'sklearn train_test_split on list of 3-dimensional arrays

I want to do image classification and I have as data_X a list of 12000 three-dimensional numpy arrays. Those arrays all have the shape 300 x 300 x 3 (height, width, channel). My data_Y is just a list of 12000 ints (between 0 and 5), stating the class the array belongs to. When I use sklearn's train_test_split like:

X_train, X_test, Y_train, Y_test = train_test_split(data_X, data_Y, test_size=0.2, random_state=42)

The resulting X_train is a list of 9600 two dimensional arrays of the shape 300 x 300. How did I loose the third dimension?

Also when trying to fit a neural network like this:

model1 = Sequential()

model1.add(Conv2D(32, kernel_size=(3, 3), activation="relu", input_shape=(300, 300, 3)))

model1.add(Conv2D(32, kernel_size=(3, 3), activation="relu"))

model1.add(MaxPooling2D(pool_size=(2,2)))

model1.add(Dropout(0.25))

model1.add(Flatten())

model1.add(Dense(6, activation="softmax"))

model1.compile(optimizer="rmsprop", loss="categorical_crossentropy", metrics=["accuracy"])

model1.fit(X_train, Y_train, validation_data=(X_test, Y_test), epochs=80, batch_size=20)

I get this error: Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 1 array(s), but instead got the following list of 9600 arrays: [array([[1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.], ..., [1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.... Please help!



Solution 1:[1]

In case your splitting indices in a XD dataset are on the second or next axis/dimension, and not the first one, one can split his/her data based on these indices, like following:

import numpy as np
from sklearn.model_selection import train_test_split

data_arr = np.random.rand(4,4,100,3)
all_indices = list(range(100))
train_ind, test_ind = train_test_split(all_indices, test_size=0.2)

train = data_arr[:,:,train_ind,:]
test = data_arr[:,:,test_ind, :]

train.shape, test.shape

The output is : ((4, 4, 80, 3), (4, 4, 20, 3))

Credit for above code: https://towardsdatascience.com/3-things-you-need-to-know-before-you-train-test-split-869dfabb7e50

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 ImanB