Skip to content

Higher order gradient of tfq layers #285

Open
@refraction-ray

Description

@refraction-ray

Both first order derivative of tfq layers and higher order gradient of non tfq layers work with GradientTape as follows:

x = tf.Variable(initial_value=0.2)
with tf.GradientTape() as t:
    with tf.GradientTape() as t2:
        y = x**3
        g = t2.gradient(y, x)
        print(g)
    g2 = t.gradient(g, x)
print(g2)
##########
a = sy.Symbol("a")
c = cirq.Circuit()
c.append(cirq.rx(a)(cirq.GridQubit(0,0)))
model = tfq.layers.PQC(c, operators=[cirq.Z(cirq.GridQubit(0,0))])
with tf.GradientTape() as t:
    o = model(tfq.convert_to_tensor([cirq.Circuit()]))[0,0]
    g = t.gradient(o, model.variables)
print(g)

However the following doesn't work, with tensorflow error as
InvalidArgumentError: Operation 'cond' has no attr named '_XlaCompile'. ValueError: Insufficient elements in branch_graphs[0].outputs. Expected: 6 Actual: 5

a = sy.Symbol("a")
c = cirq.Circuit()
c.append(cirq.rx(a)(cirq.GridQubit(0,0)))
model = tfq.layers.PQC(c, operators=[cirq.Z(cirq.GridQubit(0,0))])
with tf.GradientTape() as t:
    with tf.GradientTape() as t2:
        o = model(tfq.convert_to_tensor([cirq.Circuit()]))[0,0]
        g = t2.gradient(o, model.variables)
        print(g)
    g2 = t.gradient(g, model.variables)
print(g2)

Is my code wrong or tfq layers have special issues in terms of higher order gradients due to the way AD is implemented in these layers?

Metadata

Metadata

Assignees

No one assigned

    Labels

    status/going-staleHas been inactive for a while and may be closed soon

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions