'how to get string value out of tf.tensor which dtype is string
I want to use tf.data.Dataset.list_files function to feed my datasets.
But because the file is not image, I need to load it manually.
The problem is tf.data.Dataset.list_files pass variable as tf.tensor and my python code can not handle tensor.
How can I get string value from tf.tensor. The dtype is string.
train_dataset = tf.data.Dataset.list_files(PATH+'clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: load_audio_file(x))
def load_audio_file(file_path):
print("file_path: ", file_path)
# i want do something like string_path = convert_tensor_to_string(file_path)
file_path is Tensor("arg0:0", shape=(), dtype=string)
I use tensorflow 1.13.1 and eager mode.
thanks in advance
Solution 1:[1]
You can use tf.py_func
to wrap load_audio_file()
.
import tensorflow as tf
tf.enable_eager_execution()
def load_audio_file(file_path):
# you should decode bytes type to string type
print("file_path: ",bytes.decode(file_path),type(bytes.decode(file_path)))
return file_path
train_dataset = tf.data.Dataset.list_files('clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: tf.py_func(load_audio_file, [x], [tf.string]))
for one_element in train_dataset:
print(one_element)
file_path: clean_4s_val/1.wav <class 'str'>
(<tf.Tensor: id=32, shape=(), dtype=string, numpy=b'clean_4s_val/1.wav'>,)
file_path: clean_4s_val/3.wav <class 'str'>
(<tf.Tensor: id=34, shape=(), dtype=string, numpy=b'clean_4s_val/3.wav'>,)
file_path: clean_4s_val/2.wav <class 'str'>
(<tf.Tensor: id=36, shape=(), dtype=string, numpy=b'clean_4s_val/2.wav'>,)
UPDATE for TF 2
The above solution will not work with TF 2 (tested with 2.2.0), even when replacing tf.py_func
with tf.py_function
, giving
InvalidArgumentError: TypeError: descriptor 'decode' requires a 'bytes' object but received a 'tensorflow.python.framework.ops.EagerTensor'
To make it work in TF 2, make the following changes:
- Remove
tf.enable_eager_execution()
(eager is enabled by default in TF 2, which you can verify withtf.executing_eagerly()
returningTrue
) - Replace
tf.py_func
withtf.py_function
- Replace all in-function references of
file_path
withfile_path.numpy()
Solution 2:[2]
If you want to do something completely custom, then wrapping your code in tf.py_function
is what you should do. Keep in mind that this will result in poor performance. See documentation and examples here:
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map
On the other hand if you are doing something generic, then you don't need to wrap your code in py_function
instead use any of the methods provided in tf.strings
module. These methods are made to work on string tensors and provide many common methods like split, join, len etc. These will not negatively effect performance, they will work on the tensor directly and return a modified tensor.
See documentation of tf.strings
here: https://www.tensorflow.org/api_docs/python/tf/strings
For example lets say you wanted to extract the name of the label from the file name you could then write code like this:
ds.map(lambda x: tf.strings.split(x, sep='$')[1])
The above assumes that the label is separated by a $
.
Solution 3:[3]
if you really want to unwrap Tensor to its string content only - you need to serialize TFRecord in order to use tf_example.SerializeToString() - to get (printable) string value - see here
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | desertnaut |
Solution 2 | user566245 |
Solution 3 | JeeyCi |