'How to return null in SUM if some values are null?

I have a case where I may have null values in the column that needs to be summed up in a group.

If I encounter a null in a group, I want the sum of that group to be null. But Pyspark by default seems to ignore the null rows and sum-up the rest of the non-null values.

For example:

enter image description here

dataframe = dataframe.groupBy('dataframe.product', 'dataframe.price') \
                     .agg(f.sum('price'))

Expected output is:

enter image description here

But I am getting:

enter image description here

Any suggestions would help!

Thanks in advance!



Solution 1:[1]

You can replace nulls with NaNs using coalesce:

df2 = df.groupBy('product').agg(
    F.sum(
        F.coalesce(F.col('price'), F.lit(float('nan')))
    ).alias('sum(price)')
).orderBy('product')

df2.show()
+-------+----------+
|product|sum(price)|
+-------+----------+
|      A|     250.0|
|      B|     200.0|
|      C|       NaN|
+-------+----------+

If you want to keep integer type, you can convert NaNs back to nulls using nanvl:

df2 = df.groupBy('product').agg(
    F.nanvl(
        F.sum(
            F.coalesce(F.col('price'), F.lit(float('nan')))
        ),
        F.lit(None)
    ).cast('int').alias('sum(price)')
).orderBy('product')

df2.show()
+-------+----------+
|product|sum(price)|
+-------+----------+
|      A|       250|
|      B|       200|
|      C|      null|
+-------+----------+

Solution 2:[2]

sum function returns NULL only if all values are null for that column otherwise nulls are simply ignored.

You can use conditionnal aggregation, if count(price) < count(*) it means there are some null values there so return null. Else, return sum(price):

from pyspark.sql.functions import col, lit, sum, when, count

df.groupby(col("product")).agg(
    when(count(col("price")) < count("*"), lit(None)).otherwise(sum(col("price"))).alias("sum_price")
  ).show()

#+-------+---------+
#|product|sum_price|
#+-------+---------+
#|      B|      200|
#|      C|     null|
#|      A|      250|
#+-------+---------+

Since Spark 3.0+, one can also use any function:

from pyspark.sql.functions import col, lit, sum, when, expr

df.groupby(col("product")).agg(
        when(expr("any(price is null)"), lit(None)).otherwise(sum(col("price"))).alias("sum_price")
  ).show()

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 mck
Solution 2