'Why does torch.scatter requires a smaller shape for indices than values?
A similar question was already asked here, but I think the solution is not suited for my case.
I just wonder why it is not possible to do a torch.scatter
operation, where my index tensor is bigger than my value tensor. In my case I have duplicate indices, e.g. the following value tensor a
and the index tensor idx
:
a = torch.tensor([[0, 1, 0, 0],
[0, 0, 1, 0]])
idx = torch.tensor([[1, 1, 2, 3, 3],
[0, 0, 1, 2, 2]])
a.scatter(-1, idx, 1)
returns:
RuntimeError: Expected index [2, 5] to be smaller than self [2, 4] apart from dimension 1 and to be smaller size than src [2, 4]
Is there another way to achieve this?
Solution 1:[1]
Not a solution, but a workaround:
a = torch.tensor([[0, 1, 0, 0],
[0, 0, 1, 0]])
idx = torch.tensor([[1, 1, 2, 3, 3],
[0, 0, 1, 2, 2]])
rows = torch.arange(0, a.size(0))[:,None]
n_col = idx.size(1)
a[rows.repeat(1, n_col), idx] = 1
rows.repeat(1, n_col)
gives the row index to the corresponding column index in idx
.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Christian |