'Replace torch.gather by other operator?
I have one script code, where x1
and x2
size of 1x68x8x8
tmp_batch, tmp_channel, tmp_height, tmp_width = x1.size()
x1 = x1.view(tmp_batch*tmp_channel, -1)
max_ids = torch.argmax(x1, 1)
max_ids = max_ids.view(-1, 1)
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = torch.gather(x2, 1, max_ids) # size of 68 x 1
As for the above code, I have trouble with torch.gather
when I used old onnx
. Hence, I would like to find an alternative solution that replaces the toch.gather
by other operators but gives the same output with the above code. Could you please give me some suggestions?
Solution 1:[1]
One workaround is to use the equivalent numpy method. If you include an import numpy as np
statement somewhere, you could do the following.
outputs_x_select = torch.Tensor(np.take_along_axis(x2,max_ids,1))
If that gives you a grad related error, try
outputs_x_select = torch.Tensor(np.take_along_axis(x2.detach(),max_ids,1))
An approach without numpy: in this case, it seems that max_ids
contains exactly one entry per row. Thus, I believe the following will work:
max_ids = torch.argmax(x1, 1) # do not reshape
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = x2[torch.arange(tmp_batch*tmp_channel),max_ids]
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 |