Open
Description
Creating a linear operator via the function tfp.experimental.vi.util.build_trainable_linear_operator_block
and then plugging it into tfb.ScaleMatvecLinearOperatorBlock
produces a bijector with trainable variables. If this bijector is then put inside a tfb.Chain(), trainable variables are no longer found by the reflection. The same is not true if the linear operator is created manually. My usage of tfp.experimental.vi.util.build_trainable_linear_operator_block
is based on the tutorial Variational_Inference_and_Joint_Distributions
import tensorflow as tf
import tensorflow_probability as tfp
tfb = tfp.bijectors
print(tf.__version__)
# 2.18.0
print(tfp.__version__)
# 0.25.0
# Broken example (should print True)
operators = ((tf.linalg.LinearOperatorDiag,),)
block_tril_linop = tfp.experimental.vi.util.build_trainable_linear_operator_block(
operators, (1,)
)
scale_bijector = tfb.ScaleMatvecLinearOperatorBlock(block_tril_linop)
assert len(scale_bijector.trainable_variables) > 0
c = tfb.Chain([scale_bijector])
print(len(c.trainable_variables) > 0)
# False
# Working example
LO = tf.linalg.LinearOperatorBlockDiag([tf.linalg.LinearOperatorDiag(tf.Variable([1.0]))])
scale_bijector = tfb.ScaleMatvecLinearOperatorBlock(LO)
assert len(scale_bijector.trainable_variables) > 0
c = tfb.Chain([scale_bijector])
print(len(c.trainable_variables) > 0)
# True
Metadata
Metadata
Assignees
Labels
No labels