'How to find quantile of a row in PySpark dataframe?
I have the following PySpark dataframe and I want to find percentile row-wise.
value col_a col_b col_c
row_a 5.0 0.0 11.0
row_b 3394.0 0.0 4543.0
row_c 136111.0 0.0 219255.0
row_d 0.0 0.0 0.0
row_e 0.0 0.0 0.0
row_f 42.0 0.0 54.0
I want to add a new column in the main dataframe like below.
value col_a col_b col_c 25%
row_a 5.0 0.0 11.0 2.5
row_b 3394.0 0.0 4543.0 1697
row_c 136111.0 0.0 219255.0 68055
row_d 0.0 0.0 0.0 0.0
row_e 0.0 0.0 0.0 0.0
row_f 42.0 0.0 54.0 21.0
In pandas, I could use this:
df['25%']=df.quantile(0.25,axis=1)
Solution 1:[1]
You could do it utilizing Spark's pandas_udf
- vectorized (well-performing) UDF version.
from pyspark.sql import functions as F
import pandas as pd
@F.pandas_udf('double')
def quantile(q: pd.Series, s: pd.DataFrame) -> pd.Series:
return s.quantile(q[0], axis=1)
Test:
df = spark.createDataFrame(
[('row_a', 5.0, 0.0, 11.0),
('row_b', 3394.0, 0.0, 4543.0),
('row_c', 136111.0, 0.0, 219255.0),
('row_d', 0.0, 0.0, 0.0),
('row_e', 0.0, 0.0, 0.0),
('row_f', 42.0, 0.0, 54.0)],
['value', 'col_a', 'col_b', 'col_c']
)
df = df.withColumn('25%', quantile(F.lit(0.25), F.struct('col_a', 'col_b', 'col_c')))
df.show()
df.explain()
# +-----+--------+-----+--------+-------+
# |value| col_a|col_b| col_c| 25%|
# +-----+--------+-----+--------+-------+
# |row_a| 5.0| 0.0| 11.0| 2.5|
# |row_b| 3394.0| 0.0| 4543.0| 1697.0|
# |row_c|136111.0| 0.0|219255.0|68055.5|
# |row_d| 0.0| 0.0| 0.0| 0.0|
# |row_e| 0.0| 0.0| 0.0| 0.0|
# |row_f| 42.0| 0.0| 54.0| 21.0|
# +-----+--------+-----+--------+-------+
#
# == Physical Plan ==
# *(2) Project [value#526, col_a#527, col_b#528, col_c#529, pythonUDF0#563 AS 25%#535]
# +- ArrowEvalPython [quantile(0.25, struct(col_a, col_a#527, col_b, col_b#528, col_c, col_c#529))], [pythonUDF0#563], 200
# +- *(1) Scan ExistingRDD[value#526,col_a#527,col_b#528,col_c#529]
Or create quantile
function from scratch (no UDF of any type).
from pyspark.sql import functions as F
import math
def quantile(q, *cols):
if q < 0 or q > 1:
raise ValueError("Parameter q should be 0 <= q <= 1")
if not cols:
raise ValueError("List of columns should be provided")
idx = (len(cols) - 1) * q
i = math.floor(idx)
j = math.ceil(idx)
fraction = idx - i
arr = F.array_sort(F.array(*cols))
return arr.getItem(i) + (arr.getItem(j) - arr.getItem(i)) * fraction
df = spark.createDataFrame(
[('row_a', 5.0, 0.0, 11.0),
('row_b', 3394.0, 0.0, 4543.0),
('row_c', 136111.0, 0.0, 219255.0),
('row_d', 0.0, 0.0, 0.0),
('row_e', 0.0, 0.0, 0.0),
('row_f', 42.0, 0.0, 54.0)],
['value', 'col_a', 'col_b', 'col_c']
)
df = df.withColumn('0.25%', quantile(0.25, 'col_a', 'col_b', 'col_c'))
df.show()
df.explain()
# +-----+--------+-----+--------+-------+
# |value| col_a|col_b| col_c| 0.25%|
# +-----+--------+-----+--------+-------+
# |row_a| 5.0| 0.0| 11.0| 2.5|
# |row_b| 3394.0| 0.0| 4543.0| 1697.0|
# |row_c|136111.0| 0.0|219255.0|68055.5|
# |row_d| 0.0| 0.0| 0.0| 0.0|
# |row_e| 0.0| 0.0| 0.0| 0.0|
# |row_f| 42.0| 0.0| 54.0| 21.0|
# +-----+--------+-----+--------+-------+
# == Physical Plan ==
# Project [value#564, col_a#565, col_b#566, col_c#567, (array_sort(array(col_a#565, col_b#566, col_c#567), lambdafunction(if ((isnull(lambda left#573) AND isnull(lambda right#574))) 0 else if (isnull(lambda left#573)) 1 else if (isnull(lambda right#574)) -1 else if ((lambda left#573 < lambda right#574)) -1 else if ((lambda left#573 > lambda right#574)) 1 else 0, lambda left#573, lambda right#574, false))[0] + ((array_sort(array(col_a#565, col_b#566, col_c#567), lambdafunction(if ((isnull(lambda left#575) AND isnull(lambda right#576))) 0 else if (isnull(lambda left#575)) 1 else if (isnull(lambda right#576)) -1 else if ((lambda left#575 < lambda right#576)) -1 else if ((lambda left#575 > lambda right#576)) 1 else 0, lambda left#575, lambda right#576, false))[1] - array_sort(array(col_a#565, col_b#566, col_c#567), lambdafunction(if ((isnull(lambda left#577) AND isnull(lambda right#578))) 0 else if (isnull(lambda left#577)) 1 else if (isnull(lambda right#578)) -1 else if ((lambda left#577 < lambda right#578)) -1 else if ((lambda left#577 > lambda right#578)) 1 else 0, lambda left#577, lambda right#578, false))[0]) * 0.5)) AS 0.25%#572]
# +- *(1) Scan ExistingRDD[value#564,col_a#565,col_b#566,col_c#567]
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 |