Tensorflow Data API - prefetch
我正在尝试使用TF的新功能,即Data API,并且我不确定预取如何工作。在下面的代码中
1 2 3 4 5 6 7 8 9 10 11 | def dataset_input_fn(...) dataset = tf.data.TFRecordDataset(filenames, compression_type="ZLIB") dataset = dataset.map(lambda x:parser(...)) dataset = dataset.map(lambda x,y: image_augmentation(...) , num_parallel_calls=num_threads ) dataset = dataset.shuffle(buffer_size) dataset = dataset.batch(batch_size) dataset = dataset.repeat(num_epochs) iterator = dataset.make_one_shot_iterator() |
放在
在github上的讨论中,我找到了mrry的评论:
Note that in TF 1.4 there will be a Dataset.prefetch() method that
makes it easier to add prefetching at any point in the pipeline, not
just after a map(). (You can try it by downloading the current nightly
build.)
和
For example, Dataset.prefetch() will start a background thread to
populate a ordered buffer that acts like a tf.FIFOQueue, so that
downstream pipeline stages need not block. However, the prefetch()
implementation is much simpler, because it doesn't need to support as
many different concurrent operations as a tf.FIFOQueue.
因此它意味着任何命令都可以进行预取,并且可以在前一个命令上使用。到目前为止,我注意到仅将其放在最末即可获得最大的性能提升。
在Dataset.map,Dataset.prefetch和Dataset.shuffle中还有关于buffer_size含义的讨论,其中mrry进一步说明了有关预取和缓冲区的内容。
更新2018/10/01:
从1.7.0版开始,Dataset API(有帮助)具有一个
https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/data/prefetch_to_device