'Replace torch.gather by other operator?

I have one script code, where x1 and x2 size of 1x68x8x8

 tmp_batch, tmp_channel, tmp_height, tmp_width = x1.size()
 x1 = x1.view(tmp_batch*tmp_channel, -1)        
 max_ids = torch.argmax(x1, 1)            
 max_ids = max_ids.view(-1, 1)
            
 x2 = x2.view(tmp_batch*tmp_channel, -1)
 outputs_x_select = torch.gather(x2, 1, max_ids) # size of 68 x 1

As for the above code, I have trouble with torch.gather when I used old onnx. Hence, I would like to find an alternative solution that replaces the toch.gather by other operators but gives the same output with the above code. Could you please give me some suggestions?



Solution 1:[1]

One workaround is to use the equivalent numpy method. If you include an import numpy as np statement somewhere, you could do the following.

outputs_x_select = torch.Tensor(np.take_along_axis(x2,max_ids,1))

If that gives you a grad related error, try

outputs_x_select = torch.Tensor(np.take_along_axis(x2.detach(),max_ids,1))

An approach without numpy: in this case, it seems that max_ids contains exactly one entry per row. Thus, I believe the following will work:

max_ids = torch.argmax(x1, 1) # do not reshape
            
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = x2[torch.arange(tmp_batch*tmp_channel),max_ids]

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1