'Dividing a 3D tensor in to set of small 3D tensors in tensorflow

I have an tensor which has the dimensions (x,y,z). I want to divide it in to n small tensors of shape (a,a,z). The operation is similar to dividing an image in to a smaller set of patches but here I want to try it for a tensor which has a higher depth. I have used an available library in tensorflow for the images but now I want to do it with tensors with larger depth.

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source