'Merge one tensor into other tensor on specific indexes in PyTorch
Any efficient way to merge one tensor to another in Pytorch, but on specific indexes.
Here is my full problem.
I have a list of indexes of a tensor
in below code xy
is the original tensor.
I need to preserve the rows (those rows who are in indexes list) of xy
and apply some function on elements other than those indexes (For simplicity let say the function is 'multiply them with two),
xy = torch.rand(100,4)
indexes=[1,2,55,44,66,99,3,65,47,88,99,0]
Then merge them back into the original tensor.
This is what I have done so far: I create a mask tensor
indexes=[1,2,55,44,66,99,3,65,47,88,99,0]
xy = torch.rand(100,4)
mask=[]
for i in range(0,xy.shape[0]):
if i in indexes:
mask.append(False)
else:
mask.append(True)
print(mask)
import numpy as np
target_mask = torch.from_numpy(np.array(mask, dtype=bool))
print(target_mask.sum()) #output is 89 as these are element other than preserved.
Apply the function on masked rows
zy = xy[target_mask]
print(zy)
zy=zy*2
print(zy)
Code above is working fine and posted here to clarify the problem
Now I want to merge tensor zy
into xy
on specified index saved in the list indexes
.
Here is the pseudocode I made, as one can see it is too complex and need 3 for loops to complete the task. and it will be too much resources wastage.
# pseudocode
for masked_row in indexes:
for xy_rows_index in xy:
if xy_rows_index= masked_row
pass
else:
take zy tensor row and replace here #another loop to read zy.
But I am not sure what is an efficient way to merge them, as I don't want to use NumPy
or for
loop etc. It will make the process slow, as the original tensor is too big and I am going to use GPU.
Any efficient way in Pytorch for this?
Solution 1:[1]
Once you have your mask you can assign updated values in place.
zy = 2 * xy[target_mask]
xy[target_mask] = zy
As for acquiring the mask I don't see a problem necessarily with your approach, though using the built-in set operations would probably be more efficient. This also gives an index tensor instead of a mask, which, depending on the number of indices being updated, may be more efficient.
i = list(set(range(len(xy)))-set(indexes))
zy = 2 * xy[i]
xy[i] = zy
Edit:
To address the comment, specifically to find the complement of indices of i
we can do
i_complement = list(set(range(len(xy)))-set(i))
However, assuming indexes
contains only values between 0
and len(xy)-1
then we could equivalently use i_complement = len(set(indexes))
, which just removes the repeated values in indexes
.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 |