十五. Tensorflow Dataset基础

熟悉接口

下面这些接口在调用时并不会真正处理数据,类似其他的tensorflow构图机制,只是构图,真正启用需要调用session run。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import tensorflow as tf
import numpy as np
import time

train_file = "/nfs/project/han_new/all_data/data_kuai_che_fea/kc.train.feature.npy"
test_file = "/nfs/project/han_new/all_data/data_kuai_che_fea/kc.test.feature.npy"

train_dense_feature = np.load(train_file).astype(np.float32)[:200]
test_dense_feature = np.load(test_file).astype(np.float32)[:100]

def get_session():
"""load a new session"""
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
return tf.Session(config=config)

def simple_time(func):
def wrapper(*arg, **kw):
t1 = time.time()
res = func(*arg, **kw)
t2 = time.time()
print("time cost: ", t2 - t1)
return res
return wrapper

@simple_time
def inspect_dataset(dataset, sess):
# 检查数据集内容
iterator = dataset.make_one_shot_iterator()
data_tensor = iterator.get_next()
print(sess.run(data_tensor))
print(sess.run(data_tensor))

@simple_time
def get_dataset_size_by_iter(dataset, sess):
# 通过实际调用迭代器,获取数据集的长度
count = 0
iterator = dataset.make_one_shot_iterator()
data_tensor = iterator.get_next()
while True:
try:
sess.run(data_tensor)
count += 1
except Exception as e:
# print(e)
return count

from_tensor_slices

创造一个数据集,每个元素来自于传入数据的每个切片。

1
2
3
t = time.time()
dataset = tf.data.Dataset.from_tensor_slices(train_dense_feature)
print(time.time() - t)
1
0.0039520263671875
1
print(dataset)
1
<TensorSliceDataset shapes: (3,), types: tf.float32>
1
2
3
sess = get_session()

inspect_dataset(dataset, sess)
1
2
3
[-0.37590072 -0.11182938  0.40723833]
[ 0.1867499 1.4743736 -0.02962399]
time cost: 0.053751230239868164

输出两条数据

1
get_dataset_size_by_iter(dataset, sess)
1
2
3
time cost:  0.038236379623413086

200

repeat

重复拷贝数据集。

1
2
3
dataset = dataset.repeat(count=2)

print(dataset)
1
<RepeatDataset shapes: (3,), types: tf.float32>
1
get_dataset_size_by_iter(dataset, sess)
1
2
3
time cost:  0.06870865821838379

400

输出的数据条数是之前的两倍。

batch

数据集分批。

1
2
3
dataset1 = dataset.batch(5)  # batch size = 5

print(dataset1)
1
<BatchDataset shapes: (?, 3), types: tf.float32>
1
get_dataset_size_by_iter(dataset, sess)
1
2
3
time cost:  0.07429909706115723

400
1
400 / 32
1
12.5
1
inspect_dataset(dataset1, sess)
1
2
3
4
5
6
7
8
9
10
11
[[-0.37590072 -0.11182938  0.40723833]
[ 0.1867499 1.4743736 -0.02962399]
[-0.83972436 -0.33842978 -0.2310548 ]
[ 0.31345868 -0.6216803 -0.28837252]
[-0.69582295 -0.7916306 -0.07001938]]
[[ 0.33324835 -0.7349805 -0.22900774]
[ 0.79372746 -0.5650302 -1.0580674 ]
[ 0.5861701 4.1369286 -0.5831493 ]
[-0.20523202 -0.8482807 1.7013708 ]
[ 0.16267557 -0.11182938 0.05143963]]
time cost: 0.015297651290893555

shuffle

shuffle的实现是先在取一些元素到buffer中,然后每次从buffer中sample一个元素。buffer_size代表这个buffer的最大尺寸。

1
2
3
dataset11 = dataset1.shuffle(buffer_size=1)

inspect_dataset(dataset11, sess)
1
2
3
4
5
6
7
8
9
10
11
[[-0.37590072 -0.11182938  0.40723833]
[ 0.1867499 1.4743736 -0.02962399]
[-0.83972436 -0.33842978 -0.2310548 ]
[ 0.31345868 -0.6216803 -0.28837252]
[-0.69582295 -0.7916306 -0.07001938]]
[[ 0.33324835 -0.7349805 -0.22900774]
[ 0.79372746 -0.5650302 -1.0580674 ]
[ 0.5861701 4.1369286 -0.5831493 ]
[-0.20523202 -0.8482807 1.7013708 ]
[ 0.16267557 -0.11182938 0.05143963]]
time cost: 0.016983985900878906

可以发现,如果buffer_size是1,则完全没有起到shuffle的作用。另外可以认识到,shuffle是针对element在shuffle,而element内部的数据不会被shuffle。

1
2
3
dataset12 = dataset1.shuffle(buffer_size=2)

inspect_dataset(dataset12, sess)
1
2
3
4
5
6
7
8
9
10
11
[[ 0.33324835 -0.7349805  -0.22900774]
[ 0.79372746 -0.5650302 -1.0580674 ]
[ 0.5861701 4.1369286 -0.5831493 ]
[-0.20523202 -0.8482807 1.7013708 ]
[ 0.16267557 -0.11182938 0.05143963]]
[[-0.37590072 -0.11182938 0.40723833]
[ 0.1867499 1.4743736 -0.02962399]
[-0.83972436 -0.33842978 -0.2310548 ]
[ 0.31345868 -0.6216803 -0.28837252]
[-0.69582295 -0.7916306 -0.07001938]]
time cost: 0.021335840225219727

buffer_size是2,发现第一第二个数据变换了顺序。

因此,在设置buffer_size的时候一定要设置较大的数。

prefetch

prefetch将提前fetch一些element保存在buffer中,buffer_size代表这个buffer的最大尺寸。用于减少数据读取和计算的串行时间。

1
2
3
dataset11 = dataset1.prefetch(buffer_size=1)

inspect_dataset(dataset11, sess)
1
2
3
4
5
6
7
8
9
10
11
[[-0.37590072 -0.11182938  0.40723833]
[ 0.1867499 1.4743736 -0.02962399]
[-0.83972436 -0.33842978 -0.2310548 ]
[ 0.31345868 -0.6216803 -0.28837252]
[-0.69582295 -0.7916306 -0.07001938]]
[[ 0.33324835 -0.7349805 -0.22900774]
[ 0.79372746 -0.5650302 -1.0580674 ]
[ 0.5861701 4.1369286 -0.5831493 ]
[-0.20523202 -0.8482807 1.7013708 ]
[ 0.16267557 -0.11182938 0.05143963]]
time cost: 0.01833629608154297