'Implementation of a WGAN-GP in tensorflow

Using tensorflow, I'm trying to reimplement the following architecture (for now I'm focusing on the Generator part):

enter image description here

What I've done for now has been defining the generator in the following way:

N_Z = 128

generator = [
    tf.keras.layers.Dense(units=6144, activation="relu"),
    tf.keras.layers.Reshape(target_shape=(6, 4, 256)),
    tf.keras.layers.Conv2DTranspose(
        filters=128, kernel_size=(5,5), strides=(2, 2), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Conv2DTranspose(
        filters=128, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Conv2DTranspose(
        filters=64, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
    ),
     tf.keras.layers.Conv2DTranspose(
        filters=64, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Conv2DTranspose(
        filters=32, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
    )
    tf.keras.layers.Conv2DTranspose(
        filters=32, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
    )
    tf.keras.layers.Conv2DTranspose(
        filters=1, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
    )
]

Generator = tf.keras.models.Sequential(generator)

But if I take some random noise and let the model process it, this is the final shape I get back:

noise = tf.random.normal((64,128))

result = Generator(noise)

result.shape

TensorShape([64, 28, 28, 1])

What am I doing wrong here? I was also checking the original implementation to see additional details but I can't find anything that makes me understand.



Solution 1:[1]

It is easy you need to see input-output, it required some help at the top levels.

[ Sample ]:

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Initialize
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=( 6144 )),
    tf.keras.layers.Dense( 48 * 128, activation="linear" ),
    tf.keras.layers.BatchNormalization( momentum=0.99, epsilon=0.00001 ),
    tf.keras.layers.Reshape(target_shape=( 6, 4, 256 )),
    tf.keras.layers.Conv2DTranspose(
        filters=128, kernel_size=(5,5), strides=(2, 2), padding="same", activation="relu"
    ),
    tf.keras.layers.Resizing( 11, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
    tf.keras.layers.Reshape(target_shape=(11, 8, 128)),
    tf.keras.layers.Conv2DTranspose(
        filters=128, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Resizing( 22, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
    tf.keras.layers.Reshape(target_shape=(22, 8, 128)),
    tf.keras.layers.Conv2DTranspose(
        filters=64, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Resizing( 22, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
    tf.keras.layers.Reshape(target_shape=(22, 8, 64)),
     tf.keras.layers.Conv2DTranspose(
        filters=64, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Resizing( 43, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
    tf.keras.layers.Reshape(target_shape=(43, 8, 64)),
    tf.keras.layers.Conv2DTranspose(
        filters=32, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Resizing( 43, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
    tf.keras.layers.Reshape(target_shape=(43, 8, 32)),
    tf.keras.layers.Conv2DTranspose(
        filters=32, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
    ),
    tf.keras.layers.Resizing( 85, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
    tf.keras.layers.Reshape(target_shape=(85, 8, 32)),

])

model.summary()

[ Output ]:

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 dense (Dense)               (None, 6144)              37754880

 batch_normalization (BatchN  (None, 6144)             24576
 ormalization)

 reshape (Reshape)           (None, 6, 4, 256)         0

 conv2d_transpose (Conv2DTra  (None, 12, 8, 128)       819328
 nspose)

 resizing (Resizing)         (None, 11, 8, 128)        0

 reshape_1 (Reshape)         (None, 11, 8, 128)        0

 conv2d_transpose_1 (Conv2DT  (None, 22, 8, 128)       147584
 ranspose)

 resizing_1 (Resizing)       (None, 22, 8, 128)        0

 reshape_2 (Reshape)         (None, 22, 8, 128)        0

 conv2d_transpose_2 (Conv2DT  (None, 22, 8, 64)        73792
 ranspose)

 resizing_2 (Resizing)       (None, 22, 8, 64)         0

 reshape_3 (Reshape)         (None, 22, 8, 64)         0

 conv2d_transpose_3 (Conv2DT  (None, 44, 8, 64)        36928
 ranspose)

 resizing_3 (Resizing)       (None, 43, 8, 64)         0

 reshape_4 (Reshape)         (None, 43, 8, 64)         0

 conv2d_transpose_4 (Conv2DT  (None, 43, 8, 32)        18464
 ranspose)

 resizing_4 (Resizing)       (None, 43, 8, 32)         0

 reshape_5 (Reshape)         (None, 43, 8, 32)         0

 conv2d_transpose_5 (Conv2DT  (None, 86, 8, 32)        9248
 ranspose)

 resizing_5 (Resizing)       (None, 85, 8, 32)         0

 reshape_6 (Reshape)         (None, 85, 8, 32)         0

=================================================================
Total params: 38,884,800
Trainable params: 38,872,512
Non-trainable params: 12,288
_________________________________________________________________
2022-04-03 03:37:10.354570: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100
(1, 85, 8, 32)
1/1 [==============================] - 2s 2s/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000

 Sample

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Martijn Pieters