'Implementation of a WGAN-GP in tensorflow
Using tensorflow
, I'm trying to reimplement the following architecture (for now I'm focusing on the Generator part):
What I've done for now has been defining the generator in the following way:
N_Z = 128
generator = [
tf.keras.layers.Dense(units=6144, activation="relu"),
tf.keras.layers.Reshape(target_shape=(6, 4, 256)),
tf.keras.layers.Conv2DTranspose(
filters=128, kernel_size=(5,5), strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=128, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
)
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
)
tf.keras.layers.Conv2DTranspose(
filters=1, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
)
]
Generator = tf.keras.models.Sequential(generator)
But if I take some random noise and let the model process it, this is the final shape I get back:
noise = tf.random.normal((64,128))
result = Generator(noise)
result.shape
TensorShape([64, 28, 28, 1])
What am I doing wrong here? I was also checking the original implementation to see additional details but I can't find anything that makes me understand.
Solution 1:[1]
It is easy you need to see input-output, it required some help at the top levels.
[ Sample ]:
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Initialize
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=( 6144 )),
tf.keras.layers.Dense( 48 * 128, activation="linear" ),
tf.keras.layers.BatchNormalization( momentum=0.99, epsilon=0.00001 ),
tf.keras.layers.Reshape(target_shape=( 6, 4, 256 )),
tf.keras.layers.Conv2DTranspose(
filters=128, kernel_size=(5,5), strides=(2, 2), padding="same", activation="relu"
),
tf.keras.layers.Resizing( 11, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
tf.keras.layers.Reshape(target_shape=(11, 8, 128)),
tf.keras.layers.Conv2DTranspose(
filters=128, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Resizing( 22, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
tf.keras.layers.Reshape(target_shape=(22, 8, 128)),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Resizing( 22, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
tf.keras.layers.Reshape(target_shape=(22, 8, 64)),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Resizing( 43, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
tf.keras.layers.Reshape(target_shape=(43, 8, 64)),
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Resizing( 43, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
tf.keras.layers.Reshape(target_shape=(43, 8, 32)),
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Resizing( 85, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
tf.keras.layers.Reshape(target_shape=(85, 8, 32)),
])
model.summary()
[ Output ]:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 6144) 37754880
batch_normalization (BatchN (None, 6144) 24576
ormalization)
reshape (Reshape) (None, 6, 4, 256) 0
conv2d_transpose (Conv2DTra (None, 12, 8, 128) 819328
nspose)
resizing (Resizing) (None, 11, 8, 128) 0
reshape_1 (Reshape) (None, 11, 8, 128) 0
conv2d_transpose_1 (Conv2DT (None, 22, 8, 128) 147584
ranspose)
resizing_1 (Resizing) (None, 22, 8, 128) 0
reshape_2 (Reshape) (None, 22, 8, 128) 0
conv2d_transpose_2 (Conv2DT (None, 22, 8, 64) 73792
ranspose)
resizing_2 (Resizing) (None, 22, 8, 64) 0
reshape_3 (Reshape) (None, 22, 8, 64) 0
conv2d_transpose_3 (Conv2DT (None, 44, 8, 64) 36928
ranspose)
resizing_3 (Resizing) (None, 43, 8, 64) 0
reshape_4 (Reshape) (None, 43, 8, 64) 0
conv2d_transpose_4 (Conv2DT (None, 43, 8, 32) 18464
ranspose)
resizing_4 (Resizing) (None, 43, 8, 32) 0
reshape_5 (Reshape) (None, 43, 8, 32) 0
conv2d_transpose_5 (Conv2DT (None, 86, 8, 32) 9248
ranspose)
resizing_5 (Resizing) (None, 85, 8, 32) 0
reshape_6 (Reshape) (None, 85, 8, 32) 0
=================================================================
Total params: 38,884,800
Trainable params: 38,872,512
Non-trainable params: 12,288
_________________________________________________________________
2022-04-03 03:37:10.354570: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100
(1, 85, 8, 32)
1/1 [==============================] - 2s 2s/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Martijn Pieters |