'How to mask a 3D tensor with 2D mask and keep the dimensions of original vector?

Suppose, I have a 3D tensor A

A = torch.arange(24).view(4, 3, 2)
print(A)

and require masking it using 2D tensor

mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)

Using masked_select functionality from PyTorch leads to the following error.

torch.masked_select(X, (mask == 1))


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-72-fd6809d2c4cc> in <module>
     12 
     13 # Select based on new mask
---> 14 Y = torch.masked_select(X, (mask == 1))
     15 #Y = X * mask_
     16 print(Y)

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2

How to mask a 3D tensor with a 2D mask and keep the dimensions of the original vector? Any hints will be appreciated.



Solution 1:[1]

Essentially, we need to match the dimension of the tensor mask with the tensor being masked.

There are two ways to do it.

Approach 1: Does not preserve original tensor dimensions.

X = torch.arange(24).view(4, 3, 2)
print(X)

mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)

# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)

# Select based on the new expanded mask
Y = torch.masked_select(X, (mask_ == 1)) # does not preserve the dims
print(Y)

The output for approach 1:

tensor([ 0,  1,  8,  9, 18, 19])

Approach 2: Preserves the original tensor dimensions (by padding).

X = torch.arange(24).view(4, 3, 2)
print(X)

mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)

# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)

# Select based on the new expanded mask
Y = X * mask_
print(Y)

The output for approach 2:

tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]],

        [[12, 13],
         [14, 15],
         [16, 17]],

        [[18, 19],
         [20, 21],
         [22, 23]]])
Mask:  tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 0],
        [1, 0, 0]])
tensor([[[1, 1],
         [0, 0],
         [0, 0]],

        [[0, 0],
         [1, 1],
         [0, 0]],

        [[0, 0],
         [0, 0],
         [0, 0]],

        [[1, 1],
         [0, 0],
         [0, 0]]])
tensor([[[ 0,  1],
         [ 0,  0],
         [ 0,  0]],

        [[ 0,  0],
         [ 8,  9],
         [ 0,  0]],

        [[ 0,  0],
         [ 0,  0],
         [ 0,  0]],

        [[18, 19],
         [ 0,  0],
         [ 0,  0]]]

Solution 2:[2]

There is a simple way to preserve dims as follows:

torch.mul(X, mask.unsqueeze(-1))

the results is also:

tensor([[[ 0,  1],
     [ 0,  0],
     [ 0,  0]],

    [[ 0,  0],
     [ 8,  9],
     [ 0,  0]],

    [[ 0,  0],
     [ 0,  0],
     [ 0,  0]],

    [[18, 19],
     [ 0,  0],
     [ 0,  0]]])

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 Eric Gong