'How to make a new pyspark df column that's the average of the last n values by day of week?
What I'm trying to do is make a pyspark dataframe with item and date and another column "3_avg" that's the average of the last three same day-of-week from the given date back. Said another way, if 2022-5-5 is a thursday, I want the 3_avg value for that row to be the average sales for that item for the last three thursdays, so 4/28, 4/21, and 4/14.
I've got this thus far, but it just averages the whole column for that day of week... I can't figure out how to get it to be distinct by item and date and only use the last three? I was trying to get it to work with day_of_week, but my brain can't connect that to what I need to happen.
df_fcst_dow = (
df_avg
.withColumn("day_of_week", F.dayofweek(F.col("trn_dt")))
.groupBy("item", "date", "day_of_week")
.agg(
F.sum(F.col("sales") / 3).alias("3_avg")
)
)
Solution 1:[1]
You can do this with a window or you can do it with a groupby. Here I'd encourage group by as it will distribute the work better amongst the worker nodes. We create an array of the current date and the next two dates. We then explode that array, give us data duplicated accross all the dates we want so we can then group it up to make an average.
import pyspark.sql.functions as F
>>> spark.table("trn_dt").show()
+----+----------+-----+
|item| date|sales|
+----+----------+-----+
| 1|2016-01-03| 16.0|
| 1|2016-01-02| 15.0|
| 1|2016-01-05| 9.0|
| 1|2016-01-04| 10.0|
| 1|2016-01-01| 11.0|
| 1|2016-01-07| 10.0|
| 1|2016-01-06| 7.0|
+----+----------+-----+
df_avg.withColumn( "dates",
F.array( #building array of dates
df_avg["date"],
F.date_add( df_avg["date"], 1),
F.date_add( df_avg["date"], 2)
)).select(
F.col("item"),
F.explode("dates") ).alias("ThreeDayAve"), # tripling our data
F.col("sales")
).groupBy( "item","ThreeDayAve")
.agg( F.avg("sales").alias("3_avg")).show()
+----+-----------+------------------+
|item|ThreeDayAve| 3_avg|
+----+-----------+------------------+
| 1| 2016-01-05|11.666666666666666|
| 1| 2016-01-04|13.666666666666666|
| 1| 2016-01-07| 8.666666666666666|
| 1| 2016-01-01| 11.0|
| 1| 2016-01-03| 14.0|
| 1| 2016-01-02| 13.0|
| 1| 2016-01-09| 10.0|
| 1| 2016-01-06| 8.666666666666666|
| 1| 2016-01-08| 8.5|
+----+-----------+------------------+
You likely could use window on this but it wouldn't perform as well on large data sets.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Matt Andruff |