'Training simple CNN-LSTM model
I have a task for my project paper and I do not get how to train the model. This model is supposed to take an image and segment it into different classes. The hard part is that the different segmentation is the same but I would like to differentiate between them. When I try to make a model with convolutional layers and LSTM, The model only predicts the class of the background.
Here is my model:
def LSTMconv10x9(input_size = (200, 9, 10, 1)):
input = Input(input_size)
conv1 = TimeDistributed(Conv2D(32, 3, padding = "same", activation='relu'))(input)
conv2 = TimeDistributed(Conv2D(64, 3, padding = "same", activation='relu'))(conv1)
lstm = ConvLSTM2D(32, 3, return_sequences=True, padding="same", activation="softmax")(conv2)
conv4 = TimeDistributed(Conv2D(64,3, padding = 'same', activation='relu'))(lstm)
conv5 = TimeDistributed(Conv2D(32,3, padding = 'same', activation='relu'))(conv4)
output = ConvLSTM2D(11,1, return_sequences = True, padding = "same", activation = None)(conv5)
model = Model(inputs = input, outputs = output)
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer = tf.keras.optimizers.Adam(),
metrics=["accuracy"], sample_weight_mode='temporal')
And the way I train the model:
weights = np.where(train_y == 0, 0.1, 0.9)
model1 = LSTMconv10x9simple()
model1.fit(train_x,train_y,epochs=20, batch_size=32,validation_data=(test_x, test_y),sample_weight=weights)
The training set size is (2000,200,9,10,1) and the validation set is (1000,200,9,10,1), where I have 2000 videos of 200 frames in the trainingset, the videos are of 10 structures that look the same but I would to numerate them in a way as different structures. This is a segmentation problem.
The data is very unbalanced as there are objects in each video that I want to separate, but the background is about 90% of the videos. I have tried initializing weights with the "sample_weight_mode='temporal'" method in TensorFlow, but it did not seem to work. The most important thing in the model is to find the structures.
Does anyone have any solutions to my problems?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|