'GroupBy operator for Kotlin Flow

I am trying to switch from RxJava to Kotlin Flow. Flow is really impressive. But Is there any operator similar to RxJava's "GroupBy" in kotlin Flow right now?



Solution 1:[1]

As of Kotlin Coroutines 1.3, the standard library doesn't seem to provide this operator. However, since the design of Flow is such that all operators are extension functions, there is no fundamental distinction between the standard library providing it and you writing your own.

With that in mind, here are some of my ideas on how to approach it.

1. Collect Each Group to a List

If you just need a list of all items for each key, use this simple implementation that emits pairs of (K, List<T>):

fun <T, K> Flow<T>.groupToList(getKey: (T) -> K): Flow<Pair<K, List<T>>> = flow {
    val storage = mutableMapOf<K, MutableList<T>>()
    collect { t -> storage.getOrPut(getKey(t)) { mutableListOf() } += t }
    storage.forEach { (k, ts) -> emit(k to ts) }
}

For this example:

suspend fun main() {
    val input = 1..10
    input.asFlow()
            .groupToList { it % 2 }
            .collect { println(it) }
}

it prints

(1, [1, 3, 5, 7, 9])
(0, [2, 4, 6, 8, 10])

2.a Emit a Flow for Each Group

If you need the full RxJava semantics where you transform the input flow into many output flows (one per distinct key), things get more involved.

Whenever you see a new key in the input, you must emit a new inner flow to the downstream and then, asynchronously, keep pushing more data into it whenever you encounter the same key again.

Here's an implementation that does this:

fun <T, K> Flow<T>.groupBy(getKey: (T) -> K): Flow<Pair<K, Flow<T>>> = flow {
    val storage = mutableMapOf<K, SendChannel<T>>()
    try {
        collect { t ->
            val key = getKey(t)
            storage.getOrPut(key) {
                Channel<T>(32).also { emit(key to it.consumeAsFlow()) }
            }.send(t)
        }
    } finally {
        storage.values.forEach { chan -> chan.close() }
    }
}

It sets up a Channel for each key and exposes the channel to the downstream as a flow.

2.b Concurrently Collect and Reduce Grouped Flows

Since groupBy keeps emitting the data to the inner flows after emitting the flows themselves to the downstream, you have to be very careful with how you collect them.

You must collect all the inner flows concurrently, with no upper limit on the level of concurrency. Otherwise the channels of the flows that are queued for later collection will eventually block the sender and you'll end up with a deadlock.

Here is a function that does this properly:

fun <T, K, R> Flow<Pair<K, Flow<T>>>.reducePerKey(
        reduce: suspend Flow<T>.() -> R
): Flow<Pair<K, R>> = flow {
    coroutineScope {
        this@reducePerKey
                .map { (key, flow) -> key to async { flow.reduce() } }
                .toList()
                .forEach { (key, deferred) -> emit(key to deferred.await()) }
    }
}

The map stage launches a coroutine for each inner flow it receives. The coroutine reduces it to the final result.

toList() is a terminal operation that collects the entire upstream flow, launching all the async coroutines in the process. The coroutines start consuming the inner flows even while we're still collecting the main flow. This is essential to prevent a deadlock.

Finally, after all the coroutines have been launched, we start a forEach loop that waits for and emits the final results as they become available.

You can implement almost the same behavior in terms of flatMapMerge:

fun <T, K, R> Flow<Pair<K, Flow<T>>>.reducePerKey(
        reduce: suspend Flow<T>.() -> R
): Flow<Pair<K, R>> = flatMapMerge(Int.MAX_VALUE) { (key, flow) ->
    flow { emit(key to flow.reduce()) }
}

The difference is in the ordering: whereas the first implementation respects the order of appearance of keys in the input, this one doesn't. Both perform similarly.

3. Example

This example groups and sums 40 million integers:

suspend fun main() {
    val input = 1..40_000_000
    input.asFlow()
            .groupBy { it % 100 }
            .reducePerKey { sum { it.toLong() } }
            .collect { println(it) }
}

suspend fun <T> Flow<T>.sum(toLong: suspend (T) -> Long): Long {
    var sum = 0L
    collect { sum += toLong(it) }
    return sum
}

I can successfully run this with -Xmx64m. On my 4-core laptop I'm getting about 4 million items per second.

It is simple to redefine the first solution in terms of the new one like this:

fun <T, K> Flow<T>.groupToList(getKey: (T) -> K): Flow<Pair<K, List<T>>> =
        groupBy(getKey).reducePerKey { toList() }

Solution 2:[2]

Not yet but you can have a look at this library https://github.com/akarnokd/kotlin-flow-extensions .

Solution 3:[3]

In my project, I was able to achieve this non-blocking by using Flux.groupBy. https://projectreactor.io/docs/core/release/api/reactor/core/publisher/Flux.html#groupBy-java.util.function.Function-

I did this in the process of converting the results obtained with Flux to Flow.
This may be an inappropriate answer for the situation in question, but I share it as an example.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 Dominic Fischer
Solution 3