diff --git a/.pyrefly-baseline.json b/.pyrefly-baseline.json index bfc50106c6..8d8e2b3ca1 100644 --- a/.pyrefly-baseline.json +++ b/.pyrefly-baseline.json @@ -1,89 +1,5 @@ { "errors": [ - { - "line": 126, - "column": 5, - "stop_line": 126, - "stop_column": 16, - "path": "lib/fray/src/fray/v1/cluster/__init__.py", - "code": -2, - "name": "bad-dunder-all", - "description": "Name `TPUConfig` is listed in `__all__` but is not defined in the module", - "concise_description": "Name `TPUConfig` is listed in `__all__` but is not defined in the module", - "severity": "error" - }, - { - "line": 435, - "column": 9, - "stop_line": 435, - "stop_column": 22, - "path": "lib/fray/src/fray/v1/cluster/base.py", - "code": -2, - "name": "invalid-annotation", - "description": "`Self` cannot be used in a static method", - "concise_description": "`Self` cannot be used in a static method", - "severity": "error" - }, - { - "line": 443, - "column": 9, - "stop_line": 443, - "stop_column": 20, - "path": "lib/fray/src/fray/v1/cluster/base.py", - "code": -2, - "name": "invalid-annotation", - "description": "`Self` cannot be used in a static method", - "concise_description": "`Self` cannot be used in a static method", - "severity": "error" - }, - { - "line": 482, - "column": 9, - "stop_line": 482, - "stop_column": 17, - "path": "lib/fray/src/fray/v1/cluster/base.py", - "code": -2, - "name": "invalid-annotation", - "description": "`Self` cannot be used in a static method", - "concise_description": "`Self` cannot be used in a static method", - "severity": "error" - }, - { - "line": 30, - "column": 5, - "stop_line": 30, - "stop_column": 18, - "path": "lib/fray/src/fray/v1/cluster/ray/tpu/__init__.py", - "code": -2, - "name": "bad-dunder-all", - "description": "Name `TPU_CONFIGS` is listed in `__all__` but is not defined in the module", - "concise_description": "Name `TPU_CONFIGS` is listed in `__all__` but is not defined in the module", - "severity": "error" - }, - { - "line": 36, - "column": 5, - "stop_line": 36, - "stop_column": 16, - "path": "lib/fray/src/fray/v1/cluster/ray/tpu/__init__.py", - "code": -2, - "name": "bad-dunder-all", - "description": "Name `TPUConfig` is listed in `__all__` but is not defined in the module", - "concise_description": "Name `TPUConfig` is listed in `__all__` but is not defined in the module", - "severity": "error" - }, - { - "line": 45, - "column": 5, - "stop_line": 45, - "stop_column": 21, - "path": "lib/fray/src/fray/v1/cluster/ray/tpu/__init__.py", - "code": -2, - "name": "bad-dunder-all", - "description": "Name `get_tpu_config` is listed in `__all__` but is not defined in the module", - "concise_description": "Name `get_tpu_config` is listed in `__all__` but is not defined in the module", - "severity": "error" - }, { "line": 358, "column": 37, @@ -96,54 +12,6 @@ "concise_description": "Cannot set item in `dict[Future[Unknown], GeneratorFuture]`", "severity": "error" }, - { - "line": 35, - "column": 9, - "stop_line": 35, - "stop_column": 13, - "path": "lib/fray/src/fray/v1/queue/base.py", - "code": -2, - "name": "invalid-variance", - "description": "Type variable `T_co` is Covariant but is used in contravariant position", - "concise_description": "Type variable `T_co` is Covariant but is used in contravariant position", - "severity": "error" - }, - { - "line": 61, - "column": 9, - "stop_line": 61, - "stop_column": 13, - "path": "lib/fray/src/fray/v1/queue/base.py", - "code": -2, - "name": "invalid-variance", - "description": "Type variable `T_co` is Covariant but is used in contravariant position", - "concise_description": "Type variable `T_co` is Covariant but is used in contravariant position", - "severity": "error" - }, - { - "line": 508, - "column": 9, - "stop_line": 508, - "stop_column": 22, - "path": "lib/fray/src/fray/v2/types.py", - "code": -2, - "name": "invalid-annotation", - "description": "`Self` cannot be used in a static method", - "concise_description": "`Self` cannot be used in a static method", - "severity": "error" - }, - { - "line": 516, - "column": 9, - "stop_line": 516, - "stop_column": 20, - "path": "lib/fray/src/fray/v2/types.py", - "code": -2, - "name": "invalid-annotation", - "description": "`Self` cannot be used in a static method", - "concise_description": "`Self` cannot be used in a static method", - "severity": "error" - }, { "line": 108, "column": 47, @@ -312,18 +180,6 @@ "concise_description": "Returned type `tuple[tuple[str, ...], list[Unknown]]` is not assignable to declared return type `tuple[AxisSpec, list[Unknown]]`", "severity": "error" }, - { - "line": 360, - "column": 35, - "stop_line": 360, - "stop_column": 47, - "path": "lib/haliax/src/haliax/jax_utils.py", - "code": -2, - "name": "missing-module-attribute", - "description": "Could not import `ensure_tuple` from `haliax.core`", - "concise_description": "Could not import `ensure_tuple` from `haliax.core`", - "severity": "error" - }, { "line": 30, "column": 9, @@ -625,9 +481,21 @@ "severity": "error" }, { - "line": 2294, + "line": 697, + "column": 35, + "stop_line": 697, + "stop_column": 70, + "path": "lib/iris/src/iris/cluster/controller/controller.py", + "code": -2, + "name": "bad-specialization", + "description": "`WorkerSnapshot` is not assignable to upper bound `DataclassInstance` of type variable `_DataclassT`", + "concise_description": "`WorkerSnapshot` is not assignable to upper bound `DataclassInstance` of type variable `_DataclassT`", + "severity": "error" + }, + { + "line": 2297, "column": 26, - "stop_line": 2296, + "stop_line": 2299, "stop_column": 18, "path": "lib/iris/src/iris/cluster/controller/controller.py", "code": -2, @@ -672,30 +540,6 @@ "concise_description": "Returned type `None` is not assignable to declared return type `LevConfig`", "severity": "error" }, - { - "line": 332, - "column": 14, - "stop_line": 332, - "stop_column": 21, - "path": "lib/levanter/src/levanter/data/dataset.py", - "code": -2, - "name": "invalid-type-var", - "description": "Attribute `dataset` cannot depend on type variable `T`, which is not in the scope of class `BatchMappedAsyncDataset`", - "concise_description": "Attribute `dataset` cannot depend on type variable `T`, which is not in the scope of class `BatchMappedAsyncDataset`", - "severity": "error" - }, - { - "line": 534, - "column": 14, - "stop_line": 534, - "stop_column": 16, - "path": "lib/levanter/src/levanter/data/sharded_datasource.py", - "code": -2, - "name": "invalid-type-var", - "description": "Attribute `fn` cannot depend on type variable `T_co`, which is not in the scope of class `_MappedShardedDataSource`", - "concise_description": "Attribute `fn` cannot depend on type variable `T_co`, which is not in the scope of class `_MappedShardedDataSource`", - "severity": "error" - }, { "line": 414, "column": 25, @@ -1488,42 +1332,6 @@ "concise_description": "`Self@WhisperModel` is not assignable to upper bound `DataclassInstance` of type variable `_DataclassT`", "severity": "error" }, - { - "line": 42, - "column": 9, - "stop_line": 42, - "stop_column": 15, - "path": "lib/levanter/src/levanter/optim/model_averaging.py", - "code": -2, - "name": "bad-param-name-override", - "description": "Class member `EmaModelAveraging.update` overrides parent class `ModelAveraging` in an inconsistent manner\n Got parameter name `new_model`, expected `model`", - "concise_description": "Class member `EmaModelAveraging.update` overrides parent class `ModelAveraging` in an inconsistent manner", - "severity": "error" - }, - { - "line": 73, - "column": 9, - "stop_line": 73, - "stop_column": 15, - "path": "lib/levanter/src/levanter/optim/model_averaging.py", - "code": -2, - "name": "bad-param-name-override", - "description": "Class member `EmaDecaySqrtModelAveraging.update` overrides parent class `ModelAveraging` in an inconsistent manner\n Got parameter name `new_model`, expected `model`", - "concise_description": "Class member `EmaDecaySqrtModelAveraging.update` overrides parent class `ModelAveraging` in an inconsistent manner", - "severity": "error" - }, - { - "line": 235, - "column": 14, - "stop_line": 235, - "stop_column": 23, - "path": "lib/levanter/src/levanter/store/cache.py", - "code": -2, - "name": "invalid-type-var", - "description": "Attribute `_exemplar` cannot depend on type variable `T`, which is not in the scope of class `SerialCacheWriter`", - "concise_description": "Attribute `_exemplar` cannot depend on type variable `T`, which is not in the scope of class `SerialCacheWriter`", - "severity": "error" - }, { "line": 527, "column": 13, @@ -1560,138 +1368,6 @@ "concise_description": "`S` is not assignable to upper bound `DataclassInstance` of type variable `_DataclassT`", "severity": "error" }, - { - "line": 129, - "column": 9, - "stop_line": 129, - "stop_column": 16, - "path": "lib/levanter/src/levanter/utils/py_utils.py", - "code": -2, - "name": "bad-param-name-override", - "description": "Class member `FailSafeJSONEncoder.default` overrides parent class `JSONEncoder` in an inconsistent manner\n Got parameter name `obj`, expected `o`", - "concise_description": "Class member `FailSafeJSONEncoder.default` overrides parent class `JSONEncoder` in an inconsistent manner", - "severity": "error" - }, - { - "line": 416, - "column": 52, - "stop_line": 416, - "stop_column": 62, - "path": "lib/marin/src/marin/cluster/config.py", - "code": -2, - "name": "bad-typed-dict-key", - "description": "Invalid key for TypedDict ``, got `dict[str, int]`", - "concise_description": "Invalid key for TypedDict ``, got `dict[str, int]`", - "severity": "error" - }, - { - "line": 416, - "column": 52, - "stop_line": 416, - "stop_column": 62, - "path": "lib/marin/src/marin/cluster/config.py", - "code": -2, - "name": "bad-typed-dict-key", - "description": "Invalid key for TypedDict ``, got `dict[Unknown, Unknown]`", - "concise_description": "Invalid key for TypedDict ``, got `dict[Unknown, Unknown]`", - "severity": "error" - }, - { - "line": 416, - "column": 52, - "stop_line": 416, - "stop_column": 62, - "path": "lib/marin/src/marin/cluster/config.py", - "code": -2, - "name": "bad-typed-dict-key", - "description": "Invalid key for TypedDict ``, got `int`", - "concise_description": "Invalid key for TypedDict ``, got `int`", - "severity": "error" - }, - { - "line": 155, - "column": 52, - "stop_line": 155, - "stop_column": 61, - "path": "lib/marin/src/marin/execution/executor.py", - "code": -2, - "name": "not-a-type", - "description": "Expected a type form, got instance of `Overload[\n [_T](cls: type[_T], /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> type[_T]\n (cls: None = None, /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> [_T](type[_T]) -> type[_T]\n]`", - "concise_description": "Expected a type form, got instance of `Overload[\n [_T](cls: type[_T], /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> type[_T]\n (cls: None = None, /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> [_T](type[_T]) -> type[_T]\n]`", - "severity": "error" - }, - { - "line": 634, - "column": 37, - "stop_line": 634, - "stop_column": 46, - "path": "lib/marin/src/marin/execution/executor.py", - "code": -2, - "name": "not-a-type", - "description": "Expected a type form, got instance of `Overload[\n [_T](cls: type[_T], /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> type[_T]\n (cls: None = None, /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> [_T](type[_T]) -> type[_T]\n]`", - "concise_description": "Expected a type form, got instance of `Overload[\n [_T](cls: type[_T], /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> type[_T]\n (cls: None = None, /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> [_T](type[_T]) -> type[_T]\n]`", - "severity": "error" - }, - { - "line": 973, - "column": 13, - "stop_line": 973, - "stop_column": 22, - "path": "lib/marin/src/marin/execution/executor.py", - "code": -2, - "name": "not-a-type", - "description": "Expected a type form, got instance of `Overload[\n [_T](cls: type[_T], /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> type[_T]\n (cls: None = None, /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> [_T](type[_T]) -> type[_T]\n]`", - "concise_description": "Expected a type form, got instance of `Overload[\n [_T](cls: type[_T], /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> type[_T]\n (cls: None = None, /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> [_T](type[_T]) -> type[_T]\n]`", - "severity": "error" - }, - { - "line": 1137, - "column": 13, - "stop_line": 1137, - "stop_column": 22, - "path": "lib/marin/src/marin/execution/executor.py", - "code": -2, - "name": "not-a-type", - "description": "Expected a type form, got instance of `Overload[\n [_T](cls: type[_T], /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> type[_T]\n (cls: None = None, /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> [_T](type[_T]) -> type[_T]\n]`", - "concise_description": "Expected a type form, got instance of `Overload[\n [_T](cls: type[_T], /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> type[_T]\n (cls: None = None, /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> [_T](type[_T]) -> type[_T]\n]`", - "severity": "error" - }, - { - "line": 1138, - "column": 6, - "stop_line": 1138, - "stop_column": 15, - "path": "lib/marin/src/marin/execution/executor.py", - "code": -2, - "name": "not-a-type", - "description": "Expected a type form, got instance of `Overload[\n [_T](cls: type[_T], /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> type[_T]\n (cls: None = None, /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> [_T](type[_T]) -> type[_T]\n]`", - "concise_description": "Expected a type form, got instance of `Overload[\n [_T](cls: type[_T], /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> type[_T]\n (cls: None = None, /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> [_T](type[_T]) -> type[_T]\n]`", - "severity": "error" - }, - { - "line": 1209, - "column": 42, - "stop_line": 1209, - "stop_column": 51, - "path": "lib/marin/src/marin/execution/executor.py", - "code": -2, - "name": "not-a-type", - "description": "Expected a type form, got instance of `Overload[\n [_T](cls: type[_T], /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> type[_T]\n (cls: None = None, /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> [_T](type[_T]) -> type[_T]\n]`", - "concise_description": "Expected a type form, got instance of `Overload[\n [_T](cls: type[_T], /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> type[_T]\n (cls: None = None, /, *, init: bool = True, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False) -> [_T](type[_T]) -> type[_T]\n]`", - "severity": "error" - }, - { - "line": 226, - "column": 16, - "stop_line": 226, - "stop_column": 85, - "path": "lib/marin/src/marin/rl/environments/inference_ctx/inflight/worker.py", - "code": -2, - "name": "bad-return", - "description": "Returned type `list[str]` is not assignable to declared return type `str`", - "concise_description": "Returned type `list[str]` is not assignable to declared return type `str`", - "severity": "error" - }, { "line": 214, "column": 39, @@ -1704,18 +1380,6 @@ "concise_description": "`None` is not assignable to upper bound `DataclassInstance` of type variable `_DataclassT`", "severity": "error" }, - { - "line": 200, - "column": 15, - "stop_line": 200, - "stop_column": 49, - "path": "lib/marin/src/marin/rl/rollout_storage.py", - "code": -2, - "name": "unsupported-operation", - "description": "`<` is not supported between `float` and `None`\n Argument `None` is not assignable to parameter `value` with type `float` in function `float.__lt__`", - "concise_description": "`<` is not supported between `float` and `None`", - "severity": "error" - }, { "line": 187, "column": 6, @@ -1765,33 +1429,9 @@ "severity": "error" }, { - "line": 267, - "column": 16, - "stop_line": 267, - "stop_column": 19, - "path": "lib/marin/src/marin/rl/weight_transfer/arrow_flight.py", - "code": -2, - "name": "bad-return", - "description": "Returned type `Literal[123]` is not assignable to declared return type `None`", - "concise_description": "Returned type `Literal[123]` is not assignable to declared return type `None`", - "severity": "error" - }, - { - "line": 270, - "column": 16, - "stop_line": 270, - "stop_column": 33, - "path": "lib/marin/src/marin/rl/weight_transfer/arrow_flight.py", - "code": -2, - "name": "bad-return", - "description": "Returned type `ServerInfo | None` is not assignable to declared return type `ServerInfo`", - "concise_description": "Returned type `ServerInfo | None` is not assignable to declared return type `ServerInfo`", - "severity": "error" - }, - { - "line": 567, + "line": 568, "column": 9, - "stop_line": 567, + "stop_line": 568, "stop_column": 20, "path": "lib/marin/src/marin/rl/weight_transfer/arrow_flight.py", "code": -2, @@ -1801,9 +1441,9 @@ "severity": "error" }, { - "line": 571, + "line": 572, "column": 9, - "stop_line": 571, + "stop_line": 572, "stop_column": 27, "path": "lib/marin/src/marin/rl/weight_transfer/arrow_flight.py", "code": -2, @@ -1884,126 +1524,6 @@ "concise_description": "`object` is not assignable to upper bound `DataclassInstance` of type variable `_DataclassT`", "severity": "error" }, - { - "line": 52, - "column": 5, - "stop_line": 52, - "stop_column": 29, - "path": "lib/marin/src/marin/transform/conversation/adapters.py", - "code": -2, - "name": "invalid-annotation", - "description": "Enum member `SINGLE_COLUMN_MULTI_TURN` may not be annotated directly. Instead, annotate the `_value_` attribute.", - "concise_description": "Enum member `SINGLE_COLUMN_MULTI_TURN` may not be annotated directly. Instead, annotate the `_value_` attribute.", - "severity": "error" - }, - { - "line": 53, - "column": 5, - "stop_line": 53, - "stop_column": 25, - "path": "lib/marin/src/marin/transform/conversation/adapters.py", - "code": -2, - "name": "invalid-annotation", - "description": "Enum member `INSTRUCTION_RESPONSE` may not be annotated directly. Instead, annotate the `_value_` attribute.", - "concise_description": "Enum member `INSTRUCTION_RESPONSE` may not be annotated directly. Instead, annotate the `_value_` attribute.", - "severity": "error" - }, - { - "line": 54, - "column": 5, - "stop_line": 54, - "stop_column": 29, - "path": "lib/marin/src/marin/transform/conversation/adapters.py", - "code": -2, - "name": "invalid-annotation", - "description": "Enum member `INSTRUCT_COLUMN_RESPONSE` may not be annotated directly. Instead, annotate the `_value_` attribute.", - "concise_description": "Enum member `INSTRUCT_COLUMN_RESPONSE` may not be annotated directly. Instead, annotate the `_value_` attribute.", - "severity": "error" - }, - { - "line": 55, - "column": 5, - "stop_line": 55, - "stop_column": 26, - "path": "lib/marin/src/marin/transform/conversation/adapters.py", - "code": -2, - "name": "invalid-annotation", - "description": "Enum member `INSTRUCT_MSG_RESPONSE` may not be annotated directly. Instead, annotate the `_value_` attribute.", - "concise_description": "Enum member `INSTRUCT_MSG_RESPONSE` may not be annotated directly. Instead, annotate the `_value_` attribute.", - "severity": "error" - }, - { - "line": 101, - "column": 24, - "stop_line": 101, - "stop_column": 28, - "path": "lib/marin/src/marin/transform/conversation/adapters.py", - "code": -2, - "name": "bad-return", - "description": "Returned type `None` is not assignable to declared return type `list[OpenAIChatMessage]`", - "concise_description": "Returned type `None` is not assignable to declared return type `list[OpenAIChatMessage]`", - "severity": "error" - }, - { - "line": 110, - "column": 28, - "stop_line": 110, - "stop_column": 61, - "path": "lib/marin/src/marin/transform/conversation/adapters.py", - "code": -2, - "name": "unsupported-operation", - "description": "`None` is not subscriptable", - "concise_description": "`None` is not subscriptable", - "severity": "error" - }, - { - "line": 46, - "column": 89, - "stop_line": 46, - "stop_column": 93, - "path": "lib/marin/src/marin/transform/fasttext/transform.py", - "code": -2, - "name": "bad-return", - "description": "Function declared to return `bool` but is missing an explicit `return`", - "concise_description": "Function declared to return `bool` but is missing an explicit `return`", - "severity": "error" - }, - { - "line": 186, - "column": 16, - "stop_line": 186, - "stop_column": 20, - "path": "lib/marin/src/marin/transform/wikipedia/transform_wikipedia.py", - "code": -2, - "name": "bad-return", - "description": "Returned type `None` is not assignable to declared return type `str`", - "concise_description": "Returned type `None` is not assignable to declared return type `str`", - "severity": "error" - }, - { - "line": 197, - "column": 16, - "stop_line": 197, - "stop_column": 20, - "path": "lib/marin/src/marin/transform/wikipedia/transform_wikipedia.py", - "code": -2, - "name": "bad-return", - "description": "Returned type `None` is not assignable to declared return type `str`", - "concise_description": "Returned type `None` is not assignable to declared return type `str`", - "severity": "error" - }, - { - "line": 18, - "column": 9, - "stop_line": 18, - "stop_column": 16, - "path": "lib/marin/src/marin/utilities/json_encoder.py", - "code": -2, - "name": "bad-param-name-override", - "description": "Class member `CustomJsonEncoder.default` overrides parent class `JSONEncoder` in an inconsistent manner\n Got parameter name `obj`, expected `o`", - "concise_description": "Class member `CustomJsonEncoder.default` overrides parent class `JSONEncoder` in an inconsistent manner", - "severity": "error" - }, { "line": 92, "column": 12, diff --git a/lib/fray/src/fray/v1/cluster/__init__.py b/lib/fray/src/fray/v1/cluster/__init__.py index 7d7314abac..0ebd253680 100644 --- a/lib/fray/src/fray/v1/cluster/__init__.py +++ b/lib/fray/src/fray/v1/cluster/__init__.py @@ -123,7 +123,6 @@ def create_cluster(cluster_spec: str) -> Cluster: "JobRequest", "LocalCluster", "ResourceConfig", - "TPUConfig", "TpuConfig", "TpuType", "create_cluster", diff --git a/lib/fray/src/fray/v1/cluster/base.py b/lib/fray/src/fray/v1/cluster/base.py index ac45c3f57a..81e69f9c5c 100644 --- a/lib/fray/src/fray/v1/cluster/base.py +++ b/lib/fray/src/fray/v1/cluster/base.py @@ -12,7 +12,7 @@ from collections.abc import Callable, Sequence from dataclasses import dataclass, field from enum import StrEnum -from typing import Any, Literal, NewType, Self +from typing import Any, Literal, NewType logger = logging.getLogger(__name__) @@ -436,11 +436,11 @@ def from_callable( c: Callable[..., Any], args: Sequence[Any] = (), kwargs: dict[str, Any] | None = None, - ) -> Self: + ) -> Entrypoint: return Entrypoint(callable_entrypoint=CallableEntrypoint(callable=c, args=args, kwargs=kwargs or {})) @staticmethod - def from_binary(command: str, args: Sequence[str]) -> Self: + def from_binary(command: str, args: Sequence[str]) -> Entrypoint: return Entrypoint(binary_entrypoint=BinaryEntrypoint(command=command, args=args)) @@ -479,7 +479,7 @@ class JobStatus(StrEnum): STOPPED = "stopped" @staticmethod - def finished(status: Self) -> bool: + def finished(status: JobStatus) -> bool: return status in (JobStatus.SUCCEEDED, JobStatus.FAILED, JobStatus.STOPPED) diff --git a/lib/fray/src/fray/v1/cluster/ray/tpu/__init__.py b/lib/fray/src/fray/v1/cluster/ray/tpu/__init__.py index 8ed773f4fc..a2d39e7459 100644 --- a/lib/fray/src/fray/v1/cluster/ray/tpu/__init__.py +++ b/lib/fray/src/fray/v1/cluster/ray/tpu/__init__.py @@ -27,13 +27,11 @@ __all__ = [ "HEALTH_CHECK_TIMEOUT", "START_ACTOR_TIMEOUT", - "TPU_CONFIGS", "MultisliceInfo", "ResourcePoolManager", "SliceActor", "SliceInfo", "SlicePoolManager", - "TPUConfig", "TPUHostActor", "TPUHostInfo", "TpuCancelled", @@ -42,7 +40,6 @@ "TpuRunError", "TpuSuccess", "get_current_tpu_is_preempted", - "get_tpu_config", "run_on_pod", "run_on_pod_multislice", "run_on_pod_ray", diff --git a/lib/fray/src/fray/v1/queue/base.py b/lib/fray/src/fray/v1/queue/base.py index b222aa0a4c..989304bc35 100644 --- a/lib/fray/src/fray/v1/queue/base.py +++ b/lib/fray/src/fray/v1/queue/base.py @@ -11,7 +11,6 @@ from typing import Generic, Protocol, TypeVar T = TypeVar("T") -T_co = TypeVar("T_co", covariant=True) @dataclass @@ -29,39 +28,39 @@ class Lease(Generic[T]): timestamp: float -class Queue(Protocol[T_co]): +class Queue(Protocol[T]): """Distributed queue interface with lease-based task acquisition.""" - def push(self, item: T_co) -> None: + def push(self, item: T) -> None: """Add an item to the queue.""" ... - def peek(self) -> T_co | None: + def peek(self) -> T | None: """View the next available item without acquiring a lease.""" ... - def pop(self, lease_timeout: float = 60.0) -> Lease[T_co] | None: + def pop(self, lease_timeout: float = 60.0) -> Lease[T] | None: """Acquire a lease on the next available item.""" ... - def done(self, lease: Lease[T_co]) -> None: + def done(self, lease: Lease[T]) -> None: """Mark a leased task as successfully completed.""" ... - def release(self, lease: Lease[T_co]) -> None: + def release(self, lease: Lease[T]) -> None: """Release a lease and requeue the item for reprocessing.""" ... -class MemoryQueue(Queue[T_co]): +class MemoryQueue(Queue[T]): def __init__(self): self.queue = [] self.leases = {} # lease_id -> (item, timestamp, timeout) - def push(self, item: T_co) -> None: + def push(self, item: T) -> None: self.queue.append(item) - def peek(self) -> T_co | None: + def peek(self) -> T | None: self._recover_expired_leases() if self.queue: return self.queue[0] @@ -80,7 +79,7 @@ def _recover_expired_leases(self) -> None: self.queue.insert(0, item) del self.leases[lease_id] - def pop(self, lease_timeout: float = 60.0) -> Lease[T_co] | None: + def pop(self, lease_timeout: float = 60.0) -> Lease[T] | None: self._recover_expired_leases() if self.queue: item = self.queue.pop(0) @@ -91,11 +90,11 @@ def pop(self, lease_timeout: float = 60.0) -> Lease[T_co] | None: return lease return None - def done(self, lease: Lease[T_co]) -> None: + def done(self, lease: Lease[T]) -> None: if lease.lease_id in self.leases: del self.leases[lease.lease_id] - def release(self, lease: Lease[T_co]) -> None: + def release(self, lease: Lease[T]) -> None: if lease.lease_id in self.leases: item, _, _ = self.leases[lease.lease_id] self.queue.insert(0, item) diff --git a/lib/fray/src/fray/v2/types.py b/lib/fray/src/fray/v2/types.py index 30b801f1a0..babebdd25d 100644 --- a/lib/fray/src/fray/v2/types.py +++ b/lib/fray/src/fray/v2/types.py @@ -15,7 +15,7 @@ from collections.abc import Callable, Sequence from dataclasses import dataclass, field from enum import StrEnum -from typing import Any, Literal, Self +from typing import Any, Literal # --------------------------------------------------------------------------- # TPU topology @@ -509,11 +509,11 @@ def from_callable( c: Callable[..., Any], args: Sequence[Any] = (), kwargs: dict[str, Any] | None = None, - ) -> Self: + ) -> Entrypoint: return Entrypoint(callable_entrypoint=CallableEntrypoint(callable=c, args=args, kwargs=kwargs or {})) @staticmethod - def from_binary(command: str, args: Sequence[str]) -> Self: + def from_binary(command: str, args: Sequence[str]) -> Entrypoint: return Entrypoint(binary_entrypoint=BinaryEntrypoint(command=command, args=args)) diff --git a/lib/haliax/src/haliax/jax_utils.py b/lib/haliax/src/haliax/jax_utils.py index d8542fd544..2fc2f28e6c 100644 --- a/lib/haliax/src/haliax/jax_utils.py +++ b/lib/haliax/src/haliax/jax_utils.py @@ -357,7 +357,8 @@ def _deshape(x): def to_jax_shape(shape): - from haliax.core import Axis, ensure_tuple + from haliax.core import Axis + from haliax.util import ensure_tuple shape = ensure_tuple(shape) return tuple(axis.size if isinstance(axis, Axis) else axis for axis in shape) diff --git a/lib/iris/src/iris/cluster/config.py b/lib/iris/src/iris/cluster/config.py index d4e776d5ed..871f10d945 100644 --- a/lib/iris/src/iris/cluster/config.py +++ b/lib/iris/src/iris/cluster/config.py @@ -25,6 +25,7 @@ from google.protobuf.json_format import MessageToDict, ParseDict from iris.cluster.constraints import WellKnownAttribute +from iris.cluster.controller.db import ControllerDB from iris.cluster.providers.k8s.tasks import K8sTaskProvider from iris.cluster.providers.protocols import WorkerInfraProvider from iris.cluster.controller.worker_provider import WorkerProvider @@ -1215,7 +1216,7 @@ def create_autoscaler( label_prefix: str, base_worker_config: config_pb2.WorkerConfig | None = None, threads: ThreadContainer | None = None, - db: "ControllerDB | None" = None, # noqa: F821, UP037 — circular import + db: ControllerDB | None = None, ): """Create autoscaler from WorkerInfraProvider and explicit config. diff --git a/lib/iris/src/iris/cluster/controller/controller.py b/lib/iris/src/iris/cluster/controller/controller.py index 73a526c527..2e7862ab90 100644 --- a/lib/iris/src/iris/cluster/controller/controller.py +++ b/lib/iris/src/iris/cluster/controller/controller.py @@ -595,7 +595,7 @@ def _tasks_by_ids_with_attempts(queries: ControllerDB, task_ids: set[JobName]) - return {task.task_id: task for task in tasks_with_attempts(tasks, attempts)} -def _building_counts(queries: ControllerDB, workers: list[WorkerRow]) -> dict[WorkerId, int]: +def _building_counts(queries: ControllerDB, workers: list[WorkerSnapshot]) -> dict[WorkerId, int]: """Count tasks in BUILDING or ASSIGNED state per worker, excluding reservation-holder jobs.""" if not workers: return {} @@ -672,9 +672,9 @@ def _worker_matches_reservation_entry( def _inject_reservation_taints( - workers: list[WorkerRow], + workers: list[WorkerSnapshot], claims: dict[WorkerId, ReservationClaim], -) -> list[WorkerRow]: +) -> list[WorkerSnapshot]: """Create modified worker copies with reservation taints and prioritization. Claimed workers receive a ``reservation-job`` attribute set to the claiming @@ -687,8 +687,8 @@ def _inject_reservation_taints( if not claims: return workers - claimed: list[WorkerRow] = [] - unclaimed: list[WorkerRow] = [] + claimed: list[WorkerSnapshot] = [] + unclaimed: list[WorkerSnapshot] = [] for worker in workers: claim = claims.get(worker.worker_id) if claim is not None: @@ -1536,6 +1536,9 @@ def _capture_one_profile( duration: int, ) -> None: """Capture a single task profile via RPC and store it in the DB.""" + # Profile loop is only spawned on the non-K8s provider path (see start()). + assert not isinstance(self._provider, K8sTaskProvider) + provider = self._provider try: request = job_pb2.ProfileTaskRequest( target=task_id.to_wire(), @@ -1543,7 +1546,7 @@ def _capture_one_profile( profile_type=profile_type, ) timeout_ms = duration * 1000 + 30000 - resp = self._provider.profile_task(worker.address, request, timeout_ms=timeout_ms) + resp = provider.profile_task(worker.address, request, timeout_ms=timeout_ms) if resp.error: logger.debug("Profile (%s) failed for %s: %s", profile_kind, task_id, resp.error) return diff --git a/lib/iris/src/iris/cluster/controller/db.py b/lib/iris/src/iris/cluster/controller/db.py index 3e5f84173e..2050a63d87 100644 --- a/lib/iris/src/iris/cluster/controller/db.py +++ b/lib/iris/src/iris/cluster/controller/db.py @@ -13,10 +13,10 @@ from dataclasses import dataclass, field, replace as dc_replace from pathlib import Path from threading import Lock, RLock -from typing import Any +from typing import Any, Protocol from iris.cluster.constraints import AttributeValue -from iris.cluster.controller.schema import decode_timestamp_ms, decode_worker_id +from iris.cluster.controller.schema import EndpointRow, decode_timestamp_ms, decode_worker_id from iris.cluster.types import TERMINAL_TASK_STATES, JobName, WorkerId from iris.rpc import job_pb2 from rigging.timing import Deadline, Duration, Timestamp @@ -219,6 +219,25 @@ class EndpointQuery: limit: int | None = None +class EndpointRegistryProtocol(Protocol): + """Structural type for the endpoints registry exposed on ``ControllerDB``. + + Defined here (rather than importing the concrete ``EndpointRegistry``) + because ``endpoint_registry`` imports ``EndpointQuery`` / ``TransactionCursor`` + from this module, creating a real module-level cycle. Per iris conventions + (``AGENTS.md``: avoid ``TYPE_CHECKING``; prefer a Protocol at the boundary). + """ + + def query(self, query: EndpointQuery = ...) -> list[EndpointRow]: ... + def resolve(self, name: str) -> EndpointRow | None: ... + def get(self, endpoint_id: str) -> EndpointRow | None: ... + def all(self) -> list[EndpointRow]: ... + def add(self, cur: TransactionCursor, endpoint: EndpointRow) -> bool: ... + def remove(self, cur: TransactionCursor, endpoint_id: str) -> EndpointRow | None: ... + def remove_by_task(self, cur: TransactionCursor, task_id: JobName) -> list[str]: ... + def remove_by_job_ids(self, cur: TransactionCursor, job_ids: Sequence[JobName]) -> list[str]: ... + + def _decode_attribute_rows(rows: Sequence[Any]) -> dict[WorkerId, dict[str, AttributeValue]]: attrs_by_worker: dict[WorkerId, dict[str, AttributeValue]] = {} for row in rows: @@ -331,7 +350,7 @@ def __init__(self, db_dir: Path): logger.info("EndpointRegistry initialized in %.2fs", time.monotonic() - t0) @property - def endpoints(self) -> EndpointRegistry: # noqa: F821 + def endpoints(self) -> EndpointRegistryProtocol: """Process-local cache for the ``endpoints`` table; authoritative for reads.""" return self._endpoint_registry diff --git a/lib/levanter/src/levanter/data/dataset.py b/lib/levanter/src/levanter/data/dataset.py index cefacf84df..ea3a8c03ce 100644 --- a/lib/levanter/src/levanter/data/dataset.py +++ b/lib/levanter/src/levanter/data/dataset.py @@ -3,7 +3,7 @@ import abc import logging -from typing import Callable, Generic, Optional, Sequence, TypeAlias, TypeVar +from typing import Any, Callable, Generic, Optional, Sequence, TypeAlias, TypeVar import jax.random import numpy as np @@ -324,7 +324,7 @@ class BatchMappedAsyncDataset(AsyncDataset[U]): def __init__( self, - dataset: AsyncDataset[T], + dataset: AsyncDataset[Any], fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs, diff --git a/lib/levanter/src/levanter/data/sharded_datasource.py b/lib/levanter/src/levanter/data/sharded_datasource.py index 00a6ced1c1..5e857a345c 100644 --- a/lib/levanter/src/levanter/data/sharded_datasource.py +++ b/lib/levanter/src/levanter/data/sharded_datasource.py @@ -529,7 +529,7 @@ class _TransformedDataset: class _MappedShardedDataSource(ShardedDataSource[T], _TransformedDataset): - def __init__(self, source: ShardedDataSource[T_co], fn: Callable[[T_co], T]): + def __init__(self, source: ShardedDataSource[Any], fn: Callable[[Any], T]): self.source = source self.fn = fn self._transform = _MapTransform(fn) diff --git a/lib/levanter/src/levanter/inference/openai.py b/lib/levanter/src/levanter/inference/openai.py index 3ea47b3358..d79d69152f 100644 --- a/lib/levanter/src/levanter/inference/openai.py +++ b/lib/levanter/src/levanter/inference/openai.py @@ -349,14 +349,15 @@ def _batch_processing_loop(self): while not self.shutdown_event.is_set(): try: batch = self.batch_queue.get(timeout=1) + except queue.Empty: + continue + try: with ( self.model_lock, hax.partitioning.set_mesh(self.config.trainer.device_mesh), hax.axis_mapping(self.config.trainer.compute_axis_mapping), ): self._execute_batch(batch) - except queue.Empty: - continue except Exception as e: logger.error(f"Error executing batch: {e}", exc_info=True) # Set exceptions on all futures in the batch diff --git a/lib/levanter/src/levanter/optim/model_averaging.py b/lib/levanter/src/levanter/optim/model_averaging.py index 3dd2b54adf..11d7f1cddc 100644 --- a/lib/levanter/src/levanter/optim/model_averaging.py +++ b/lib/levanter/src/levanter/optim/model_averaging.py @@ -39,10 +39,10 @@ class EmaModelAveraging(ModelAveraging[M]): model: M beta: float = eqx.field(static=True) - def update(self: S, new_model: M, step: int) -> S: + def update(self: S, model: M, step: int) -> S: del step # 1 - beta because increment_update expects the weight of the new model - return dataclasses.replace(self, model=optax.incremental_update(new_model, self.model, 1 - self.beta)) # type: ignore + return dataclasses.replace(self, model=optax.incremental_update(model, self.model, 1 - self.beta)) # type: ignore @property def model_params(self) -> M: @@ -70,11 +70,11 @@ def _raw_weight(self, step: int) -> float: frac = jnp.clip(t / self.decay_steps, 0.0, 1.0) return float(1.0 - jnp.sqrt(frac)) - def update(self, new_model: M, step: int) -> "EmaDecaySqrtModelAveraging[M]": + def update(self, model: M, step: int) -> "EmaDecaySqrtModelAveraging[M]": w = self._raw_weight(step) new_tot_w = self.tot_weight + w alpha = 0.0 if new_tot_w == 0.0 else w / new_tot_w - updated = optax.incremental_update(new_model, self.model, alpha) + updated = optax.incremental_update(model, self.model, alpha) return dataclasses.replace(self, model=updated, tot_weight=new_tot_w) # type: ignore[arg-type] @property diff --git a/lib/levanter/src/levanter/store/cache.py b/lib/levanter/src/levanter/store/cache.py index f92e77b47b..d2e60d970d 100644 --- a/lib/levanter/src/levanter/store/cache.py +++ b/lib/levanter/src/levanter/store/cache.py @@ -11,7 +11,7 @@ import threading import time from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, Sequence, TypeVar, Union import deepdiff import jax @@ -217,7 +217,7 @@ def empty(): return CacheMetadata() -class SerialCacheWriter: +class SerialCacheWriter(Generic[T]): """ Writes TreeCache-compatible caches to disk without Ray. Mostly for scripts and debugging. """ diff --git a/lib/levanter/src/levanter/utils/py_utils.py b/lib/levanter/src/levanter/utils/py_utils.py index 64c954a4e0..98405e5d82 100644 --- a/lib/levanter/src/levanter/utils/py_utils.py +++ b/lib/levanter/src/levanter/utils/py_utils.py @@ -126,7 +126,8 @@ def __init__(self, *args, bytes_strategy="base64", **kwargs): super().__init__(*args, **kwargs) self.bytes_strategy = bytes_strategy - def default(self, obj): + def default(self, o): + obj = o # Known clean conversions if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)): # ISO 8601; preserves tzinfo if present diff --git a/lib/marin/src/marin/cluster/config.py b/lib/marin/src/marin/cluster/config.py index 5726e26f1a..e92048ef39 100644 --- a/lib/marin/src/marin/cluster/config.py +++ b/lib/marin/src/marin/cluster/config.py @@ -6,6 +6,7 @@ import os from dataclasses import dataclass from pathlib import Path +from typing import TypedDict import jinja2 import yaml @@ -314,7 +315,14 @@ def list_available_configs() -> list[str]: }, } -GENERATION_CONFIGS = { + +class _GenerationConfig(TypedDict): + runtime_version: str + base_worker: str + slices: list[int] + + +GENERATION_CONFIGS: dict[str, _GenerationConfig] = { "v4": { "runtime_version": "tpu-ubuntu2204-base", "base_worker": "8", diff --git a/lib/marin/src/marin/datakit/download/wikipedia.py b/lib/marin/src/marin/datakit/download/wikipedia.py index 2244c54376..0904a3a5e7 100644 --- a/lib/marin/src/marin/datakit/download/wikipedia.py +++ b/lib/marin/src/marin/datakit/download/wikipedia.py @@ -58,7 +58,11 @@ def process_file(input_file: str, output_path: str) -> Iterable[str]: with open_url(input_file) as f: with tarfile.open(fileobj=f, mode="r:gz") as tr: for info in tr: - with tr.extractfile(info) as file: + extracted = tr.extractfile(info) + if extracted is None: + # Skip non-regular entries (directories, symlinks, etc.) + continue + with extracted as file: file_content = file.read() file_path = os.path.join(output_path, info.name + ".gz") diff --git a/lib/marin/src/marin/execution/executor.py b/lib/marin/src/marin/execution/executor.py index 0891ebcec1..ce8fd7a5a8 100644 --- a/lib/marin/src/marin/execution/executor.py +++ b/lib/marin/src/marin/execution/executor.py @@ -152,7 +152,7 @@ def _get_local_data_browser_port(default: int = 5000) -> int: return default -ConfigT = TypeVar("ConfigT", covariant=True, bound=dataclass) +ConfigT = TypeVar("ConfigT", covariant=True) T_co = TypeVar("T_co", covariant=True) ExecutorFunction = Callable | None @@ -631,7 +631,7 @@ def _maybe_attach_inferred_region_constraint( ) -def asdict_without_description(obj: dataclass) -> dict[str, Any]: +def asdict_without_description(obj: Any) -> dict[str, Any]: """Return the dict form of a dataclass, but remove the `description` field.""" def recurse(value: Any): @@ -970,7 +970,7 @@ class ExecutorStepInfo: fn_name: str """Rendered string of `step.fn`.""" - config: dataclass + config: Any """`step.config`, but concretized (no more `InputName`, `OutputName`, or `VersionedValue`).""" description: str | None @@ -1133,9 +1133,7 @@ def recurse(obj: Any) -> None: return max_budget -def instantiate_config( - config: dataclass, output_path: str, output_paths: dict[ExecutorStep, str], prefix: str -) -> dataclass: +def instantiate_config(config: Any, output_path: str, output_paths: dict[ExecutorStep, str], prefix: str) -> Any: """ Return a "real" config where all the special values (e.g., `InputName`, `OutputName`, and `VersionedValue`) have been replaced with @@ -1206,7 +1204,7 @@ def __init__( self.executor_info_base_path = executor_info_base_path self.description = description - self.configs: dict[ExecutorStep, dataclass] = {} + self.configs: dict[ExecutorStep, Any] = {} self.dependencies: dict[ExecutorStep, list[ExecutorStep]] = {} self.versions: dict[ExecutorStep, dict[str, Any]] = {} # pseudo-dependencies only impact version but don't block execution of descendants diff --git a/lib/marin/src/marin/infra/__init__.py b/lib/marin/src/marin/infra/__init__.py deleted file mode 100644 index 42a0f250ae..0000000000 --- a/lib/marin/src/marin/infra/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright The Marin Authors -# SPDX-License-Identifier: Apache-2.0 - -"""Infrastructure helpers.""" - -from .tpu_monitor import TpuMonitor, start_tpu_monitor_on_head - -__all__ = ["TpuMonitor", "start_tpu_monitor_on_head"] diff --git a/lib/marin/src/marin/rl/environments/inference_ctx/inflight/worker.py b/lib/marin/src/marin/rl/environments/inference_ctx/inflight/worker.py index 1b5f0d2edd..d6e42505f8 100644 --- a/lib/marin/src/marin/rl/environments/inference_ctx/inflight/worker.py +++ b/lib/marin/src/marin/rl/environments/inference_ctx/inflight/worker.py @@ -212,7 +212,7 @@ async def _init_engine(self, engine_args): logger.info(f"Engine initialized: {engine}") return engine - def generate(self, prompts: list[str], sampling_params: SamplingParams) -> str: + def generate(self, prompts: list[str], sampling_params: SamplingParams) -> list[str]: """ Synchronous generate method - runs async code under the hood. diff --git a/lib/marin/src/marin/rl/rollout_storage.py b/lib/marin/src/marin/rl/rollout_storage.py index 96eafda6f6..1e43c7a86b 100644 --- a/lib/marin/src/marin/rl/rollout_storage.py +++ b/lib/marin/src/marin/rl/rollout_storage.py @@ -196,8 +196,8 @@ def _get_available_files(self) -> list[str]: def read_batch(self, timeout: float | None = None) -> RolloutBatch | None: """Read a single batch with optional timeout.""" - start_time = time.time() - while time.time() - start_time < timeout: + deadline = time.time() + timeout if timeout is not None else float("inf") + while time.time() < deadline: available_files = self._get_available_files() for file_path in available_files: if file_path not in self._read_files: diff --git a/lib/marin/src/marin/rl/weight_transfer/arrow_flight.py b/lib/marin/src/marin/rl/weight_transfer/arrow_flight.py index 297c16d621..c2e9e312d1 100644 --- a/lib/marin/src/marin/rl/weight_transfer/arrow_flight.py +++ b/lib/marin/src/marin/rl/weight_transfer/arrow_flight.py @@ -264,9 +264,8 @@ def update_server(self, weight_id: int, param_names: list[str], server_locations param_names=param_names, ) logger.info(f"Updated server: weight_id={weight_id}, params={len(param_names)}, servers={len(server_locations)}") - return 123 - def fetch_server(self) -> ServerInfo: + def fetch_server(self) -> ServerInfo | None: return self._server_info @@ -292,6 +291,8 @@ def do_put(self, context, descriptor, reader, writer): def do_get(self, context, ticket): """Serve weight data to inference workers.""" + # Propagate typed FlightUnavailableError without rewrapping — callers + # distinguish "no weights yet" (retry) from internal server errors. try: ticket_data = ticket.ticket.decode("utf-8") @@ -305,12 +306,17 @@ def do_get(self, context, ticket): with self._lock: if weight_id != self._latest_weight_id: logger.debug(f"Requested weight_id {weight_id} stale, returning {self._latest_weight_id}") + if self._latest_weight_id is None: + raise flight.FlightUnavailableError("No weights available yet") weight_id = self._latest_weight_id (schema, batches) = self._weights_store[weight_id][param_name] return flight.RecordBatchStream(pa.RecordBatchReader.from_batches(schema, batches)) + except flight.FlightUnavailableError: + # Typed "retry me" signal; do not rewrap as Internal. + raise except Exception as e: logger.error(f"Error in do_get: {e}") raise flight.FlightInternalError(f"Failed to get weights: {e}") from e diff --git a/lib/marin/src/marin/transform/conversation/adapters.py b/lib/marin/src/marin/transform/conversation/adapters.py index 3ab1af7304..233a1a80ed 100644 --- a/lib/marin/src/marin/transform/conversation/adapters.py +++ b/lib/marin/src/marin/transform/conversation/adapters.py @@ -49,10 +49,10 @@ class InputDatasetFormat(str, Enum): | what's the car's speed ?" }] | time taken. Answer is 375/3 = 125 kmph" | """ - SINGLE_COLUMN_MULTI_TURN: str = "messages" - INSTRUCTION_RESPONSE: str = "instruction_response" - INSTRUCT_COLUMN_RESPONSE: str = "instruct_column_response" - INSTRUCT_MSG_RESPONSE: str = "instruct_msg_response" + SINGLE_COLUMN_MULTI_TURN = "messages" + INSTRUCTION_RESPONSE = "instruction_response" + INSTRUCT_COLUMN_RESPONSE = "instruct_column_response" + INSTRUCT_MSG_RESPONSE = "instruct_msg_response" @dataclass @@ -91,7 +91,7 @@ class TransformAdapter: def transform_conversation_to_openai_format( self, row: dict[str, Any], - ) -> list[OpenAIChatMessage]: + ) -> list[OpenAIChatMessage] | None: if self.dataset_format == InputDatasetFormat.INSTRUCTION_RESPONSE: messages = [] instruction = row[self.instruction_column] @@ -107,6 +107,8 @@ def transform_conversation_to_openai_format( if completion[self.filter_on_key] > best_metric: best_metric = completion[self.filter_on_key] best_completion = completion + if best_completion is None: + return None response = best_completion[self.content_key] messages.append(OpenAIChatMessage(role="user", content=instruction)) messages.append(OpenAIChatMessage(role="assistant", content=response)) diff --git a/lib/marin/src/marin/transform/conversation/transform_conversation.py b/lib/marin/src/marin/transform/conversation/transform_conversation.py index 914cd954b1..7e52565960 100644 --- a/lib/marin/src/marin/transform/conversation/transform_conversation.py +++ b/lib/marin/src/marin/transform/conversation/transform_conversation.py @@ -120,7 +120,7 @@ def _normalize_tool_structures(message: dict) -> dict: def transform_row(row: dict, cfg: TransformSFTDatasetConfig, adapter: TransformAdapter): source = unwrap_versioned_value(cfg.source) - transformed_row_messages: list[OpenAIChatMessage] = adapter.transform_conversation_to_openai_format(row) + transformed_row_messages: list[OpenAIChatMessage] | None = adapter.transform_conversation_to_openai_format(row) if transformed_row_messages is None: logger.warning(f"{source} returning no valid messages") diff --git a/lib/marin/src/marin/transform/fasttext/transform.py b/lib/marin/src/marin/transform/fasttext/transform.py index d99d554fd3..1d577376f4 100644 --- a/lib/marin/src/marin/transform/fasttext/transform.py +++ b/lib/marin/src/marin/transform/fasttext/transform.py @@ -80,6 +80,8 @@ def convert_fasttext_to_dolma_format(input_path: str, output_path: str, source: output_jsonl_gz.write(f"{json.dumps(doc)}\n") + return True + @dataclass class TransformFasttextToDolmaConfig: diff --git a/lib/marin/src/marin/transform/wikipedia/transform_wikipedia.py b/lib/marin/src/marin/transform/wikipedia/transform_wikipedia.py index 9f80ff389e..0181554d60 100644 --- a/lib/marin/src/marin/transform/wikipedia/transform_wikipedia.py +++ b/lib/marin/src/marin/transform/wikipedia/transform_wikipedia.py @@ -177,7 +177,9 @@ def unwrap_eqn(html: str): return str(html) -def postprocess_content(content: str, digit_threshold: int, word_threshold: int, special_char_threshold: float) -> str: +def postprocess_content( + content: str, digit_threshold: int, word_threshold: int, special_char_threshold: float +) -> str | None: """ Postprocesses the content by deleting it if its is mainly digits, words, and special characters. """ diff --git a/lib/marin/src/marin/utilities/json_encoder.py b/lib/marin/src/marin/utilities/json_encoder.py index 7a96e63406..88a04f28e3 100644 --- a/lib/marin/src/marin/utilities/json_encoder.py +++ b/lib/marin/src/marin/utilities/json_encoder.py @@ -15,7 +15,8 @@ class CustomJsonEncoder(json.JSONEncoder): - def default(self, obj): + def default(self, o): + obj = o if isinstance(obj, timedelta): return {"days": obj.days, "seconds": obj.seconds, "microseconds": obj.microseconds} if isinstance(obj, Path): diff --git a/lib/zephyr/src/zephyr/plan.py b/lib/zephyr/src/zephyr/plan.py index c52acb248f..d92db2ce83 100644 --- a/lib/zephyr/src/zephyr/plan.py +++ b/lib/zephyr/src/zephyr/plan.py @@ -776,6 +776,9 @@ def run_stage( if isinstance(op, Map): if op.needs_shard_context: + # Map.fn is a compose_map pipeline that takes shard_idx/total_shards as kwargs. + # (The ShardInfo dataclass is only used internally by compose_map to hand off + # to user-provided MapShardOp.fn at the logical-op boundary — see line 237.) stream = op.fn(stream, shard_idx=ctx.shard_idx, total_shards=ctx.total_shards) else: stream = op.fn(stream) diff --git a/pyproject.toml b/pyproject.toml index ffb64d4be3..4375e05721 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,9 +163,6 @@ missing-attribute = false # Additional error types with >10 occurrences that are hard to fix systematically: -# Deprecated APIs (19 occurrences) - would require code changes across multiple files -deprecated = false - # Unknown name errors (17 occurrences) - often from dynamic code or missing stubs unknown-name = false