'(py)spark weighted average taking account of missing values

Is there a canonical way to compute the weighted average in pyspark ignoring missing values in the denominator sum?

Take the following example:

# create data
data2 = [(1,1,1,1),
         (1,None,1,2),
         (2,1,1,1),
         (2,3,1,2),
  ]

schema = (StructType([ 
    StructField("group",IntegerType(),True), 
    StructField("var1",IntegerType(),True), 
    StructField("var2",IntegerType(),True), 
    StructField("wght", IntegerType(), True), 
  ]))


df = spark.createDataFrame(data=data2,schema=schema)
df.printSchema()
df.show(truncate=False)

+-----+----+----+----+
|group|var1|var2|wght|
+-----+----+----+----+
|1    |1   |1   |1   |
|1    |null|1   |2   |
|2    |1   |1   |1   |
|2    |3   |1   |2   |
+-----+----+----+----+

I can compute the weighted average as documented elsewhere:

(df.groupBy("group").agg(
     (F.sum(col("var1")*col("wght"))/F.sum("wght")).alias("wgtd_var1"),
     (F.sum(col("var2")*col("wght"))/F.sum("wght")).alias("wgtd_var2")).show(truncate=False))

+-----+------------------+---------+
|group|wgtd_var1         |wgtd_var2|
+-----+------------------+---------+
|1    |0.3333333333333333|1.0      |
|2    |2.3333333333333335|1.0      |
+-----+------------------+---------+

But the problem is that for group 1 the weigthed average should be one as the second observation sohuld not be used. I can

# get new weights
df = (df.withColumn("wghtvar1", F.when(col("var1").isNull(), None)
                                 .otherwise(col("wght")))
        .withColumn("wghtvar2", F.when(col("var2").isNull(), None)
                                 .otherwise(col("wght"))))

# compute correct weighted average
(df.groupBy("group").agg(
     (F.sum(col("var1")*col("wghtvar1"))/F.sum("wghtvar1")).alias("wgtd_var1"),
     (F.sum(col("var2")*col("wghtvar2"))/F.sum("wghtvar2")).alias("wgtd_var2")).show(truncate=False))

+-----+------------------+---------+
|group|wgtd_var1         |wgtd_var2|
+-----+------------------+---------+
|1    |1.0               |1.0      |
|2    |2.3333333333333335|1.0      |
+-----+------------------+---------+

Is there a canonical way to do this?



Solution 1:[1]

Not much difference but at least this can save you from creating new wght column per variable.

Conditional aggregation.

df = (df.groupby('group')
      .agg(
          (F.sum(F.when(F.col('var1').isNotNull(), F.col('var1') * F.col('wght'))) 
           /
          (F.sum(F.when(F.col('var1').isNotNull(), F.col('wght'))))
          ).alias('wgtd_var1')
      ))

For applying this to multiple var, I can use list comprehension.

df = (df.groupby('group')
      .agg(*[
          (F.sum(F.when(F.col(x).isNotNull(), F.col(x) * F.col('wght'))) 
           /
          (F.sum(F.when(F.col(x).isNotNull(), F.col('wght'))))
          ).alias(f'wgtd_{x}')
          for x in ['var1', 'var2']
      ]))

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Emma