File tree 2 files changed +13
-9
lines changed
2 files changed +13
-9
lines changed Original file line number Diff line number Diff line change @@ -5192,6 +5192,6 @@ def update_indices(i, _indices):
5192
5192
@op (torch .ops .aten .linear )
5193
5193
def linear (input , weight , bias = None ):
5194
5194
res = input @ jnp .transpose (weight )
5195
- if bias :
5195
+ if bias is not None :
5196
5196
res += bias
5197
5197
return res
Original file line number Diff line number Diff line change @@ -295,16 +295,20 @@ def get_as_jax_device(self, device: Any):
295
295
296
296
if isinstance (device , torch .device ):
297
297
device = str (device )
298
- if (self .config .use_torch_native_for_cpu_tensor and
299
- not device .startswith ('jax' ) and not device .startswith ('cuda' )):
300
- return None
301
298
302
- if not self .config .treat_cuda_as_jax_device and device .startswith ('cuda' ):
303
- return None
304
-
305
- if device == 'cpu' :
299
+ print ('device ' , device )
300
+ if (not self .config .use_torch_native_for_cpu_tensor and
301
+ device .startswith ('cpu' )):
306
302
return jax .devices ('cpu' )[0 ]
307
- return jax .local_devices ()[0 ]
303
+
304
+ if self .config .treat_cuda_as_jax_device and device .startswith ('cuda' ):
305
+ return jax .local_devices ()[0 ]
306
+
307
+ if device .startswith ('jax' ):
308
+ return jax .local_devices ()[0 ]
309
+
310
+ return None # fallback to torch
311
+
308
312
309
313
310
314
def load_ops (self ):
You can’t perform that action at this time.
0 commit comments