'Pyspark Window function on entire data frame

Consider a pyspark data frame. I would like to summarize the entire data frame, per column, and append the result for every row.

+-----+----------+-----------+
|index|      col1| col2      |
+-----+----------+-----------+
|  0.0|0.58734024|0.085703015|
|  1.0|0.67304325| 0.17850411|

Expected result

+-----+----------+-----------+-----------+-----------+-----------+-----------+
|index|      col1| col2      |  col1_min | col1_mean |col2_min   | col2_mean
+-----+----------+-----------+-----------+-----------+-----------+-----------+
|  0.0|0.58734024|0.085703015|  -5       | 2.3       |  -2       | 1.4 |
|  1.0|0.67304325| 0.17850411|  -5       | 2.3       |  -2       | 1.4 |

To my knowledge, I'll need Window function with the whole data frame as Window, to keep the result for each row (instead of, for example, do the stats separately then join back to replicate for each row)

My questions are :

  1. How to write Window without any partition nor order by

I know there is the standard Window with Partition and Order, but not the one taking everything as 1 single partition

w = Window.partitionBy("col1", "col2").orderBy(desc("col1"))
df = df.withColumn("col1_mean", mean("col1").over(w)))

How would I write a Window with everything as one partition ?

  1. Any way to write dynamically for all columns.

Let's say I have 500 columns, it does not look great to write repeatedly.

df = df.withColumn("col1_mean", mean("col1").over(w))).withColumn("col1_min", min("col2").over(w)).withColumn("col2_mean", mean().over(w)).....

Let's assume I want multiple stats for each column, so each colx will spawn colx_min, colx_max, colx_mean.



Solution 1:[1]

Instead of using window you can achieve the same with a custom aggregation in combination with cross join:

import pyspark.sql.functions as F
from pyspark.sql.functions import broadcast
from itertools import chain

df = spark.createDataFrame([
  [1, 2.3, 1],
  [2, 5.3, 2],
  [3, 2.1, 4],
  [4, 1.5, 5]
], ["index", "col1", "col2"])

agg_cols = [(
             F.min(c).alias("min_" + c), 
             F.max(c).alias("max_" + c), 
             F.mean(c).alias("mean_" + c)) 

  for c in df.columns if c.startswith('col')]

stats_df = df.agg(*list(chain(*agg_cols)))

# there is no performance impact from crossJoin since we have only one row on the right table which we broadcast (most likely Spark will broadcast it anyway)
df.crossJoin(broadcast(stats_df)).show() 

# +-----+----+----+--------+--------+---------+--------+--------+---------+
# |index|col1|col2|min_col1|max_col1|mean_col1|min_col2|max_col2|mean_col2|
# +-----+----+----+--------+--------+---------+--------+--------+---------+
# |    1| 2.3|   1|     1.5|     5.3|      2.8|       1|       5|      3.0|
# |    2| 5.3|   2|     1.5|     5.3|      2.8|       1|       5|      3.0|
# |    3| 2.1|   4|     1.5|     5.3|      2.8|       1|       5|      3.0|
# |    4| 1.5|   5|     1.5|     5.3|      2.8|       1|       5|      3.0|
# +-----+----+----+--------+--------+---------+--------+--------+---------+

Note1: Using broadcast we will avoid shuffling since the broadcasted df will be send to all the executors.

Note2: with chain(*agg_cols) we flatten the list of tuples which we created in the previous step.

UPDATE:

Here is the execution plan for the above program:

== Physical Plan ==
*(3) BroadcastNestedLoopJoin BuildRight, Cross
:- *(3) Scan ExistingRDD[index#196L,col1#197,col2#198L]
+- BroadcastExchange IdentityBroadcastMode, [id=#274]
   +- *(2) HashAggregate(keys=[], functions=[finalmerge_min(merge min#233) AS min(col1#197)#202, finalmerge_max(merge max#235) AS max(col1#197)#204, finalmerge_avg(merge sum#238, count#239L) AS avg(col1#197)#206, finalmerge_min(merge min#241L) AS min(col2#198L)#208L, finalmerge_max(merge max#243L) AS max(col2#198L)#210L, finalmerge_avg(merge sum#246, count#247L) AS avg(col2#198L)#212])
      +- Exchange SinglePartition, [id=#270]
         +- *(1) HashAggregate(keys=[], functions=[partial_min(col1#197) AS min#233, partial_max(col1#197) AS max#235, partial_avg(col1#197) AS (sum#238, count#239L), partial_min(col2#198L) AS min#241L, partial_max(col2#198L) AS max#243L, partial_avg(col2#198L) AS (sum#246, count#247L)])
            +- *(1) Project [col1#197, col2#198L]
               +- *(1) Scan ExistingRDD[index#196L,col1#197,col2#198L]

Here we see a BroadcastExchange of a SinglePartition which is broadcasting one single row since stats_df can fit into a SinglePartition. Therefore the data being shuffled here is only one row (the minimum possible).

Solution 2:[2]

We can also able to specify without orderby,partitionBy clauses in window function min("<col_name>").over()

Example:

//sample data
val df=Seq((1,2,3),(4,5,6)).toDF("i","j","k")

val df1=df.columns.foldLeft(df)((df, c) => {
  df.withColumn(s"${c}_min",min(col(s"${c}")).over()).
  withColumn(s"${c}_max",max(col(s"${c}")).over()).
  withColumn(s"${c}_mean",mean(col(s"${c}")).over())
})

df1.show()
//+---+---+---+-----+-----+------+-----+-----+------+-----+-----+------+
//|  i|  j|  k|i_min|i_max|i_mean|j_min|j_max|j_mean|k_min|k_max|k_mean|
//+---+---+---+-----+-----+------+-----+-----+------+-----+-----+------+
//|  1|  2|  3|    1|    4|   2.5|    2|    5|   3.5|    3|    6|   4.5|
//|  4|  5|  6|    1|    4|   2.5|    2|    5|   3.5|    3|    6|   4.5|
//+---+---+---+-----+-----+------+-----+-----+------+-----+-----+------+

Solution 3:[3]

After looking around, I realize there is another good solution for Pyspark 2.0+ where over requires window argument:

from pyspark.sql import Window
from pyspark.sql.functions import min

df.withColumn(f"{c}_min", min(col(f"{c}")).over(Window.partitionBy()))

If you leave partitionBy empty, it will not do any partitioning.

Solution 4:[4]

I know this was a while ago, but you could also add a dummy variable column that has the same value for each row. Then the partition contains the entire dataframe.

df_dummy = df.withColumn("dummy", col("index") * 0)
w = Window.partitionBy("dummy")

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 notNull
Solution 3 Tim
Solution 4 myranda