十二. tensorflow Variable复用/get_variable

本篇博客记录如何实现变量复用。

首先介绍变量复用的典型应用场景:

rnn单元,按时间展开后变量是复用的;多gpu工作的时候,可能要统一使用cpu定义的变量。

tf.variable_scope

variable_scope相当于给变量名前面加一些前缀,避免名称冲突的问题;也可以使得graph在tensorboard中可以更有层次地显示。
变量复用往往和variable_scope一起使用,当你希望你的变量定义的是不复用的时候,使用:

1
with tf.variable_scope("scope_name", reuse=False):

当你希望当前域中的变量是可以复用的时候,使用:

1
with tf.variable_scope("scope_name", reuse=tf.AUTO_REUSE):

或者强制复用变量:

1
with tf.variable_scope("scope_name", reuse=True):

或者显式地调用:

1
scope.reuse_variables()

当然设置完variable_scope只是起点,接下来是获取变量。

tf.get_variable

get_variable会按照名称去找之前有没有定义好的variable,有则复用;无则创建。

1
2
3
4
5
6
7
8
9
10
import tensorflow as tf

# 定义scope
with tf.variable_scope("foo"):
# 第一次创建变量
v = tf.get_variable("v", [1])
tf.get_variable_scope().reuse_variables()
# 第二次直接复用
v1 = tf.get_variable("v", [1])
assert v1 is v