关于python:如何在热切的执行模式下使用tf.data数据集?

How can I use tf.data Datasets in eager execution mode?

在2018年TensorFlow开发峰会上的tf.data演讲中,Derek Murray提出了一种将tf.data API与TensorFlow急切执行模式结合的方法(在10:54)。 我尝试了显示在此处的代码的简化版本:

1
2
3
4
5
6
7
import tensorflow as tf
tf.enable_eager_execution()

dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([50, 10]))
dataset = dataset.batch(5)
for batch in dataset:
    print(batch)

引起

1
TypeError: 'BatchDataset' object is not iterable

我还尝试使用dataset.make_one_shot_iterator()dataset.make_initializable_iterator()遍历数据集,但是它们导致

1
RuntimeError: dataset.make_one_shot_iterator is not supported when eager execution is enabled.

1
RuntimeError: dataset.make_initializable_iterator is not supported when eager execution is enabled.

TensorFlow版本:1.7.0,Python版本:3.6

如何在急切执行中使用tf.data API?


make_one_shot_iterator()应该可以在TensorFlow 1.8中使用,但是就目前而言(即针对TensorFlow 1.7),请执行以下操作:

1
2
3
4
5
6
import tensorflow.contrib.eager as tfe

dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([50, 10]))
dataset = dataset.batch(5)
for batch in tfe.Iterator(dataset):
     print(batch)

有了TF 2.1,

您可以这样创建一个迭代器:

1
iterator = iter(dataset)

并获取下一批值:

1
batch = iterator.get_next()