'Convert model.fit_generator to model.fit
I have codes in the following,
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
'data/validation',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
Now model.fit_generator
is defined as following:
model.fit_generator(
train_generator,
steps_per_epoch=2000,
epochs=50,
validation_data=validation_generator,
validation_steps=800)
Now model.fit_generator
is deprecated, what is the proper way to change model.fit_generator
to model.fit
in this case?
Solution 1:[1]
You just have to change model.fit_generator()
to model.fit()
.
As of TensorFlow 2.1, model.fit()
also accepts generators as input. As simple as that.
From TensorFlow's official documentation:
Warning: THIS FUNCTION IS DEPRECATED. It will be removed in a future version. Instructions for updating: Please use Model.fit, which supports generators.
Solution 2:[2]
Get rid of 'generator='.
Old training:
model.fit_generator(generator=train_generator,
steps_per_epoch=2048//36, epochs=10,
validation_data=validation_generator, validation_steps=832//16)
New training:
model.fit(train_generator,
steps_per_epoch=2048 // 128, epochs=10,
validation_data=validation_generator, validation_steps=832//16)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Timbus Calin |
Solution 2 | Makyen |