Skip to content

Conversation

@nitins17
Copy link

@nitins17 nitins17 commented Oct 9, 2025

No description provided.

Google-ML-Automation and others added 30 commits October 21, 2025 01:05
…mplicit before.

This is a pre-requisite to making Pallas/Mosaic GPU the default lowering path
for GPU.

PiperOrigin-RevId: 822071203
cl/821888942 simplified the `Divides` constraint but accidentally broke the merge semantics (likely due to a bad rebase). This change makes the necessary fixes:
- Merging two `Divides` constraints of different lengths yields a new constraint with the smaller length.
- Restored a deleted test that verifies the behavior above.
- Enhanced the subview test to cover more cases of dynamic offsets.
- Removed a couple of TODOs that are no longer relevant.

PiperOrigin-RevId: 822105192
This will be necessary to support collective kernels, which need to zero-initialize the semaphores.
The old pattern matching code was more fragile than the current implementation and didn't properly
account for all the possible ways the arguments and results can be produced from the custom call.

PiperOrigin-RevId: 822129973
Fix typo in docstring regarding with_sharding_constraint.
Fixes an outstanding TODO.

PiperOrigin-RevId: 822164120
…_documentation

PiperOrigin-RevId: 822194566
The test checks that `jax.device_put` from a TPU/GPU sharding to a CPU sharding raises a helpful error, but had neglected the case where the TPU/GPU sharding is fully addressable but the CPU sharding isn't.

PiperOrigin-RevId: 822216589
…-overlapped-sharding

PiperOrigin-RevId: 822249041
…duction to expose more ILP. When in shape invariant mode, fall back to use sequential reduction with single partial accumulator.

PiperOrigin-RevId: 822252526
In Pallas, pltpu.repeat is semantically equivalent to jnp.tile, meaning that it repeats the entire input array `repeat` times along `axis`.
- Changes the MLIR lowering rule to lower to jnp.tile for consistency with Pallas behavior.
- Adds a def_impl rule so that pltpu.repeat can be used outside of a jax.jit context.

PiperOrigin-RevId: 822269696
In some environments only the SLURM_JOB_ID might be set, e.g. when using hooks for SSH to a node with an existing allocation

This causes a false positive in the detection and later `KeyError` on e.g. `SLURM_LOCALID`
PiperOrigin-RevId: 822501486
The change revolves around how vmap with explicit axis mapped over and shard_map should work. If vmap maps over an explicitly sharded axis i.e. `vmap(f)(arr: f32[8@i, 4])` then `f` does not have access to `i` mesh axis anymore.

This should be true for any `f` but right now let's assume `f = lambda x: shard_map(...)`. With this context, consider this example:

```
@jtu.with_explicit_mesh((2, 2), ('i', 'j'))
def test_explicit_vmap_grad_shmap(self, mesh):
  np_inp = np.arange(6 * 24, dtype=np.float32).reshape(6, 24)
  arr = jax.device_put(np_inp, P('i', None))

  def g(x):
    out = jax.shard_map(jnp.cos, in_specs=P('j'), out_specs=P('j'))(x)
    return out.sum()

  out = jax.jit(jax.vmap(jax.grad(g)))(arr)
  self.assertEqual(out.sharding, NamedSharding(mesh, P('i', 'j')))
```

**What's the change?**

* **Mesh and Axis names inference change:**

  Since axis_names is not mentioned in jax.shard_map, this shard_map should go Manual over all available mesh axis names. But what axis names are available?

  Since the shard_map is executing under a vmap, i is not available anymore! So the shard_map will only be Manual over j and Explicit over i.

  If you were to do `vmap(lambda x: jax.shard_map(..., axis_names={'i', 'j'}))(arr: f32[8@i, 4])`, this would be an error since you can't go manual over `i` when passing `axis_names` argument to shard_map.

  We achieve this by adding `explicit_mesh_axis_names` to `AxisEnv` (similar to spmd_axis_names) and entering into that context when we trace vmap and reading that value when assigning `axis_names` (if not present) to shard_map.

  Ideally we just change the mesh when starting vmap tracing so that the mesh itself doesn't have vmapped axis names (I'll try that in a follow up).

* **shard_map's batching rule change:**

  This explanation is dependent on the above code example.

  **Before:**

  In shard_map's batching rule, the new_in_specs and new_out_specs would have been P('i', 'j'). The `i` comes from `axis_data.explicit_mesh_axis`. This means that we were going manual over `i` since `in_specs` of shard_map only talk about `Manual` mesh axis. This seems a bit weird.

  **After:**

  In shard_map's batching rule, the new_in_specs and new_out_specs would be `P(None, 'j')`. This makes sense because the explicit axis dims would be tracked on the inputs. So concretely the shard_map would look like: `shmap(f, in_specs=P(None, 'j'), out_specs=P(None, 'j'), axis_names={'y'})(in_vals: f32[6@i, 24])`. Given that vmap is tracking mesh axis `i`, we don't have to separately add it to `in_specs` at all. This would result in a partial manual mesh but that's perfectly fine!
This is an internal semantics change but it makes perfect sense to me to do this.

PiperOrigin-RevId: 822563882
The error for a shape of length 1 or less is now only raised when `in_axis` is `-2`. This allows `_compute_fans` to be used with 1D/0D shapes when custom axes are specified.

PiperOrigin-RevId: 822616111
hawkinsp and others added 30 commits October 30, 2025 09:45
`k` does not affect the shape so it does not need to be static.

Fixes jax-ml#32994

PiperOrigin-RevId: 826072062
This change updates the Bazel version used in TensorFlow, JAX, and XLA projects from 7.4.1 to 7.7.0 in `.bazelversion` files and build scripts.

PiperOrigin-RevId: 826075658
PiperOrigin-RevId: 826115707
…f to infer the input type (and other metadata like layouts) at which to trace f. Because duck_stuff is flexible, it could comprise actual values, but we don't need to hold onto those values. Yet `traced` currently does hold references to any such values!

Create an internal `LowerType` that can be passed to `_resolve_and_lower` via `Traced` and normal jit non-AOT path instead of actual arguments. We still hold onto consts though in `Traced`.

Currently for HiJAX, we hold onto real arguments but we'll fix that in a follow up

PiperOrigin-RevId: 826122272
…eing held in Traced without converting it to MetaTys.

This change makes MetaTys work with HiTypes.

Co-authored-by: Matthew Johnson <[email protected]>
PiperOrigin-RevId: 826153440
This change introduces an optional `name` string argument to the `pallas_call` primitive and its associated lowering and transformation rules. The `name` is currently passed through but not yet used by the backend-specific lowering implementations for TPU, GPU, and Triton.

A follow up CL will use it to extend the name stack on TPU

PiperOrigin-RevId: 826174027
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.