'How to replace make_one_shot_iterator () from google machine learning crash course
I am following the Google Machine Learning Intensive Course. But it uses version 1.x of TensorFlow, so I was planning to change the exercises to be able to run them in TensorFlow 2.0. But I am stuck in that exercise:
Specifically the code:
def my_input_fn(features, targets, batch_size=1, shuffle=True, num_epochs=None):
"""Trains a linear regression model of one feature.
Args:
features: pandas DataFrame of features
targets: pandas DataFrame of targets
batch_size: Size of batches to be passed to the model
shuffle: True or False. Whether to shuffle the data.
num_epochs: Number of epochs for which data should be repeated. None = repeat indefinitely
Returns:
Tuple of (features, labels) for next data batch
"""
# Convert pandas data into a dict of np arrays.
features = {key:np.array(value) for key,value in dict(features).items()}
# Construct a dataset, and configure batching/repeating.
ds = Dataset.from_tensor_slices((features,targets)) # warning: 2GB limit
ds = ds.batch(batch_size).repeat(num_epochs)
# Shuffle the data, if specified.
if shuffle:
ds = ds.shuffle(buffer_size=10000)
# Return the next batch of data.
features, labels = ds.make_one_shot_iterator().get_next()
return features, labels
I have replaced features, labels = ds.make_one_shot_iterator().get_next()
with features, labels = tf.compat.v1.data.make_one_shot_iterator(ds).get_next()
and it seems to work but make_one_shot_iterator() is depreceated, so, how can i replace it?
Also according to https://github.com/tensorflow/tensorflow/issues/29252 , I have tried
features, labels = ds.__iter__()
next(ds.__iter__())
return features, labels
but it returns the error __iter __ () is only supported inside of tf.function or when eager execution is enabled.
I am quite inexperienced in python and follow the course as a hobbyist. Any ideas on how to solve it? Thank you.
Solution 1:[1]
After several tests, the python hang was a local problem.
To replace features, labels = ds.make_one_shot_iterator (). Get_next ()
I have tried several things:
features, labels = ds.__iter__().get_next()
iterator = ds.__iter__()
features, labels = iterator.get_next()
it = iter(ds)
features, labels = next(it)
All three cases return __iter__() is only supported inside of tf.function or when eager execution is enabled.
so I tried:
features, labels = ds
return ds
And also just:
return features, labels
And both returns the same error, finally I tried:
return ds
And mysteriously it worked, I have no idea why, but it did.
Solution 2:[2]
1). I doubt, that you've really got what you wanted. Because if your input really needed to be multi-input - then your ds unlikely suits, you just need the list... something like this:
features = tf.compat.v1.data.make_one_shot_iterator(train_dataset).get_next()
image, label = features['image'], features['label']
2). Concerning Iterator - now it is belonging to 'tf.data' - with 'tf.data.Iterator.get_next()' method as opposed to previous tf.data.Datasetds.make_one_shot_iterator() -- 'Dependency Invertion' (D from SOLID principles of dev.) perhaps was done, perhaps to refactor.... New Iterator-entity now could be used for tf.data.Dataset.from_generator() objects feeding from fn_generator in async-mode each chunk of data yielded -- here is example of Custom-tfds.core.GeneratorBasedBuilder overwritting...
I think, the overall architecture of tf-lib was refactored a little-bit, because the input started to eat batch-by-batch itself (due to dev.'s implementations) -- & make_one_shot_iterator applied for Dataset no more needed... Even for debugging there is .as_numpy_iterator(), & make_one_shot_iterator no more considered to be needed by developers
though sometimes people use:
iterator = iter(batched_dataset)
next_element = iterator.get_next()
- cannot assume where this could be needed yet
P.S. BTW, as I remember smth from Debugger, if your container is hashable or not iterable (or correct me) - you can try:
iterator = iter(dataset)
# batch_features, batch_labels = iterator.get_next()
el = iterator.get_next()
batch_features= el[:]
print(batch_features)
batch_labels= el[:-1]
print(batch_labels)
- works OK
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Sir Fred |
Solution 2 |