'Sum a column values based on a condition using spark scala

I have a dataframe like this:

JoiKey period Age Amount
Jk1 2022-02 2 200
Jk1 2022-02 3 450
Jk2 2022-03 5 500
Jk3 2022-03 0 200
Jk2 2022-02 8 300
Jk3 2022-03 9 200
Jk2 2022-04 1 100

Is there any way using spark scala to create two new columns based on condition.

Column Amount (Age <= 3) >> Sum of amount having age > 3 and Column Amount (Age > 3) >> Sum of amount having age <= 3

Need to group-by by Joinkey and Period and drop column "Age" and "Amount"

Desired Output will be :

JoiKey period Amount (Age <= 3) Amount (Age > 3)
Jk1 2022-02 650 0
Jk2 2022-03 0 500
Jk2 2022-02 0 300
Jk2 2022-04 100 0
Jk3 2022-03 200 200


Solution 1:[1]

Of course you can, but how would you expect your data to be? If you expect your output to be something like:

Age   Amount    A    B
 2     200     500  1450
 3     450     500  1450
 5     500     500  1450
 0     200     500  1450
 8     300     500  1450
 9     200     500  1450
 1     100     500  1450

Then this is a windowed aggregate function (windowing over sum). A windowing function is used to place an aggregated value for all the rows (in this case).

df
  .withColumn(
    "A",
    sum(when(col("Age") lt 3, col("Amount")).otherwise(lit(0)))
  .over()
)
.withColumn(
    "B",
    sum(when(col("Age") >= 3, col("Amount")).otherwise(lit(0)))
  .over()
)

Note that using over window function without partitioning is not performant at all, use partitioning. Here's the output:

22/04/25 23:54:01 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+---+------+---+----+
|Age|Amount|  A|   B|
+---+------+---+----+
|  2|   200|500|1450|
|  3|   450|500|1450|
|  5|   500|500|1450|
|  0|   200|500|1450|
|  8|   300|500|1450|
|  9|   200|500|1450|
|  1|   100|500|1450|
+---+------+---+----+

Update: So after you updated the question, I suggest you to do this:

df
  .groupBy(
    col("JoinKey"), col("period"), expr("Age < 3").as("under3")
  ).agg(sum(col("Amount")) as "grouped_age_sum")
  .withColumn("A", sum(when(col("under3") === true, col("grouped_age_sum")).otherwise(lit(0)))
  .over()
  )
  .withColumn("B", sum(when(col("under3") === false, col("grouped_age_sum")).otherwise(lit(0)))
  .over()
  ).drop("grouped_age_sum", "under3")
  .groupBy(col("JoinKey"), col("period")).min()
  .withColumnRenamed("min(A)", "A")
  .withColumnRenamed("min(B)", "B")
  .show

Please note that the same thing about partitioning still exists, I had some few sample data and didn't really need performance (it would've also add some logic dependent boilerplate to the solution), but you should do it, here's the output:

22/04/26 22:36:48 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+-------+-------+---+----+
|JoinKey| period|  A|   B|
+-------+-------+---+----+
|    JK1|2022-02|500|1450|
|    JK2|2022-03|500|1450|
|    JK3|2022-03|500|1450|
|    JK2|2022-02|500|1450|
|    JK2|2022-04|500|1450|
+-------+-------+---+----+

Update #2:

After clearer explanations you provided: so you just need grouping with 2 simple aggregate functions:

df
    .groupBy(col("JoinKey"), col("period"))
    .agg(
      sum(when(col("Age") lt 4, col("Amount")).otherwise(lit(0))).as("Amount (Age <= 3)"),
      sum(when(col("Age") gt 3, col("Amount")).otherwise(lit(0))).as("Amount (Age > 3)")
    )

Output:

+-------+-------+-----------------+----------------+
|JoinKey| period|Amount (Age <= 3)|Amount (Age > 3)|
+-------+-------+-----------------+----------------+
|    JK1|2022-02|              650|               0|
|    JK2|2022-03|                0|             500|
|    JK3|2022-03|              200|             200|
|    JK2|2022-02|                0|             300|
|    JK2|2022-04|              100|               0|
+-------+-------+-----------------+----------------+

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1