How to use tf.datasets with iterator in Tensorflow
我正在尝试使用 tf.data.TextLineDataset 从 csv 文件中读取数据,将数据集分片到多个工作节点上,然后创建一个迭代器来迭代它们以分批提供数据。我使用了 TensorFlow 的 tf.datasets 程序员指南(https://www.tensorflow.org/programmers_guide/datasets)。
在 tf 会话中运行代码时的问题是我收到以下错误:
1 2 | *** tensorflow.python.framework.errors_impl.NotFoundError: Date,Open,High,Low,Last,Close,Total Trade Quantity,Turnover,close_pct_change_1d,KAMA7-KAMA30,KAMA15-KAMA30,HT_QUAD,TURNOVER,BOP,MFI,MINUS_DI,ROCP,STOCH_SLOWK,NATR,EMA7-EMA30-1d,DX-1d,PPO-1d,NATR-1d,HT_INPHASOR-2d,day_0,day_1,day_2,day_3; No such file or directory [[Node: IteratorGetNext_5 = IteratorGetNext[output_shapes=[[], [], [], [], [], ..., [], [], [], [], []], output_types=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32, ..., DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator_8)]] |
现在,"日期"、"开放"、"高"等是我要加载的数据集中的列。因此,我知道错误与加载数据集无关。
加载数据集时,我使用
有人知道这个错误是从哪里来的吗?有没有人能解决这个问题?
请看以下代码说明:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | def create_pipeline(bs, nr, ep): def _X_parse_csv(file): record_defaults=[[0]]*20 splits = tf.decode_csv(file, record_defaults) input = splits return input def _y_parse_csv(file): record_defaults=[[0]]*20 splits = tf.decode_csv(file, record_defaults) label = splits[0] return label # Dataset for input data file = tf.gfile.Glob("./NSEOIL.csv") num_workers = 1 # for testing; simulate 1 node for sharding below task_index = 0 ds_file = tf.data.TextLineDataset(file) ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(1))) #remove CSV headers ds = ds.shard(num_workers, task_index).repeat(ep) X_train = ds.map(_X_parse_csv) ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(2))) #remove CSV headers + shift forward 1 day ds = ds.shard(num_workers, task_index).repeat(ep) y_train = ds.map(_y_parse_csv) X_iterator = X_train.make_initializable_iterator() y_iterator = y_train.make_initializable_iterator() return X_iterator, y_iterator |
这两行似乎是问题的根源:
1 2 3 | ds_file = tf.data.TextLineDataset(file) ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(1))) #remove CSV headers |
第一行从
幸运的是,修复相对简单,因为您可以使用
1 2 3 4 5 6 | # Create a dataset of filenames. ds_file = tf.data.Dataset.list_files("./NSEOIL.csv") # For each filename in `ds_file`, read the lines from that file (skipping the # header). ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(1))) |