'How to modify a network that recognizes single-channel images to recognize three-channel?

When I run the code below, I can train the MNIST dataset with the model in the code. But when I modified the dataset to RGB dataset, the batch_size in the parameters output by the model became tripled. Is this due to the number of channels being changed from 1 to 3? How can I modify it?

import tensorflow as tf

class local_embedding(tf.keras.Model):
def __init__(self, seed=1):
    super(local_embedding, self).__init__()
    tf.random.set_seed(seed)
    # convolutional layer 1
    self.c1 = tf.keras.layers.Conv2D(filters=64, kernel_size=(5, 5), padding='same', data_format='channels_last', input_shape=(32, 32, 3), name='c1')
    # max pooling 1
    self.s1 = tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same', data_format='channels_last', name='m1')
    # convolutional layer 2
    self.c2 = tf.keras.layers.Conv2D(filters=128, kernel_size=(5, 5), padding='same', data_format='channels_last', name='c2')
    # max pooling 2
    self.s2 = tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same', data_format='channels_last', name='m2')
    # convolutional layer 3
    # self.c3 = tf.keras.layers.Conv2D(filters=64, kernel_size=(5, 5), padding='same', name='c3')
    # max pooling 3
    # self.s3 = tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same', name='m3')
    # activation function
    self.relu = tf.keras.activations.relu
    self.sigmoid = tf.keras.activations.sigmoid
    self.fc1 = tf.keras.layers.Dense(256, activation = None, name = 'fc1')
    self.fc2 = tf.keras.layers.Dense(64, activation = None, name = 'fc2')
    self.fc3 = tf.keras.layers.Dense(10, activation = None, name = 'fc3')

def forward(self, input):
    # input size should be 14x14x3
    print(input.shape)
    x = tf.reshape(input, [-1, 14, 14, 1])
    x = self.c1(x)
    x = self.relu(x)
    x = self.s1(x)
    # x should be 7x7x64

    x = self.c2(x)
    x = self.relu(x)
    x = self.s2(x)
    # x should be 4x4x128

    print(x.shape)

    '''
    x = self.c3(x)
    x = self.relu(x)
    x = self.s3(x)
    # x should be 6x6x64
    '''

    middle_input = tf.reshape(x, [-1, 4 * 4 * 128]) # 2304

    middle_output = self.fc1(middle_input)
    x = self.relu(middle_output)
    # x = self.sigmoid(middle_output)
    # x should be 256
    x = self.fc2(x)
    x = self.relu(x)
    # x = self.sigmoid(x)
    # x should be 64
    x = self.fc3(x)
    x = self.relu(x)
    # x = self.sigmoid(x)
    return middle_input, x, middle_output

class server(tf.keras.Model):
def __init__(self, seed=0):
    super(server, self).__init__()
    tf.random.set_seed(seed)
    self.last = tf.keras.layers.Dense(10, activation = 'softmax', name = 'last')

def forward(self, input):
    # the size of input should be 1024
    output = self.last(input)
    # the size of output is 10
    return output

The following is the details of the code error

Traceback (most recent call last):
File "cafe.py", line 135, in <module>
  vfl_cafe()
File "cafe.py", line 76, in vfl_cafe
  = take_gradient(number_of_workers, random_lists, real_data, real_labels, local_net, Server)
File "/content/utils.py", line 145, in take_gradient
  loss = compute_loss(label, predict)
File "/content/utils.py", line 15, in compute_loss
  return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels, logits), axis=-1)
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
  raise e.with_traceback(filtered_tb) from None
File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py", line 7186, in raise_from_not_ok_status
  raise core._status_to_exception(e) from None  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: logits and labels must be broadcastable: logits_size=[120,10] labels_size=[40,10] [Op:SoftmaxCrossEntropyWithLogits]


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source