@@ -105,7 +105,18 @@ def backward(output_):
105105 index = output_ .optrace [id (output_ )][1 ]
106106 arg = output_ .optrace [id (output_ )][2 ]
107107
108- # check if upstream index exists
108+ # If the last operation is indexing, there is no downstream map op to
109+ # populate Jacobian values. In this case, the Jacobian block values are
110+ # identity matrices placed at the indexed columns.
111+ if not hasattr (output_ , 'jactrace' ):
112+ if output_ .ndim == 1 :
113+ eye_blocks = torch .ones ((output_ .shape [0 ], 1 , 1 ), device = output_ .device , dtype = output_ .dtype )
114+ else :
115+ block_dim = output_ .shape [- 1 ]
116+ eye = torch .eye (block_dim , device = output_ .device , dtype = output_ .dtype )
117+ eye_blocks = eye .unsqueeze (0 ).repeat (output_ .shape [0 ], 1 , 1 )
118+ output_ .jactrace = (None , eye_blocks )
119+
109120 if type (output_ .jactrace ) is tuple :
110121 if output_ .jactrace [0 ] is not None :
111122 upstream_index = output_ .jactrace [0 ]
@@ -124,7 +135,7 @@ def backward(output_):
124135
125136
126137def jacobian (output , params ):
127- assert output .optrace [id (output )][0 ] == 'map' , "The last operation in compute graph being indexing transform is not meaningful "
138+ assert output .optrace [id (output )][0 ] in ( 'map' , 'index' ), "Unsupported last operation in compute graph"
128139 backward (output )
129140 res = []
130141 for param in params :
0 commit comments