'How to use countDistinct using a window function in Spark/Scala?

I need to use window function that is paritioned by 2 columns and do distinct count on the 3rd column and that as the 4th column. I can do count with out any issues, but using distinct count is throwing exception -

rg.apache.spark.sql.AnalysisException: Distinct window functions are not supported: 

Is there any workaround for this ?



Solution 1:[1]

A previous answer suggested two possible techniques: approximate counting and size(collect_set(...)). Both have problems.

If you need an exact count, which is the main reason to use COUNT(DISTINCT ...) in big data, approximate counting will not do. Also, approximate counting actual error rates can vary quite significantly for small data.

size(collect_set(...)) may cause a substantial slowdown in processing of big data because it uses a mutable Scala HashSet, which is a pretty slow data structure. In addition, you may occasionally get strange results, e.g., if you run the query over an empty dataframe, because size(null) produces the counterintuitive -1. Spark's native distinct counting runs faster for a number of reasons, the main one being that it doesn't have to produce all the counted data in an array.

The typical approach to solving this problem is with a self-join. You group by whatever columns you need, compute the distinct count or any other aggregate function that cannot be used as a window function, and then join back to your original data.

Solution 2:[2]

Use approx_count_distinct (or) collect_set and size functions on window to mimic countDistinct functionality.

Example:

df.show()
//+---+---+---+
//|  i|  j|  k|
//+---+---+---+
//|  1|  a|  c|
//|  2|  b|  d|
//|  1|  a|  c|
//|  2|  b|  e|
//+---+---+---+

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val windowSpec = Window.partitionBy("i","j")

df.withColumn("cnt",size(collect_set("k").over(windowSpec))).show()

//or using approx_count_distinct

df.withColumn("cnt",approx_count_distinct("k").over(windowSpec)).show()

//+---+---+---+---+
//|  i|  j|  k|cnt|
//+---+---+---+---+
//|  2|  b|  d|  2|
//|  2|  b|  e|  2|
//|  1|  a|  c|  1| //as c value repeated for 1,a partition
//|  1|  a|  c|  1|
//+---+---+---+---+

Solution 3:[3]

Trying to improve Sim's answer, if you want to do this:

//val newColumnName: String = ...
//val colToCount: Column = ...
//val aggregatingCols: Seq[Column] = ...

df.withColumn(newColName, countDistinct(colToCount).over(partitionBy(aggregatingCols:_*)))

You must instead do this:

//val aggregatingCols: Seq[String] = ...

df.groupBy(aggregatingCols.head, aggregatingCols.tail:_*)
  .agg(countDistinct(colToCount).as(newColName))
  .select(newColName, aggregatingCols:_*)
  .join(df, usingColumns = aggregatingCols)

Solution 4:[4]

This will return the number of distinct elements in the partition, using dense_rank() function. When we sum ascending and descending rank, we always get the total number of distinct elements + 1 :

dense_rank().over(Window.partitionBy("i").orderBy(c.asc)) + dense_rank().over(Window.partitionBy("i").orderBy(c.desc)) - 1

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Sim
Solution 2
Solution 3
Solution 4 Vitamon