Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions lib/fray/src/fray/v2/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,11 @@ def with_tpu(tpu_type: str | Sequence[str], *, slice_count: int = 1, **kwargs: A

When ``tpu_type`` is a list, the first entry is canonical (used for
chip_count, env_vars, resource sizing) and the rest are alternatives.
All types in a list must share the same ``vm_count``.
All types in a list must share both ``vm_count`` and ``chips_per_vm``:
a TPU VM is the atomic scheduling unit, so mixing variants with
different per-VM chip counts (e.g. ``v6e-4`` + ``v6e-8``) would let
the scheduler co-locate two partial-VM jobs onto a VM that cannot
actually be shared.
"""
if isinstance(tpu_type, str):
tpu_types = [tpu_type]
Expand All @@ -368,9 +372,16 @@ def with_tpu(tpu_type: str | Sequence[str], *, slice_count: int = 1, **kwargs: A
if not tpu_types:
raise ValueError("tpu_type must be non-empty")

vm_counts = {t: get_tpu_topology(t).vm_count for t in tpu_types}
if len(set(vm_counts.values())) != 1:
raise ValueError(f"All TPU types must have the same vm_count for flexible scheduling. Got: {vm_counts}")
topos = {t: get_tpu_topology(t) for t in tpu_types}
vm_counts = {t: topo.vm_count for t, topo in topos.items()}
chips_per_vm = {t: topo.chips_per_vm for t, topo in topos.items()}
if len(set(vm_counts.values())) != 1 or len(set(chips_per_vm.values())) != 1:
raise ValueError(
"All TPU types in a flexible request must share both vm_count and chips_per_vm. "
f"Got vm_count={vm_counts}, chips_per_vm={chips_per_vm}. "
"Single-VM variants like v6e-8 or v5litepod-8 cannot be mixed with smaller "
"single-VM variants because the VM is indivisible and would be shared between jobs."
)

primary = tpu_types[0]
alternatives = list(tpu_types[1:]) or None
Expand Down
14 changes: 13 additions & 1 deletion lib/fray/tests/test_v2_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,21 @@ def test_multiple_types_sets_alternatives(self):
assert rc.replicas == 1 # both v4-8 and v5p-8 have vm_count=1

def test_mismatched_vm_count_raises(self):
with pytest.raises(ValueError, match="same vm_count"):
with pytest.raises(ValueError, match="vm_count and chips_per_vm"):
ResourceConfig.with_tpu(["v4-8", "v4-16"])

def test_mismatched_chips_per_vm_raises(self):
# v6e-4 and v6e-8 both have vm_count=1 but 4 vs 8 chips per VM;
# the single VM of a v6e-8 is indivisible so these must not mix.
with pytest.raises(ValueError, match="vm_count and chips_per_vm"):
ResourceConfig.with_tpu(["v6e-4", "v6e-8"])

def test_same_chips_per_vm_different_generations_ok(self):
# v4-8 and v5p-8 both have vm_count=1 and chips_per_vm=4.
rc = ResourceConfig.with_tpu(["v4-8", "v5p-8"])
assert rc.device.variant == "v4-8"
assert rc.device_alternatives == ["v5p-8"]

def test_empty_raises(self):
with pytest.raises(ValueError, match="non-empty"):
ResourceConfig.with_tpu([])
Expand Down
82 changes: 82 additions & 0 deletions lib/iris/src/iris/cluster/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,88 @@ def constraints_from_resources(resources: job_pb2.ResourceSpecProto) -> list[Con
return constraints


def validate_tpu_request(
resources: job_pb2.ResourceSpecProto,
constraints: Sequence[Constraint],
) -> str | None:
"""Check that a TPU job's chip count matches the VM shape of every candidate variant.

A TPU VM is the atomic scheduling unit: the scheduler reserves chips from a
worker's advertised capacity, but a single-VM slice (e.g. ``v6e-8``) cannot
be shared between two jobs even if their combined chip count fits.

An explicit ``device-variant`` constraint is authoritative for scheduling
(it replaces the auto-generated constraint from the primary variant), so
we validate the requested chip count against every effective candidate —
not just the primary. This rejects submissions where:

- any candidate variant's ``chips_per_vm`` differs from
``resources.device.tpu.count`` (e.g. primary ``v6e-4`` with
``device-variant EQ v6e-8`` would schedule on a single v6e-8 VM while
reserving only 4 of its 8 chips), or
- an IN constraint lists candidates with mismatched VM shapes
(e.g. ``["v6e-4", "v6e-8"]``).

Returns ``None`` if the request is valid, or a human-readable error
message suitable for returning as ``INVALID_ARGUMENT``.
"""
from iris.cluster.types import TpuTopologyInfo, get_tpu_topology

if not resources.HasField("device") or not resources.device.HasField("tpu"):
return None

primary = resources.device.tpu.variant
if not primary or primary == "auto":
return None

chips_requested = resources.device.tpu.count

# Effective candidates: an explicit device-variant constraint overrides
# the primary. Fall back to the primary when no such constraint exists.
variants: list[str] = [primary]
for c in constraints:
if c.key != WellKnownAttribute.DEVICE_VARIANT:
continue
if c.op == ConstraintOp.IN:
variants = [str(av.value) for av in c.values if av.value]
break
if c.op == ConstraintOp.EQ and c.values:
variants = [str(c.values[0].value)]
break

topos: dict[str, TpuTopologyInfo] = {}
for v in variants:
try:
topos[v] = get_tpu_topology(v)
except ValueError:
continue # unknown variants fall through to the scheduler

if not topos:
return None

mismatched = {
v: topo.chips_per_vm for v, topo in topos.items() if chips_requested and chips_requested != topo.chips_per_vm
}
if mismatched:
return (
f"TPU chip count mismatch: requested {chips_requested} chips per replica, but "
f"candidate variants have chips_per_vm={mismatched}. A TPU VM is indivisible; "
"the per-replica chip count must equal every candidate variant's chips_per_vm."
)

shapes = {v: (topo.vm_count, topo.chips_per_vm) for v, topo in topos.items()}
if len(set(shapes.values())) > 1:
return (
"TPU variant alternatives have incompatible VM shapes: "
f"{ {v: {'vm_count': s[0], 'chips_per_vm': s[1]} for v, s in shapes.items()} }. "
"All candidates must share vm_count and chips_per_vm; single-VM variants like "
"v6e-8 or v5litepod-8 cannot be mixed with smaller variants because their VM is "
"indivisible and would be shared between co-scheduled jobs."
)

return None


# ---------------------------------------------------------------------------
# Executor heuristic: auto-tag small CPU-only jobs as non-preemptible
# ---------------------------------------------------------------------------
Expand Down
10 changes: 9 additions & 1 deletion lib/iris/src/iris/cluster/controller/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from connectrpc.errors import ConnectError
from connectrpc.request import RequestContext

from iris.cluster.constraints import Constraint, constraints_from_resources, merge_constraints
from iris.cluster.constraints import Constraint, constraints_from_resources, merge_constraints, validate_tpu_request
from iris.cluster.redaction import redact_request_env_vars
from iris.cluster.controller.codec import (
constraints_from_json,
Expand Down Expand Up @@ -1145,6 +1145,14 @@ def launch_job(
# device-variant, etc.) replace auto-generated ones.
request = _inject_resource_constraints(request)

# Reject TPU requests whose chip count doesn't match a single VM, or
# whose device-variant alternatives mix incompatible VM shapes (e.g.
# v6e-4 + v6e-8). Co-scheduling jobs onto a single-VM slice like v6e-8
# would put two tenants on one indivisible VM.
tpu_error = validate_tpu_request(request.resources, [Constraint.from_proto(c) for c in request.constraints])
if tpu_error:
raise ConnectError(Code.INVALID_ARGUMENT, tpu_error)

# Reject jobs that can never be scheduled so they fail fast instead
# of sitting in the pending queue. For coscheduled jobs this also
# verifies the replica count is compatible with some group's num_vms.
Expand Down
59 changes: 59 additions & 0 deletions lib/iris/tests/cluster/controller/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,65 @@ def test_launch_job_bundle_blob_rewrites_to_controller_bundle_id(service, state)
assert len(job.bundle_id) == 64


def test_launch_job_rejects_tpu_chip_count_mismatch(service):
"""A job requesting fewer chips than the variant's chips_per_vm is rejected."""
request = make_job_request("bad-tpu-chip-count")
request.resources.device.CopyFrom(tpu_device("v6e-8", count=4))

with pytest.raises(ConnectError) as exc_info:
service.launch_job(request, None)

assert exc_info.value.code == Code.INVALID_ARGUMENT
assert "chip count mismatch" in exc_info.value.message


def test_launch_job_rejects_mixed_vm_shape_alternatives(service):
"""device-variant IN constraint with mismatched chips_per_vm is rejected."""
request = make_job_request("mixed-tpu-variants")
request.resources.device.CopyFrom(tpu_device("v6e-4"))
# User-provided IN constraint that mixes a 4-chip/VM and an 8-chip/VM variant.
request.constraints.append(device_variant_constraint(["v6e-4", "v6e-8"]).to_proto())

with pytest.raises(ConnectError) as exc_info:
service.launch_job(request, None)

assert exc_info.value.code == Code.INVALID_ARGUMENT
# Mismatched shapes necessarily imply a chip-count mismatch for at least one
# candidate, so the per-candidate count check fires first.
assert "chip count mismatch" in exc_info.value.message
assert "v6e-8" in exc_info.value.message


def test_launch_job_rejects_variant_override_with_smaller_primary(service):
"""Explicit device-variant constraint overrides the primary; chip count must match it.

Regression for Codex review: primary v6e-4 (chips_per_vm=4) with an explicit
`device-variant EQ v6e-8` constraint would schedule onto a single v6e-8 VM
while reserving only 4 of its 8 chips — the exact partial-VM collision we
want to block. The validator must check chip count against every effective
candidate, not just the primary.
"""
request = make_job_request("variant-override-mismatch")
request.resources.device.CopyFrom(tpu_device("v6e-4"))
request.constraints.append(device_variant_constraint(["v6e-8"]).to_proto())

with pytest.raises(ConnectError) as exc_info:
service.launch_job(request, None)

assert exc_info.value.code == Code.INVALID_ARGUMENT
assert "chip count mismatch" in exc_info.value.message


def test_launch_job_accepts_same_shape_alternatives(service):
"""Alternatives sharing vm_count/chips_per_vm (e.g. v4-8 + v5p-8) are accepted."""
request = make_job_request("matched-tpu-variants")
request.resources.device.CopyFrom(tpu_device("v4-8"))
request.constraints.append(device_variant_constraint(["v4-8", "v5p-8"]).to_proto())

response = service.launch_job(request, None)
assert response.job_id == JobName.root("test-user", "matched-tpu-variants").to_wire()


def test_launch_job_rejects_duplicate_name(service):
"""Verify launch_job rejects duplicate job names for running jobs."""
request = make_job_request("duplicate-job")
Expand Down
Loading