'Sum a column values based on a condition using spark scala
I have a dataframe like this:
JoiKey | period | Age | Amount |
---|---|---|---|
Jk1 | 2022-02 | 2 | 200 |
Jk1 | 2022-02 | 3 | 450 |
Jk2 | 2022-03 | 5 | 500 |
Jk3 | 2022-03 | 0 | 200 |
Jk2 | 2022-02 | 8 | 300 |
Jk3 | 2022-03 | 9 | 200 |
Jk2 | 2022-04 | 1 | 100 |
Is there any way using spark scala to create two new columns based on condition.
Column Amount (Age <= 3) >> Sum of amount having age > 3 and Column Amount (Age > 3) >> Sum of amount having age <= 3
Need to group-by by Joinkey and Period and drop column "Age" and "Amount"
Desired Output will be :
JoiKey | period | Amount (Age <= 3) | Amount (Age > 3) |
---|---|---|---|
Jk1 | 2022-02 | 650 | 0 |
Jk2 | 2022-03 | 0 | 500 |
Jk2 | 2022-02 | 0 | 300 |
Jk2 | 2022-04 | 100 | 0 |
Jk3 | 2022-03 | 200 | 200 |
Solution 1:[1]
Of course you can, but how would you expect your data to be? If you expect your output to be something like:
Age Amount A B
2 200 500 1450
3 450 500 1450
5 500 500 1450
0 200 500 1450
8 300 500 1450
9 200 500 1450
1 100 500 1450
Then this is a windowed aggregate function (windowing over sum). A windowing function is used to place an aggregated value for all the rows (in this case).
df
.withColumn(
"A",
sum(when(col("Age") lt 3, col("Amount")).otherwise(lit(0)))
.over()
)
.withColumn(
"B",
sum(when(col("Age") >= 3, col("Amount")).otherwise(lit(0)))
.over()
)
Note that using over
window function without partitioning is not performant at all, use partitioning.
Here's the output:
22/04/25 23:54:01 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+---+------+---+----+
|Age|Amount| A| B|
+---+------+---+----+
| 2| 200|500|1450|
| 3| 450|500|1450|
| 5| 500|500|1450|
| 0| 200|500|1450|
| 8| 300|500|1450|
| 9| 200|500|1450|
| 1| 100|500|1450|
+---+------+---+----+
Update: So after you updated the question, I suggest you to do this:
df
.groupBy(
col("JoinKey"), col("period"), expr("Age < 3").as("under3")
).agg(sum(col("Amount")) as "grouped_age_sum")
.withColumn("A", sum(when(col("under3") === true, col("grouped_age_sum")).otherwise(lit(0)))
.over()
)
.withColumn("B", sum(when(col("under3") === false, col("grouped_age_sum")).otherwise(lit(0)))
.over()
).drop("grouped_age_sum", "under3")
.groupBy(col("JoinKey"), col("period")).min()
.withColumnRenamed("min(A)", "A")
.withColumnRenamed("min(B)", "B")
.show
Please note that the same thing about partitioning still exists, I had some few sample data and didn't really need performance (it would've also add some logic dependent boilerplate to the solution), but you should do it, here's the output:
22/04/26 22:36:48 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+-------+-------+---+----+
|JoinKey| period| A| B|
+-------+-------+---+----+
| JK1|2022-02|500|1450|
| JK2|2022-03|500|1450|
| JK3|2022-03|500|1450|
| JK2|2022-02|500|1450|
| JK2|2022-04|500|1450|
+-------+-------+---+----+
Update #2:
After clearer explanations you provided: so you just need grouping with 2 simple aggregate functions:
df
.groupBy(col("JoinKey"), col("period"))
.agg(
sum(when(col("Age") lt 4, col("Amount")).otherwise(lit(0))).as("Amount (Age <= 3)"),
sum(when(col("Age") gt 3, col("Amount")).otherwise(lit(0))).as("Amount (Age > 3)")
)
Output:
+-------+-------+-----------------+----------------+
|JoinKey| period|Amount (Age <= 3)|Amount (Age > 3)|
+-------+-------+-----------------+----------------+
| JK1|2022-02| 650| 0|
| JK2|2022-03| 0| 500|
| JK3|2022-03| 200| 200|
| JK2|2022-02| 0| 300|
| JK2|2022-04| 100| 0|
+-------+-------+-----------------+----------------+
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 |