控制流程
1. tf.cond = if else
先看看一个简单的例子:
1 | import tensorflow as tf |
1 | <tf.Tensor 'cond/Merge:0' shape=() dtype=int32> |
1 | with tf.Session() as sess: |
1 | 10 |
也就是说,给tf.cond
分别传入判断条件;”是”的情况下的执行函数true_fn()
;”否”的情况下的执行函数false_fn
。
需要注意的是,true_fn
和false_fn
在构图阶段都会被执行。在run阶段,才会根据实际的条件,取选择具体的数据流向。
尝试着在一个优化问题中,引入条件判断看看:
1 | flag = tf.placeholder(tf.bool) |
上面代码的意思是,goal函数可以是$(2x-3)^2$,也可以是$(3x-3)^2$,由传入的flag决定。看看梯度会怎么计算。
1 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1) |
1 | [(<tf.Tensor 'gradients/AddN:0' shape=(1,) dtype=float32>, <tensorflow.python.ops.variables.Variable object at 0x7f2758386b50>)] |
1 | with tf.Session() as sess: |
1 | flag == true, gradient == 8*x-12, x=[-0.28935], gra=[-14.31480026], gra_true=[-14.31480026] |
也就是说,tensorflow会根据条件判断来构建梯度计算路径,结果和预想的一致。