熟悉接口
下面这些接口在调用时并不会真正处理数据,类似其他的tensorflow构图机制,只是构图,真正启用需要调用session run。
1 | import tensorflow as tf |
from_tensor_slices
创造一个数据集,每个元素来自于传入数据的每个切片。
1 | t = time.time() |
1 | 0.0039520263671875 |
1 | print(dataset) |
1 | <TensorSliceDataset shapes: (3,), types: tf.float32> |
1 | sess = get_session() |
1 | [-0.37590072 -0.11182938 0.40723833] |
输出两条数据
1 | get_dataset_size_by_iter(dataset, sess) |
1 | time cost: 0.038236379623413086 |
repeat
重复拷贝数据集。
1 | dataset = dataset.repeat(count=2) |
1 | <RepeatDataset shapes: (3,), types: tf.float32> |
1 | get_dataset_size_by_iter(dataset, sess) |
1 | time cost: 0.06870865821838379 |
输出的数据条数是之前的两倍。
batch
数据集分批。
1 | dataset1 = dataset.batch(5) # batch size = 5 |
1 | <BatchDataset shapes: (?, 3), types: tf.float32> |
1 | get_dataset_size_by_iter(dataset, sess) |
1 | time cost: 0.07429909706115723 |
1 | 400 / 32 |
1 | 12.5 |
1 | inspect_dataset(dataset1, sess) |
1 | [[-0.37590072 -0.11182938 0.40723833] |
shuffle
shuffle的实现是先在取一些元素到buffer中,然后每次从buffer中sample一个元素。buffer_size
代表这个buffer的最大尺寸。
1 | dataset11 = dataset1.shuffle(buffer_size=1) |
1 | [[-0.37590072 -0.11182938 0.40723833] |
可以发现,如果buffer_size
是1,则完全没有起到shuffle的作用。另外可以认识到,shuffle是针对element在shuffle,而element内部的数据不会被shuffle。
1 | dataset12 = dataset1.shuffle(buffer_size=2) |
1 | [[ 0.33324835 -0.7349805 -0.22900774] |
buffer_size
是2,发现第一第二个数据变换了顺序。
因此,在设置buffer_size的时候一定要设置较大的数。
prefetch
prefetch将提前fetch一些element保存在buffer中,buffer_size
代表这个buffer的最大尺寸。用于减少数据读取和计算的串行时间。
1 | dataset11 = dataset1.prefetch(buffer_size=1) |
1 | [[-0.37590072 -0.11182938 0.40723833] |