Skip to content

Commit e8b7025

Browse files
Googlertensorflower-gardener
Googler
authored andcommitted
Passes kernel.num_tasks to tf.linalg.LinearOperatorIdentity in multitask_gaussian_process_regression_model.py
PiperOrigin-RevId: 737695798
1 parent 2297429 commit e8b7025

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _scale_from_precomputed(precomputed_cholesky, kernel):
158158
for param in params['kronecker_orths']:
159159
if 'identity' in param:
160160
ops.append(tf.linalg.LinearOperatorIdentity(
161-
param['identity'], dtype=params['diag'].dtype))
161+
kernel.num_tasks, dtype=params['diag'].dtype))
162162
elif 'unitary' in param:
163163
ops.append(
164164
linear_operator_unitary.LinearOperatorUnitary(param['unitary'])
@@ -187,7 +187,7 @@ def _precomputed_from_scale(observation_scale):
187187
if isinstance(observation_scale, tf.linalg.LinearOperatorComposition):
188188
kronecker_op, diag_op = observation_scale.operators
189189
kronecker_orths = [
190-
{'identity': k.domain_dimension_tensor()}
190+
{'identity': None}
191191
if isinstance(k, tf.linalg.LinearOperatorIdentity)
192192
else {'unitary': k.matrix} for k in kronecker_op.operators]
193193
return {'separable': {'kronecker_orths': kronecker_orths,

0 commit comments

Comments
 (0)