关于tensorflow:结合使用Estimators API和tf.data.Dataset时如何加快批处理准备

How to speed up batch preparation when using Estimators API combined with tf.data.Dataset

我想加快使用Estimator API和使用tf.data.Dataset编写的input_fn的训练例程。

我的实现需要2秒钟来准备一批数据,然后在GPU上进行训练1秒钟,然后重新开始准备一批数据。 这真的很低效。

我正在寻找一种异步准备批次并将其上传到GPU的方法,以加快培训速度。 或者作为一种在input_fn调用之间缓存数据集的方法(dataset.cache()似乎不是一个好选择,因为必须在每个input_fn调用上重新创建数据集)。

这是我的代码的简化版本:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def input_fn(filenames, labels, epochs):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads)
  if shuffle:
     dataset = dataset.shuffle(buffer_size=len(labels))
  dataset = dataset.map(_post_process,  num_parallel_calls=num_map_threads)
  dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))
  dataset = dataset.batch(128)
  dataset = dataset.repeat(epochs) # to iterate over the training set forever
  iterator = dataset.dataset.make_one_shot_iterator()
  features, labels = iterator.get_next()
  return features, labels

train_input_fn = lambda : input_fn(train_files, train_labels, None)
eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1)

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

我注意到Estimator API正在积极开发中,并且在tensorflow的master分支中,input_fn已经可以返回数据集了,所以也许我问得太早了并且该功能还没有准备好。 但是,如果是这样,请提供可以跟踪此实施情况的票证。


使用tf.data.Dataset.cache()确实不是一个好选择,因为它将把整个数据集缓存到内存中,这会花费时间并且可能会导致内存溢出。

解决方法是在管道的末尾使用tf.data.Dataset.prefetch(),这将始终确保数据管道包含buffer_size元素。通常在结尾加上buffer_size = 1就足够了:

1
2
3
dataset = ...
dataset = dataset.batch(128)
dataset = dataset.prefetch(1)  # prefetch one batch

如@mrry在此答案中所解释的,您还可以尝试稍微增加预提取的批处理数量。

Typically it is most useful to add a small prefetch buffer (with perhaps just a single element) at the very end of the pipeline, but more complex pipelines can benefit from additional prefetching, especially when the time to produce a single element can vary.

如果与GPU计算相比,您的输入管道仍然较慢,则需要使用tf.data.Dataset.map()num_parallel_calls参数来增加并行工作的线程数。


一些要补充到奥利维尔的答案的要点,主要来自于这篇文章:

  • shuffle之前的repeat稍微快一些,位于模糊的时代边界的下方。在极少数情况下,这可能很重要,但我对此表示怀疑。
  • map ping前使用shuffle-这会减少随机播放缓冲区大小的内存占用,因为它只需要缓冲文件名而不是文件内容。
  • 对于我来说,将第三次地图变换应用于get_next()的输出而不是数据集对我来说更有意义-不确定这是否会大大影响速度。您也可以考虑将其他两个地图调用放在同一个中,以减少调度问题。
  • batch之前尝试使用repeat。可能不会有所作为,但可能很小。如果您如上所述在shuffle之前repeat,则必须这样做。
  • 如Olivier所述,请使用prefetch

修改后的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def input_fn(filenames, labels, epochs):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.repeat(epochs)
  if shuffle:
    dataset = dataset.shuffle(buffer_size=len(labels))

  def combined_map_fn(*args):
    return _post_process(_read_wav(*args))

  dataset = dataset.map(combined_map_fn, num_parallel_calls=num_map_threads)
  dataset = dataset.batch(128)
  dataset = dataset.prefetch(1)

  iterator = dataset.dataset.make_one_shot_iterator()
  wavs, labels = iterator.get_next()
  features = {'wav': wavs}
  return features, labels