forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 3
Test GitHub app #48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nitins17
wants to merge
611
commits into
main
Choose a base branch
from
srnitin/github-app
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Test GitHub app #48
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…ommit/2930d4d5457abd98a4dd6d682a570f3c2b771b03 PiperOrigin-RevId: 821994830
…mplicit before. This is a pre-requisite to making Pallas/Mosaic GPU the default lowering path for GPU. PiperOrigin-RevId: 822071203
cl/821888942 simplified the `Divides` constraint but accidentally broke the merge semantics (likely due to a bad rebase). This change makes the necessary fixes: - Merging two `Divides` constraints of different lengths yields a new constraint with the smaller length. - Restored a deleted test that verifies the behavior above. - Enhanced the subview test to cover more cases of dynamic offsets. - Removed a couple of TODOs that are no longer relevant. PiperOrigin-RevId: 822105192
This will be necessary to support collective kernels, which need to zero-initialize the semaphores. The old pattern matching code was more fragile than the current implementation and didn't properly account for all the possible ways the arguments and results can be produced from the custom call. PiperOrigin-RevId: 822129973
Fix typo in docstring regarding with_sharding_constraint.
…_names PiperOrigin-RevId: 822144673
PiperOrigin-RevId: 822146746
Fixes an outstanding TODO. PiperOrigin-RevId: 822164120
…_documentation PiperOrigin-RevId: 822194566
…neously). PiperOrigin-RevId: 822202176
PiperOrigin-RevId: 822203593
The test checks that `jax.device_put` from a TPU/GPU sharding to a CPU sharding raises a helpful error, but had neglected the case where the TPU/GPU sharding is fully addressable but the CPU sharding isn't. PiperOrigin-RevId: 822216589
PiperOrigin-RevId: 822246710
…-overlapped-sharding PiperOrigin-RevId: 822249041
…duction to expose more ILP. When in shape invariant mode, fall back to use sequential reduction with single partial accumulator. PiperOrigin-RevId: 822252526
In Pallas, pltpu.repeat is semantically equivalent to jnp.tile, meaning that it repeats the entire input array `repeat` times along `axis`. - Changes the MLIR lowering rule to lower to jnp.tile for consistency with Pallas behavior. - Adds a def_impl rule so that pltpu.repeat can be used outside of a jax.jit context. PiperOrigin-RevId: 822269696
PiperOrigin-RevId: 822387107
PiperOrigin-RevId: 822401100
In some environments only the SLURM_JOB_ID might be set, e.g. when using hooks for SSH to a node with an existing allocation This causes a false positive in the detection and later `KeyError` on e.g. `SLURM_LOCALID`
…ommit/ccdbd127164aa34c9cf47a3c3c4a9f54eac487f0 PiperOrigin-RevId: 822468441
PiperOrigin-RevId: 822501486
The change revolves around how vmap with explicit axis mapped over and shard_map should work. If vmap maps over an explicitly sharded axis i.e. `vmap(f)(arr: f32[8@i, 4])` then `f` does not have access to `i` mesh axis anymore.
This should be true for any `f` but right now let's assume `f = lambda x: shard_map(...)`. With this context, consider this example:
```
@jtu.with_explicit_mesh((2, 2), ('i', 'j'))
def test_explicit_vmap_grad_shmap(self, mesh):
np_inp = np.arange(6 * 24, dtype=np.float32).reshape(6, 24)
arr = jax.device_put(np_inp, P('i', None))
def g(x):
out = jax.shard_map(jnp.cos, in_specs=P('j'), out_specs=P('j'))(x)
return out.sum()
out = jax.jit(jax.vmap(jax.grad(g)))(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j')))
```
**What's the change?**
* **Mesh and Axis names inference change:**
Since axis_names is not mentioned in jax.shard_map, this shard_map should go Manual over all available mesh axis names. But what axis names are available?
Since the shard_map is executing under a vmap, i is not available anymore! So the shard_map will only be Manual over j and Explicit over i.
If you were to do `vmap(lambda x: jax.shard_map(..., axis_names={'i', 'j'}))(arr: f32[8@i, 4])`, this would be an error since you can't go manual over `i` when passing `axis_names` argument to shard_map.
We achieve this by adding `explicit_mesh_axis_names` to `AxisEnv` (similar to spmd_axis_names) and entering into that context when we trace vmap and reading that value when assigning `axis_names` (if not present) to shard_map.
Ideally we just change the mesh when starting vmap tracing so that the mesh itself doesn't have vmapped axis names (I'll try that in a follow up).
* **shard_map's batching rule change:**
This explanation is dependent on the above code example.
**Before:**
In shard_map's batching rule, the new_in_specs and new_out_specs would have been P('i', 'j'). The `i` comes from `axis_data.explicit_mesh_axis`. This means that we were going manual over `i` since `in_specs` of shard_map only talk about `Manual` mesh axis. This seems a bit weird.
**After:**
In shard_map's batching rule, the new_in_specs and new_out_specs would be `P(None, 'j')`. This makes sense because the explicit axis dims would be tracked on the inputs. So concretely the shard_map would look like: `shmap(f, in_specs=P(None, 'j'), out_specs=P(None, 'j'), axis_names={'y'})(in_vals: f32[6@i, 24])`. Given that vmap is tracking mesh axis `i`, we don't have to separately add it to `in_specs` at all. This would result in a partial manual mesh but that's perfectly fine!
This is an internal semantics change but it makes perfect sense to me to do this.
PiperOrigin-RevId: 822563882
PiperOrigin-RevId: 822607889
The error for a shape of length 1 or less is now only raised when `in_axis` is `-2`. This allows `_compute_fans` to be used with 1D/0D shapes when custom axes are specified. PiperOrigin-RevId: 822616111
PiperOrigin-RevId: 822637293
`k` does not affect the shape so it does not need to be static. Fixes jax-ml#32994 PiperOrigin-RevId: 826072062
This change updates the Bazel version used in TensorFlow, JAX, and XLA projects from 7.4.1 to 7.7.0 in `.bazelversion` files and build scripts. PiperOrigin-RevId: 826075658
PiperOrigin-RevId: 826115707
…f to infer the input type (and other metadata like layouts) at which to trace f. Because duck_stuff is flexible, it could comprise actual values, but we don't need to hold onto those values. Yet `traced` currently does hold references to any such values! Create an internal `LowerType` that can be passed to `_resolve_and_lower` via `Traced` and normal jit non-AOT path instead of actual arguments. We still hold onto consts though in `Traced`. Currently for HiJAX, we hold onto real arguments but we'll fix that in a follow up PiperOrigin-RevId: 826122272
…eing held in Traced without converting it to MetaTys. This change makes MetaTys work with HiTypes. Co-authored-by: Matthew Johnson <[email protected]> PiperOrigin-RevId: 826153440
This change introduces an optional `name` string argument to the `pallas_call` primitive and its associated lowering and transformation rules. The `name` is currently passed through but not yet used by the backend-specific lowering implementations for TPU, GPU, and Triton. A follow up CL will use it to extend the name stack on TPU PiperOrigin-RevId: 826174027
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.