How to speed up batch preparation when using Estimators API combined with tf.data.Dataset
我想加快使用Estimator API和使用
我的实现需要2秒钟来准备一批数据,然后在GPU上进行训练1秒钟,然后重新开始准备一批数据。 这真的很低效。
我正在寻找一种异步准备批次并将其上传到GPU的方法,以加快培训速度。 或者作为一种在
这是我的代码的简化版本:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | def input_fn(filenames, labels, epochs): dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads) if shuffle: dataset = dataset.shuffle(buffer_size=len(labels)) dataset = dataset.map(_post_process, num_parallel_calls=num_map_threads) dataset = dataset.map(lambda wav, label: ({'wav': wav}, label)) dataset = dataset.batch(128) dataset = dataset.repeat(epochs) # to iterate over the training set forever iterator = dataset.dataset.make_one_shot_iterator() features, labels = iterator.get_next() return features, labels train_input_fn = lambda : input_fn(train_files, train_labels, None) eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1) train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) |
我注意到Estimator API正在积极开发中,并且在tensorflow的master分支中,input_fn已经可以返回数据集了,所以也许我问得太早了并且该功能还没有准备好。 但是,如果是这样,请提供可以跟踪此实施情况的票证。
使用
解决方法是在管道的末尾使用
1 2 3 | dataset = ... dataset = dataset.batch(128) dataset = dataset.prefetch(1) # prefetch one batch |
如@mrry在此答案中所解释的,您还可以尝试稍微增加预提取的批处理数量。
Typically it is most useful to add a small prefetch buffer (with perhaps just a single element) at the very end of the pipeline, but more complex pipelines can benefit from additional prefetching, especially when the time to produce a single element can vary.
如果与GPU计算相比,您的输入管道仍然较慢,则需要使用
一些要补充到奥利维尔的答案的要点,主要来自于这篇文章:
-
在
shuffle 之前的repeat 稍微快一些,位于模糊的时代边界的下方。在极少数情况下,这可能很重要,但我对此表示怀疑。 -
在
map ping前使用shuffle -这会减少随机播放缓冲区大小的内存占用,因为它只需要缓冲文件名而不是文件内容。 -
对于我来说,将第三次地图变换应用于
get_next() 的输出而不是数据集对我来说更有意义-不确定这是否会大大影响速度。您也可以考虑将其他两个地图调用放在同一个中,以减少调度问题。 -
在
batch 之前尝试使用repeat 。可能不会有所作为,但可能很小。如果您如上所述在shuffle 之前repeat ,则必须这样做。 -
如Olivier所述,请使用
prefetch 。
修改后的代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | def input_fn(filenames, labels, epochs): dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) dataset = dataset.repeat(epochs) if shuffle: dataset = dataset.shuffle(buffer_size=len(labels)) def combined_map_fn(*args): return _post_process(_read_wav(*args)) dataset = dataset.map(combined_map_fn, num_parallel_calls=num_map_threads) dataset = dataset.batch(128) dataset = dataset.prefetch(1) iterator = dataset.dataset.make_one_shot_iterator() wavs, labels = iterator.get_next() features = {'wav': wavs} return features, labels |