'Numpy argmax - random tie breaking

In numpy.argmax function, tie breaking between multiple max elements is so that the first element is returned. Is there a functionality for randomizing tie breaking so that all maximum numbers have equal chance of being selected?

Below is an example directly from numpy.argmax documentation.

>>> b = np.arange(6)
>>> b[1] = 5
>>> b
array([0, 5, 2, 3, 4, 5])
>>> np.argmax(b) # Only the first occurrence is returned.
1

I am looking for ways so that 1st and 5th elements in the list are returned with equal probability.

Thank you!



Solution 1:[1]

Use np.random.choice -

np.random.choice(np.flatnonzero(b == b.max()))

Let's verify for an array with three max candidates -

In [298]: b
Out[298]: array([0, 5, 2, 5, 4, 5])

In [299]: c=[np.random.choice(np.flatnonzero(b == b.max())) for i in range(100000)]

In [300]: np.bincount(c)
Out[300]: array([    0, 33180,     0, 33611,     0, 33209])

Solution 2:[2]

In the case of a multi-dimensional array, choice won't work.

An alternative is

def randargmax(b,**kw):
  """ a random tie-breaking argmax"""
  return np.argmax(np.random.random(b.shape) * (b==b.max()), **kw)

If for some reason generating random floats is slower than some other method, random.random can be replaced with that other method.

Solution 3:[3]

Easiest way is

np.random.choice(np.where(b == b.max())[0])

Solution 4:[4]

Since the accepted answer may not be obvious, here is how it works:

  • b == b.max() will return an array of booleans, with values of true where items are max and values of false for other items
  • flatnonzero() will do two things: ignore the false values (nonzero part) then return indices of true values. In other words, you get an array with indices of items matching the max value
  • Finally, you pick a random index from the array

Solution 5:[5]

Additional to @Manux's answer,

Changing b.max() to np.amax(b,**kw, keepdims=True) will let you do it along axes.

def randargmax(b,**kw):
    """ a random tie-breaking argmax"""
    return np.argmax(np.random.random(b.shape) * (b==b.max()), **kw)

randargmax(b,axis=None) 

Solution 6:[6]

Here is a comparison between the two main solutions by @divakar and @shyam-padia :

method (1) - using np.where

np.random.choice(np.where(b == b.max())[0])

method (2) - using np.flatnonzero

np.random.choice(np.flatnonzero(b == b.max())

Code

Here is the code I wrote for the comparison:

def method1(b, bmax,):
    return np.random.choice(np.where(b == bmax)[0])

def method2(b, bmax):
    return np.random.choice(np.flatnonzero(b == bmax))

def time_it(n):
    b = np.array([1.0, 2.0, 5.0, 5.0, 0.4, 0.1, 5.0, 0.3, 0.1])
    bmax = b.max()

    start = time.perf_counter()
    for i in range(n):
        method1(b, bmax)
    elapsed1 = time.perf_counter() - start
    start = time.perf_counter() 
    for i in range(n):
        method2(b, bmax)
    elapsed2 = time.perf_counter() - start

    print(f'method1 time: {elapsed1} - method2 time: {elapsed2}')
    return elapsed1, elapsed2

Results

The following figure shows the computation time for running each method for [100, 1000, 10000, 100000, 1000000] iterations where x-axis represents number of iterations, y-axis shows time in seconds. It can be seen that np.where performs better than np.flatnonzero when number of iterations increases. Note that the x-axis has a logarithmic scale.

enter image description here

To show how the two methods compare in the lower iteration, we can re-plot the previous results by making the y-axis being a logarithmic scale. We can see that np.where stays always better than np.flatnonzero.

enter image description here

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Divakar
Solution 2 Manux
Solution 3 shyam padia
Solution 4 upe
Solution 5 asrvnon
Solution 6 NKN