'Get index of selected element when using unsorted_segment_max

I have the following the tensor and a list of ids:

aa = tf.constant([5, 1, 2, 3, 4])
ids = [1, 0, 0, 1, 0]

I would like to find out the maximum for each id in the tensor and the associated index with it.

For example, the output I am expecting is

<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[4, 4],
       [5, 0]], dtype=int32)>

I am using the following code:

bb = tf.stack([aa, tf.range(5)], axis=-1)
tf.math.unsorted_segment_max(bb, ids, 2)

The output I am getting is:

<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[4, 4],
       [5, 3]], dtype=int32)>


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source