'Compute maxima and minima of a 4D tensor in PyTorch
Suppose that we have a 4-dimensional tensor, for instance
import torch
X = torch.rand(2, 3, 4, 4)
tensor([[[[-0.9951, 1.6668, 1.3140, 1.4274],
[ 0.2614, 2.6442, -0.3041, 0.7337],
[-1.2690, 0.0125, -0.3885, 0.0535],
[ 1.5270, -0.1186, -0.4458, 0.1389]],
[[ 0.9125, -1.2998, -0.4277, -0.2688],
[-1.6917, -0.8855, -0.2784, -0.6717],
[ 1.1417, 0.4574, 0.4803, -1.6637],
[ 0.7322, 0.2654, -0.1525, 1.7285]],
[[ 1.8310, -1.5765, 0.1392, 1.3431],
[-0.6641, -1.5090, -0.4893, -1.4110],
[ 0.5875, 0.7528, -0.6482, -0.2547],
[-2.3133, 0.3888, 2.1428, 0.2331]]]])
I want to compute the maximum and the minimum values of X
over the dimensions 2 and 3, that is, to compute two tensors of size (2,3,1,1), one for the maximum and one for the minimum values of the 4x4 blocks.
I started by trying to do that with torch.max()
and torch.min()
, but I had no luck. I would expect the dim
argument of the above functions to be able to take tuple values, but it can take only an integer. So I don't know how to proceed.
However, specifically for the maximum values, I decided to use torch.nn.MaxPool2d()
with kernel_size=4
and stride=4
. This indeed did the job:
max_pool = nn.MaxPool2d(kernel_size=4, stride=4)
X_max = max_pool(X)
tensor([[[[2.6442]],
[[1.7285]],
[[2.1428]]]])
But, afaik, there's no similar layer for "min"-pooling. Could you please help me on how to compute the minima similarly to the maxima?
Thank you.
Solution 1:[1]
Just calculate the max for both dimensions sequentially, it gives the same result:
tup = (2,3)
for dim in tup:
X = torch.max(X,dim=dim,keepdim=True)[0]
Solution 2:[2]
If you use torch>=1.11, please use torch.amax function,
dim = (2,3)
x = torch.rand(2,3,4,4)
x_max = torch.amax(x,dim=dim)
However, if you use the older version of Pytorch, then please use this custom max function
def torch_max(x,dim):
s1 = [i for i in range(len(x.shape)) if i not in dim]
s2 = [i for i in range(len(x.shape)) if i in dim]
x2 = x.permute(tuple(s1+s2))
s = [d for (i,d) in enumerate(x.shape) if i not in dim] + [-1]
x2 = torch.reshape(x2, tuple(s))
max,_ = x2.max(-1)
return max
Usage of this function is very similar to the original torch.max function.
dim = (2,3)
x = torch.rand(2,3,4,4)
x_max = torch_max(x,dim=dim)
If the length of dim
is long, then this custom torch_max
is slightly faster.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Chris Holland |
Solution 2 |