Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Jan 24, 2025
1 parent 6822925 commit 3f7c67a
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/torch_xla2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ jobs:
working-directory: torchax
shell: bash
run: |
pytest test/
pytest --forked test/
XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest -n 0 test_dist/
1 change: 1 addition & 0 deletions torchax/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ absl-py
immutabledict
pytest
pytest-xdist
pytest-forked
sentencepiece
expecttest
optax
2 changes: 1 addition & 1 deletion torchax/torchax/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def tolist(self):
return self._elem.tolist()

def shard_(self, sharding):
self.apply_(jax.lax.with_sharding_constraint, sharding)
self.apply_jax_(jax.lax.with_sharding_constraint, sharding)


def debug_accuracy(func, args, kwargs, current_output):
Expand Down

0 comments on commit 3f7c67a

Please sign in to comment.