本篇博客记录如何实现变量复用。
首先介绍变量复用的典型应用场景:
rnn单元,按时间展开后变量是复用的;多gpu工作的时候,可能要统一使用cpu定义的变量。
tf.variable_scope
variable_scope
相当于给变量名前面加一些前缀,避免名称冲突的问题;也可以使得graph在tensorboard中可以更有层次地显示。
变量复用往往和variable_scope
一起使用,当你希望你的变量定义的是不复用的时候,使用:
1 | with tf.variable_scope("scope_name", reuse=False): |
当你希望当前域中的变量是可以复用的时候,使用:
1 | with tf.variable_scope("scope_name", reuse=tf.AUTO_REUSE): |
或者强制复用变量:
1 | with tf.variable_scope("scope_name", reuse=True): |
或者显式地调用:
1 | scope.reuse_variables() |
当然设置完variable_scope只是起点,接下来是获取变量。
tf.get_variable
get_variable
会按照名称去找之前有没有定义好的variable
,有则复用;无则创建。
1 | import tensorflow as tf |