Skip to content

Trainable variables created by tfp.experimental.vi.util.build_trainable_linear_operator_block are lost after the resulting bijector is wrapped in a tfb.Chain #1997

Open
@bwalker1

Description

@bwalker1

Creating a linear operator via the function tfp.experimental.vi.util.build_trainable_linear_operator_block and then plugging it into tfb.ScaleMatvecLinearOperatorBlock produces a bijector with trainable variables. If this bijector is then put inside a tfb.Chain(), trainable variables are no longer found by the reflection. The same is not true if the linear operator is created manually. My usage of tfp.experimental.vi.util.build_trainable_linear_operator_block is based on the tutorial Variational_Inference_and_Joint_Distributions

import tensorflow as tf
import tensorflow_probability as tfp
tfb = tfp.bijectors
print(tf.__version__)
# 2.18.0
print(tfp.__version__)
# 0.25.0


# Broken example (should print True)
operators = ((tf.linalg.LinearOperatorDiag,),)
block_tril_linop = tfp.experimental.vi.util.build_trainable_linear_operator_block(
    operators, (1,)
)
scale_bijector = tfb.ScaleMatvecLinearOperatorBlock(block_tril_linop)
assert len(scale_bijector.trainable_variables) > 0
c = tfb.Chain([scale_bijector])

print(len(c.trainable_variables) > 0)
# False



# Working example
LO = tf.linalg.LinearOperatorBlockDiag([tf.linalg.LinearOperatorDiag(tf.Variable([1.0]))])
scale_bijector = tfb.ScaleMatvecLinearOperatorBlock(LO)
assert len(scale_bijector.trainable_variables) > 0
c = tfb.Chain([scale_bijector])

print(len(c.trainable_variables) > 0)
# True

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions