Skip to content

Commit 31e8ee3

Browse files
authored
iris, fray: reject TPU requests whose chip count doesn't match VM shape (#4791)
## Summary A TPU VM is the atomic scheduling unit, but neither the Fray client nor the Iris controller rejected requests where the per-replica chip count differed from the variant's `chips_per_vm`. The concrete failure mode from the user report: > I keep getting scheduled to nodes that already have something running on them... This particular job only needs 4 chips and says so in the request, but I included v6e-8 as a possible TPU type to claim. A job submitted with `device_alternatives=[\"v6e-4\", \"v6e-8\"]` passed the old `vm_count`-only check (both are vm_count=1), reserved 4 chips per replica against the primary, and then landed on a v6e-8 worker that advertises 8 chips. The scheduler saw 4 free chips and co-scheduled a second 4-chip job onto the same indivisible VM — two tenants colliding on one JAX host. The diagram: ``` with_tpu([\"v6e-4\", \"v6e-8\"]) ← old check passes (vm_count both = 1) primary = v6e-4 → reserve chips_per_vm = 4 ↓ lands on v6e-8 worker (advertises 8 chips, 1 VM) reserved=4, free=4 ✘ scheduler thinks VM is half-free → second 4-chip job co-scheduled onto the same VM → collision ``` Tighten validation at both ends: - **fray** (`ResourceConfig.with_tpu`): candidates must share both `vm_count` _and_ `chips_per_vm`. `[v4-8, v5p-8]` (both 1×4) still works; `[v6e-4, v6e-8]` (1×4 vs 1×8) now fails fast with a clear message. - **iris** (`launch_job` ingestion): new `validate_tpu_request()` helper runs right after constraint injection and rejects chip-count / VM-shape mismatches with `INVALID_ARGUMENT`, so older or hand-rolled clients can't bypass the fray-side check.
1 parent 913e579 commit 31e8ee3

File tree

5 files changed

+178
-6
lines changed

5 files changed

+178
-6
lines changed

lib/fray/src/fray/v2/types.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,11 @@ def with_tpu(tpu_type: str | Sequence[str], *, slice_count: int = 1, **kwargs: A
358358
359359
When ``tpu_type`` is a list, the first entry is canonical (used for
360360
chip_count, env_vars, resource sizing) and the rest are alternatives.
361-
All types in a list must share the same ``vm_count``.
361+
All types in a list must share both ``vm_count`` and ``chips_per_vm``:
362+
a TPU VM is the atomic scheduling unit, so mixing variants with
363+
different per-VM chip counts (e.g. ``v6e-4`` + ``v6e-8``) would let
364+
the scheduler co-locate two partial-VM jobs onto a VM that cannot
365+
actually be shared.
362366
"""
363367
if isinstance(tpu_type, str):
364368
tpu_types = [tpu_type]
@@ -368,9 +372,16 @@ def with_tpu(tpu_type: str | Sequence[str], *, slice_count: int = 1, **kwargs: A
368372
if not tpu_types:
369373
raise ValueError("tpu_type must be non-empty")
370374

371-
vm_counts = {t: get_tpu_topology(t).vm_count for t in tpu_types}
372-
if len(set(vm_counts.values())) != 1:
373-
raise ValueError(f"All TPU types must have the same vm_count for flexible scheduling. Got: {vm_counts}")
375+
topos = {t: get_tpu_topology(t) for t in tpu_types}
376+
vm_counts = {t: topo.vm_count for t, topo in topos.items()}
377+
chips_per_vm = {t: topo.chips_per_vm for t, topo in topos.items()}
378+
if len(set(vm_counts.values())) != 1 or len(set(chips_per_vm.values())) != 1:
379+
raise ValueError(
380+
"All TPU types in a flexible request must share both vm_count and chips_per_vm. "
381+
f"Got vm_count={vm_counts}, chips_per_vm={chips_per_vm}. "
382+
"Single-VM variants like v6e-8 or v5litepod-8 cannot be mixed with smaller "
383+
"single-VM variants because the VM is indivisible and would be shared between jobs."
384+
)
374385

375386
primary = tpu_types[0]
376387
alternatives = list(tpu_types[1:]) or None

lib/fray/tests/test_v2_iris.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,21 @@ def test_multiple_types_sets_alternatives(self):
243243
assert rc.replicas == 1 # both v4-8 and v5p-8 have vm_count=1
244244

245245
def test_mismatched_vm_count_raises(self):
246-
with pytest.raises(ValueError, match="same vm_count"):
246+
with pytest.raises(ValueError, match="vm_count and chips_per_vm"):
247247
ResourceConfig.with_tpu(["v4-8", "v4-16"])
248248

249+
def test_mismatched_chips_per_vm_raises(self):
250+
# v6e-4 and v6e-8 both have vm_count=1 but 4 vs 8 chips per VM;
251+
# the single VM of a v6e-8 is indivisible so these must not mix.
252+
with pytest.raises(ValueError, match="vm_count and chips_per_vm"):
253+
ResourceConfig.with_tpu(["v6e-4", "v6e-8"])
254+
255+
def test_same_chips_per_vm_different_generations_ok(self):
256+
# v4-8 and v5p-8 both have vm_count=1 and chips_per_vm=4.
257+
rc = ResourceConfig.with_tpu(["v4-8", "v5p-8"])
258+
assert rc.device.variant == "v4-8"
259+
assert rc.device_alternatives == ["v5p-8"]
260+
249261
def test_empty_raises(self):
250262
with pytest.raises(ValueError, match="non-empty"):
251263
ResourceConfig.with_tpu([])

lib/iris/src/iris/cluster/constraints.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,88 @@ def constraints_from_resources(resources: job_pb2.ResourceSpecProto) -> list[Con
715715
return constraints
716716

717717

718+
def validate_tpu_request(
719+
resources: job_pb2.ResourceSpecProto,
720+
constraints: Sequence[Constraint],
721+
) -> str | None:
722+
"""Check that a TPU job's chip count matches the VM shape of every candidate variant.
723+
724+
A TPU VM is the atomic scheduling unit: the scheduler reserves chips from a
725+
worker's advertised capacity, but a single-VM slice (e.g. ``v6e-8``) cannot
726+
be shared between two jobs even if their combined chip count fits.
727+
728+
An explicit ``device-variant`` constraint is authoritative for scheduling
729+
(it replaces the auto-generated constraint from the primary variant), so
730+
we validate the requested chip count against every effective candidate —
731+
not just the primary. This rejects submissions where:
732+
733+
- any candidate variant's ``chips_per_vm`` differs from
734+
``resources.device.tpu.count`` (e.g. primary ``v6e-4`` with
735+
``device-variant EQ v6e-8`` would schedule on a single v6e-8 VM while
736+
reserving only 4 of its 8 chips), or
737+
- an IN constraint lists candidates with mismatched VM shapes
738+
(e.g. ``["v6e-4", "v6e-8"]``).
739+
740+
Returns ``None`` if the request is valid, or a human-readable error
741+
message suitable for returning as ``INVALID_ARGUMENT``.
742+
"""
743+
from iris.cluster.types import TpuTopologyInfo, get_tpu_topology
744+
745+
if not resources.HasField("device") or not resources.device.HasField("tpu"):
746+
return None
747+
748+
primary = resources.device.tpu.variant
749+
if not primary or primary == "auto":
750+
return None
751+
752+
chips_requested = resources.device.tpu.count
753+
754+
# Effective candidates: an explicit device-variant constraint overrides
755+
# the primary. Fall back to the primary when no such constraint exists.
756+
variants: list[str] = [primary]
757+
for c in constraints:
758+
if c.key != WellKnownAttribute.DEVICE_VARIANT:
759+
continue
760+
if c.op == ConstraintOp.IN:
761+
variants = [str(av.value) for av in c.values if av.value]
762+
break
763+
if c.op == ConstraintOp.EQ and c.values:
764+
variants = [str(c.values[0].value)]
765+
break
766+
767+
topos: dict[str, TpuTopologyInfo] = {}
768+
for v in variants:
769+
try:
770+
topos[v] = get_tpu_topology(v)
771+
except ValueError:
772+
continue # unknown variants fall through to the scheduler
773+
774+
if not topos:
775+
return None
776+
777+
mismatched = {
778+
v: topo.chips_per_vm for v, topo in topos.items() if chips_requested and chips_requested != topo.chips_per_vm
779+
}
780+
if mismatched:
781+
return (
782+
f"TPU chip count mismatch: requested {chips_requested} chips per replica, but "
783+
f"candidate variants have chips_per_vm={mismatched}. A TPU VM is indivisible; "
784+
"the per-replica chip count must equal every candidate variant's chips_per_vm."
785+
)
786+
787+
shapes = {v: (topo.vm_count, topo.chips_per_vm) for v, topo in topos.items()}
788+
if len(set(shapes.values())) > 1:
789+
return (
790+
"TPU variant alternatives have incompatible VM shapes: "
791+
f"{ {v: {'vm_count': s[0], 'chips_per_vm': s[1]} for v, s in shapes.items()} }. "
792+
"All candidates must share vm_count and chips_per_vm; single-VM variants like "
793+
"v6e-8 or v5litepod-8 cannot be mixed with smaller variants because their VM is "
794+
"indivisible and would be shared between co-scheduled jobs."
795+
)
796+
797+
return None
798+
799+
718800
# ---------------------------------------------------------------------------
719801
# Executor heuristic: auto-tag small CPU-only jobs as non-preemptible
720802
# ---------------------------------------------------------------------------

lib/iris/src/iris/cluster/controller/service.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from connectrpc.errors import ConnectError
2222
from connectrpc.request import RequestContext
2323

24-
from iris.cluster.constraints import Constraint, constraints_from_resources, merge_constraints
24+
from iris.cluster.constraints import Constraint, constraints_from_resources, merge_constraints, validate_tpu_request
2525
from iris.cluster.redaction import redact_request_env_vars
2626
from iris.cluster.controller.codec import (
2727
constraints_from_json,
@@ -1135,6 +1135,14 @@ def launch_job(
11351135
# device-variant, etc.) replace auto-generated ones.
11361136
request = _inject_resource_constraints(request)
11371137

1138+
# Reject TPU requests whose chip count doesn't match a single VM, or
1139+
# whose device-variant alternatives mix incompatible VM shapes (e.g.
1140+
# v6e-4 + v6e-8). Co-scheduling jobs onto a single-VM slice like v6e-8
1141+
# would put two tenants on one indivisible VM.
1142+
tpu_error = validate_tpu_request(request.resources, [Constraint.from_proto(c) for c in request.constraints])
1143+
if tpu_error:
1144+
raise ConnectError(Code.INVALID_ARGUMENT, tpu_error)
1145+
11381146
# Reject jobs that can never be scheduled so they fail fast instead
11391147
# of sitting in the pending queue. For coscheduled jobs this also
11401148
# verifies the replica count is compatible with some group's num_vms.

lib/iris/tests/cluster/controller/test_service.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,65 @@ def test_launch_job_bundle_blob_rewrites_to_controller_bundle_id(service, state)
124124
assert len(job.bundle_id) == 64
125125

126126

127+
def test_launch_job_rejects_tpu_chip_count_mismatch(service):
128+
"""A job requesting fewer chips than the variant's chips_per_vm is rejected."""
129+
request = make_job_request("bad-tpu-chip-count")
130+
request.resources.device.CopyFrom(tpu_device("v6e-8", count=4))
131+
132+
with pytest.raises(ConnectError) as exc_info:
133+
service.launch_job(request, None)
134+
135+
assert exc_info.value.code == Code.INVALID_ARGUMENT
136+
assert "chip count mismatch" in exc_info.value.message
137+
138+
139+
def test_launch_job_rejects_mixed_vm_shape_alternatives(service):
140+
"""device-variant IN constraint with mismatched chips_per_vm is rejected."""
141+
request = make_job_request("mixed-tpu-variants")
142+
request.resources.device.CopyFrom(tpu_device("v6e-4"))
143+
# User-provided IN constraint that mixes a 4-chip/VM and an 8-chip/VM variant.
144+
request.constraints.append(device_variant_constraint(["v6e-4", "v6e-8"]).to_proto())
145+
146+
with pytest.raises(ConnectError) as exc_info:
147+
service.launch_job(request, None)
148+
149+
assert exc_info.value.code == Code.INVALID_ARGUMENT
150+
# Mismatched shapes necessarily imply a chip-count mismatch for at least one
151+
# candidate, so the per-candidate count check fires first.
152+
assert "chip count mismatch" in exc_info.value.message
153+
assert "v6e-8" in exc_info.value.message
154+
155+
156+
def test_launch_job_rejects_variant_override_with_smaller_primary(service):
157+
"""Explicit device-variant constraint overrides the primary; chip count must match it.
158+
159+
Regression for Codex review: primary v6e-4 (chips_per_vm=4) with an explicit
160+
`device-variant EQ v6e-8` constraint would schedule onto a single v6e-8 VM
161+
while reserving only 4 of its 8 chips — the exact partial-VM collision we
162+
want to block. The validator must check chip count against every effective
163+
candidate, not just the primary.
164+
"""
165+
request = make_job_request("variant-override-mismatch")
166+
request.resources.device.CopyFrom(tpu_device("v6e-4"))
167+
request.constraints.append(device_variant_constraint(["v6e-8"]).to_proto())
168+
169+
with pytest.raises(ConnectError) as exc_info:
170+
service.launch_job(request, None)
171+
172+
assert exc_info.value.code == Code.INVALID_ARGUMENT
173+
assert "chip count mismatch" in exc_info.value.message
174+
175+
176+
def test_launch_job_accepts_same_shape_alternatives(service):
177+
"""Alternatives sharing vm_count/chips_per_vm (e.g. v4-8 + v5p-8) are accepted."""
178+
request = make_job_request("matched-tpu-variants")
179+
request.resources.device.CopyFrom(tpu_device("v4-8"))
180+
request.constraints.append(device_variant_constraint(["v4-8", "v5p-8"]).to_proto())
181+
182+
response = service.launch_job(request, None)
183+
assert response.job_id == JobName.root("test-user", "matched-tpu-variants").to_wire()
184+
185+
127186
def test_launch_job_rejects_duplicate_name(service):
128187
"""Verify launch_job rejects duplicate job names for running jobs."""
129188
request = make_job_request("duplicate-job")

0 commit comments

Comments
 (0)