'How to modify a network that recognizes single-channel images to recognize three-channel?
When I run the code below, I can train the MNIST dataset with the model in the code. But when I modified the dataset to RGB dataset, the batch_size in the parameters output by the model became tripled. Is this due to the number of channels being changed from 1 to 3? How can I modify it?
import tensorflow as tf
class local_embedding(tf.keras.Model):
def __init__(self, seed=1):
super(local_embedding, self).__init__()
tf.random.set_seed(seed)
# convolutional layer 1
self.c1 = tf.keras.layers.Conv2D(filters=64, kernel_size=(5, 5), padding='same', data_format='channels_last', input_shape=(32, 32, 3), name='c1')
# max pooling 1
self.s1 = tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same', data_format='channels_last', name='m1')
# convolutional layer 2
self.c2 = tf.keras.layers.Conv2D(filters=128, kernel_size=(5, 5), padding='same', data_format='channels_last', name='c2')
# max pooling 2
self.s2 = tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same', data_format='channels_last', name='m2')
# convolutional layer 3
# self.c3 = tf.keras.layers.Conv2D(filters=64, kernel_size=(5, 5), padding='same', name='c3')
# max pooling 3
# self.s3 = tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same', name='m3')
# activation function
self.relu = tf.keras.activations.relu
self.sigmoid = tf.keras.activations.sigmoid
self.fc1 = tf.keras.layers.Dense(256, activation = None, name = 'fc1')
self.fc2 = tf.keras.layers.Dense(64, activation = None, name = 'fc2')
self.fc3 = tf.keras.layers.Dense(10, activation = None, name = 'fc3')
def forward(self, input):
# input size should be 14x14x3
print(input.shape)
x = tf.reshape(input, [-1, 14, 14, 1])
x = self.c1(x)
x = self.relu(x)
x = self.s1(x)
# x should be 7x7x64
x = self.c2(x)
x = self.relu(x)
x = self.s2(x)
# x should be 4x4x128
print(x.shape)
'''
x = self.c3(x)
x = self.relu(x)
x = self.s3(x)
# x should be 6x6x64
'''
middle_input = tf.reshape(x, [-1, 4 * 4 * 128]) # 2304
middle_output = self.fc1(middle_input)
x = self.relu(middle_output)
# x = self.sigmoid(middle_output)
# x should be 256
x = self.fc2(x)
x = self.relu(x)
# x = self.sigmoid(x)
# x should be 64
x = self.fc3(x)
x = self.relu(x)
# x = self.sigmoid(x)
return middle_input, x, middle_output
class server(tf.keras.Model):
def __init__(self, seed=0):
super(server, self).__init__()
tf.random.set_seed(seed)
self.last = tf.keras.layers.Dense(10, activation = 'softmax', name = 'last')
def forward(self, input):
# the size of input should be 1024
output = self.last(input)
# the size of output is 10
return output
The following is the details of the code error
Traceback (most recent call last):
File "cafe.py", line 135, in <module>
vfl_cafe()
File "cafe.py", line 76, in vfl_cafe
= take_gradient(number_of_workers, random_lists, real_data, real_labels, local_net, Server)
File "/content/utils.py", line 145, in take_gradient
loss = compute_loss(label, predict)
File "/content/utils.py", line 15, in compute_loss
return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels, logits), axis=-1)
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py", line 7186, in raise_from_not_ok_status
raise core._status_to_exception(e) from None # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: logits and labels must be broadcastable: logits_size=[120,10] labels_size=[40,10] [Op:SoftmaxCrossEntropyWithLogits]
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|