'Find Top-K Smallest Values So Far in Data Stream
Let's say that I have a data stream where single data point is retrieved at a time:
import numpy as np
def next_data_point():
"""
Mock a data stream. Data points will always be a positive float
"""
return np.random.uniform(0, 1_000_000, dtype='float')
I need to be able to update a NumPy array and track the top-K smallest-values-so-far from this stream (or until the user decides when it is okay to stop the analysis via some check_stop_condition()
function). Let's say we want to capture the top 1,000 smallest values from the stream, then a naive way to accomplish this might be:
k = 1000
topk = np.full(k, fille_value=np.inf, dtype='float')
while check_stop_condition():
topk[:] = np.sort(np.append(topk, next_data_point()))[:k]
This works fine but is quite inefficient and can be slow if repeated millions of times since we are:
- creating a new array every time
- sorting the concatenated array every time
So, I came up with a different approach to address these 2 inefficiencies:
k = 1000
topk = np.full(k, fille_value=np.inf)
while check_stop_condition():
data_point = next_data_point()
idx = np.searchsorted(topk, data_point)
if idx < k:
topk[idx : -1] = topk[idx + 1 :]
topk[idx] = data_point
Here, I leverage np.searchsorted()
to replace np.sort
and to quickly find the insertion point, idx
, for the next data point. I believe that np.searchsorted
uses some sort of binary search and assumes that the initial array is pre-sorted first. Then, we shift the data in topk
to accommodate and insert the new data point if and only if idx < k
.
I haven't seen this being done anywhere and so my question is if there is anything that can be done to make this even more efficient? Especially in the way that I shifting things around inside the if
statement.
Solution 1:[1]
Sorting a huge array is very expensive so this is not surprising the second method is faster. However, the speed of the second method is probably bounded by the slow array copy. The complexity of the first method is O(k log(k) n)
while the second method has a complexity of O(n (log(k) + k * p))
, with n
the number of points and p
the probability of the branch to be taken.
To build a faster implementation, you can use a tree. More specifically a self-balancing binary search tree for example. Here is the algorithm:
topk = Tree()
maxi = np.inf
while check_stop_condition(): # O(n)
data_point = next_data_point()
if len(topk) <= 1000: # O(1)
topk.insert(data_point) # O(log k)
elif data_point < maxi: # Discard the value in O(1)
topk.insert(data_point) # O(log k)
topk.deleteMaxNode() # O(log k)
maxi = topk.findMaxValue() # O(log k)
The above algorithm run in O(n log k)
. One can show that this complexity is optimal (using only data_point
comparisons).
In practice, binary heaps can be a bit faster (with the same complexity). Indeed, they have several advantage over self-balancing binary search trees in this case:
- they can be implemented in a very compact way in memory (reducing cache misses and memory consumption)
- insertion of the
n=1000
first items can be done inO(n)
time and very quickly
Note that discarded values are computed in constant time. This appends a lot on huge random datasets as most of the values get quickly bigger than maxi
. On can even prove that random datasets can be computed in O(n)
time (optimal).
Note that Python 3 provides a standard heap implementation called heapq which is probably a good starting point.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 |