九. Tensorflow 流程控制

控制流程

1. tf.cond = if else

先看看一个简单的例子:

1
2
3
4
5
6
7
8
9
import tensorflow as tf

x = tf.constant(2)
y = tf.constant(5)
def f1(): return tf.multiply(x, y)
def f2(): return tf.add(x, y)
r = tf.cond(tf.less(x, y), f1, f2)

r
1
<tf.Tensor 'cond/Merge:0' shape=() dtype=int32>
1
2
with tf.Session() as sess:
print(sess.run(r))
1
10

也就是说,给tf.cond分别传入判断条件;”是”的情况下的执行函数true_fn();”否”的情况下的执行函数false_fn

需要注意的是,true_fnfalse_fn在构图阶段都会被执行。在run阶段,才会根据实际的条件,取选择具体的数据流向。

尝试着在一个优化问题中,引入条件判断看看:

1
2
3
4
flag = tf.placeholder(tf.bool)
x = tf.Variable(tf.truncated_normal([1]))
r = tf.cond(flag, true_fn=lambda : tf.multiply(x, 2), false_fn=lambda : tf.multiply(x, 3))
goal = tf.pow(r-3,2)

上面代码的意思是,goal函数可以是$(2x-3)^2$,也可以是$(3x-3)^2$,由传入的flag决定。看看梯度会怎么计算。

1
2
3
4
5
6
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)

gra_and_var = optimizer.compute_gradients(goal)
print(gra_and_var)
print("x's gradient to goal:", gra_and_var[0][0])
print("x var itself:", gra_and_var[0][1])
1
2
3
[(<tf.Tensor 'gradients/AddN:0' shape=(1,) dtype=float32>, <tensorflow.python.ops.variables.Variable object at 0x7f2758386b50>)]
x's gradient to goal: Tensor("gradients/AddN:0", shape=(1,), dtype=float32)
x var itself: Tensor("Variable/read:0", shape=(1,), dtype=float32)
1
2
3
4
5
6
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
x_, gra_var = sess.run([x, gra_and_var], {flag: True})
print("flag == true, gradient == 8*x-12, x={}, gra={}, gra_true={}".format(x_, gra_var[0][0], 8*x_-12))
x_, gra_var = sess.run([x, gra_and_var], {flag: False})
print("flag == flase, gradient == 18*x-18, x={}, gra={}, gra_true={}".format(x_, gra_var[0][0], 18*x_-18))
1
2
flag == true, gradient == 8*x-12, x=[-0.28935], gra=[-14.31480026], gra_true=[-14.31480026]
flag == flase, gradient == 18*x-18, x=[-0.28935], gra=[-23.20829964], gra_true=[-23.20829964]

也就是说,tensorflow会根据条件判断来构建梯度计算路径,结果和预想的一致。