@@ -1568,21 +1568,21 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
15681568 'Please see the jax.Array migration guide for more information '
15691569 'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
15701570 f'Got arg shape: { arg .shape } , arg value: { arg } ' )
1571- if not is_unspecified (arg_s ):
1571+ if not isinstance (arg_s , UnspecifiedValue ):
15721572 # jax.jit does not allow resharding across different memory kinds even
15731573 # if the argument is uncommitted. Use jax.device_put for those cases,
15741574 # either outside or inside jax.jit.
15751575 if pjit_in_s .memory_kind != arg_s .memory_kind : # type: ignore
15761576 raise ValueError (
15771577 'Memory kinds passed to jax.jit does not match memory kind on the'
15781578 f' respective arg. Got pjit memory kind: { pjit_in_s .memory_kind } , ' # type: ignore
1579- f'arg memory kind: { arg_s .memory_kind } for ' # pytype: disable=attribute-error
1579+ f'arg memory kind: { arg_s .memory_kind } for '
15801580 f'arg shape: { shaped_abstractify (arg ).str_short ()} ' )
15811581 if (committed and
15821582 not isinstance (arg_s , PmapSharding ) and
15831583 not op_shardings .are_op_shardings_equal (
15841584 pjit_in_s ._to_xla_hlo_sharding (arg .ndim ), # type: ignore
1585- arg_s ._to_xla_hlo_sharding (arg .ndim ))): # type: ignore
1585+ arg_s ._to_xla_hlo_sharding (arg .ndim ))):
15861586 raise ValueError ('Sharding passed to pjit does not match the sharding '
15871587 'on the respective arg. '
15881588 f'Got pjit sharding: { pjit_in_s } ,\n '
0 commit comments