Skip to content

TransformedVariable does not follow array convention for equality (==) comparison #1241

Open
@st--

Description

@st--

A tfp.util.TransformedVariable generally behaves like any tensor-like object. For example, a less-than-equal comparison works element-wise, as expected:

import tensorflow as tf
import tensorflow_probability as tfp
var = tf.Variable(1.0)
assert var <= 1.0 and var >= 1.0 and var == 1.0  # passes
t_var = tfp.util.TransformedVariable(1.0, tfp.bijectors.Identity())
assert t_var <= 1.0 and t_var >= 1.0  # passes

However, equality comparison does not behave likewise, and the following fails:

assert t_var == 1.0  # FAILS

This is completely unexpected behaviour which can lead to very subtle downstream bugs.

The reason for this behaviour seems to be due to explicitly removing __eq__ and __ne__ in

operators.difference_update({'__eq__', '__ne__'})

What was the motivation for excluding equality comparisons? Would there be any objections to a PR that removes that line, thereby allowing tfp TransformedVariables to behave more expectedly like tf Variables?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions