'How to mask a 3D tensor with 2D mask and keep the dimensions of original vector?
Suppose, I have a 3D tensor A
A = torch.arange(24).view(4, 3, 2)
print(A)
and require masking it using 2D tensor
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
Using masked_select functionality from PyTorch leads to the following error.
torch.masked_select(X, (mask == 1))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-72-fd6809d2c4cc> in <module>
12
13 # Select based on new mask
---> 14 Y = torch.masked_select(X, (mask == 1))
15 #Y = X * mask_
16 print(Y)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2
How to mask a 3D tensor with a 2D mask and keep the dimensions of the original vector? Any hints will be appreciated.
Solution 1:[1]
Essentially, we need to match the dimension of the tensor mask with the tensor being masked.
There are two ways to do it.
Approach 1: Does not preserve original tensor dimensions.
X = torch.arange(24).view(4, 3, 2)
print(X)
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)
# Select based on the new expanded mask
Y = torch.masked_select(X, (mask_ == 1)) # does not preserve the dims
print(Y)
The output for approach 1:
tensor([ 0, 1, 8, 9, 18, 19])
Approach 2: Preserves the original tensor dimensions (by padding).
X = torch.arange(24).view(4, 3, 2)
print(X)
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)
# Select based on the new expanded mask
Y = X * mask_
print(Y)
The output for approach 2:
tensor([[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]])
Mask: tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 0],
[1, 0, 0]])
tensor([[[1, 1],
[0, 0],
[0, 0]],
[[0, 0],
[1, 1],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]],
[[1, 1],
[0, 0],
[0, 0]]])
tensor([[[ 0, 1],
[ 0, 0],
[ 0, 0]],
[[ 0, 0],
[ 8, 9],
[ 0, 0]],
[[ 0, 0],
[ 0, 0],
[ 0, 0]],
[[18, 19],
[ 0, 0],
[ 0, 0]]]
Solution 2:[2]
There is a simple way to preserve dims as follows:
torch.mul(X, mask.unsqueeze(-1))
the results is also:
tensor([[[ 0, 1],
[ 0, 0],
[ 0, 0]],
[[ 0, 0],
[ 8, 9],
[ 0, 0]],
[[ 0, 0],
[ 0, 0],
[ 0, 0]],
[[18, 19],
[ 0, 0],
[ 0, 0]]])
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | |
Solution 2 | Eric Gong |