'sklearn train_test_split on list of 3-dimensional arrays
I want to do image classification and I have as data_X
a list of 12000 three-dimensional numpy arrays. Those arrays all have the shape 300 x 300 x 3 (height, width, channel)
. My data_Y
is just a list of 12000 ints (between 0 and 5), stating the class the array belongs to. When I use sklearn's train_test_split
like:
X_train, X_test, Y_train, Y_test = train_test_split(data_X, data_Y, test_size=0.2, random_state=42)
The resulting X_train
is a list of 9600 two dimensional arrays of the shape 300 x 300
. How did I loose the third dimension?
Also when trying to fit a neural network like this:
model1 = Sequential()
model1.add(Conv2D(32, kernel_size=(3, 3), activation="relu", input_shape=(300, 300, 3)))
model1.add(Conv2D(32, kernel_size=(3, 3), activation="relu"))
model1.add(MaxPooling2D(pool_size=(2,2)))
model1.add(Dropout(0.25))
model1.add(Flatten())
model1.add(Dense(6, activation="softmax"))
model1.compile(optimizer="rmsprop", loss="categorical_crossentropy", metrics=["accuracy"])
model1.fit(X_train, Y_train, validation_data=(X_test, Y_test), epochs=80, batch_size=20)
I get this error:
Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 1 array(s), but instead got the following list of 9600 arrays: [array([[1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.], ..., [1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1....
Please help!
Solution 1:[1]
In case your splitting indices in a XD dataset are on the second or next axis/dimension, and not the first one, one can split his/her data based on these indices, like following:
import numpy as np
from sklearn.model_selection import train_test_split
data_arr = np.random.rand(4,4,100,3)
all_indices = list(range(100))
train_ind, test_ind = train_test_split(all_indices, test_size=0.2)
train = data_arr[:,:,train_ind,:]
test = data_arr[:,:,test_ind, :]
train.shape, test.shape
The output is : ((4, 4, 80, 3), (4, 4, 20, 3))
Credit for above code: https://towardsdatascience.com/3-things-you-need-to-know-before-you-train-test-split-869dfabb7e50
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | ImanB |