Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions lib/levanter/src/levanter/optim/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from levanter.models.linear import has_linear_like_marker
import levanter.tracker
from levanter.utils.jax_utils import is_inexact_arrayish


T = TypeVar("T")
Expand All @@ -43,28 +42,6 @@ def label_linear_like_module(module: Any, *, weight_label: str, bias_label: str)
return dataclasses.replace(module, weight=weight_label, bias=masked_bias)


def hvp(f, x, v):
"""Compute the Hessian-vector product of a function."""
return eqx.filter_jvp(eqx.filter_grad(f), (x,), (v,))[1]
# grad_f = eqx.filter_grad(f)
# _, vjp_fn = eqx.filter_vjp(grad_f, x)
# return vjp_fn(v)[0]


def tree_gaussian_like(key, tree):
"""
Samples a tree of gaussian noise with the same structure as `tree`, except for leaves which are not inexact arrays,
for which it returns None
"""
leaves, structure = jax.tree_util.tree_flatten(tree)
keys = jax.random.split(key, len(leaves))
rand_n = lambda x, key: jax.random.normal(key, x.shape) if is_inexact_arrayish(x) else None
g = jax.tree_util.tree_map(rand_n, leaves, list(keys))
g = jax.tree_util.tree_unflatten(structure, g)

return g


def log_norm_passthrough(desc: str) -> GradientTransformation:
"""
Creates a gradient transformation that logs the L2 norm of the updates
Expand Down
17 changes: 0 additions & 17 deletions lib/levanter/src/levanter/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
from math import prod
from typing import Optional, Tuple, Type, TypeAlias, Union

import jax
import numpy as np
from haliax import Axis
from haliax.util import is_named_array
from jax import ShapeDtypeStruct
from jaxtyping import PyTree

DType = Union[np.dtype, Type[int], Type[float], Type[bool]]

Expand All @@ -36,17 +33,3 @@ def to_raw_shape(shape: Union[ShapeSpec, NamedShapeSpec]) -> Optional[Tuple[int,
if raw is None:
return None
return tuple(ax.size for ax in raw)


def conforms(shape: PyTree[Union[ShapeSpec, NamedShapeSpec]], tree: PyTree) -> bool:
"""Check if a tree conforms to a shape specification."""

def _leaf_conforms(shape_spec: Union[ShapeSpec, NamedShapeSpec], leaf):
if isinstance(shape_spec, ShapeSpec): # type: ignore
return shape_spec.shape == leaf.shape and shape_spec.dtype == leaf.dtype
else:
return (shape_spec.shape is None or shape_spec.shape == leaf.axes) and (
shape_spec.dtype is None or shape_spec.dtype == leaf.dtype
)

return jax.tree_util.tree_all(jax.tree_util.tree_map(_leaf_conforms, shape, tree, is_leaf=is_named_array))
1 change: 0 additions & 1 deletion lib/levanter/src/levanter/utils/stat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def add(self, x: Arrayish, total: Arrayish) -> "RunningMean":
new_total = self.total + total
ratio = hax.where(new_total, total / new_total, 0.0)
new_mean = self.mean + delta * ratio
new_total = self.total + total
return RunningMean(new_mean, new_total)

def __add__(self, other: "RunningMean"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def transform_row(row: dict, cfg: TransformSFTDatasetConfig, adapter: TransformA
extra_columns: dict[str, object] = {}
for source_column, target_column in metadata_remap.items():
if target_column in _RESERVED_TOP_LEVEL_FIELDS:
logging.log(
logging.WARNING,
logger.warning(
f"Skipping remap for column '{source_column}' because target '{target_column}' is reserved.",
)
continue
Expand All @@ -151,9 +150,7 @@ def transform_row(row: dict, cfg: TransformSFTDatasetConfig, adapter: TransformA
content = message.get("content")
if isinstance(content, str):
message["content"] = _apply_replacements(content, replacements)
transformed_row_messages = [_normalize_tool_structures(message) for message in transformed_row_messages]
else:
transformed_row_messages = [_normalize_tool_structures(message) for message in transformed_row_messages]
transformed_row_messages = [_normalize_tool_structures(message) for message in transformed_row_messages]
if adapter.extra_metadata_fn:
extra_from_fn = adapter.extra_metadata_fn(row)
if extra_from_fn:
Expand Down Expand Up @@ -203,7 +200,7 @@ def _get_available_subsets(cfg: TransformSFTDatasetConfig) -> Sequence[str | Non
try:
subsets = datasets.get_dataset_config_names(cfg.source)
except Exception as exc:
logging.log(logging.WARNING, f"Unable to fetch dataset configs for {cfg.source}: {exc}")
logger.warning(f"Unable to fetch dataset configs for {cfg.source}: {exc}")
subsets = []
if not subsets:
return [None]
Expand All @@ -217,7 +214,7 @@ def _get_available_splits(cfg: TransformSFTDatasetConfig, subset: str | None) ->
try:
split_names = datasets.get_dataset_split_names(cfg.source, name=subset)
except Exception as exc:
logging.log(logging.WARNING, f"Unable to fetch splits for {cfg.source} (subset={subset}): {exc}")
logger.warning(f"Unable to fetch splits for {cfg.source} (subset={subset}): {exc}")
split_names = ["train"]
if not split_names:
return ["train"]
Expand Down Expand Up @@ -260,10 +257,10 @@ def get_dataset_tasks(cfg: TransformSFTDatasetConfig):
requested = set(configured_splits)
missing = sorted(requested - set(splits))
if missing:
logging.log(logging.WARNING, f"Requested split(s) {missing} for {source} skipped.")
logger.warning(f"Requested split(s) {missing} for {source} skipped.")
splits = [split for split in splits if split in requested]
if not splits:
logging.log(logging.WARNING, f"No splits to process for subset={subset}; skipping.")
logger.warning(f"No splits to process for subset={subset}; skipping.")
continue

# 3. For each split, enumerate shards
Expand All @@ -283,7 +280,6 @@ def get_dataset_tasks(cfg: TransformSFTDatasetConfig):

dataset = load_dataset_with_backoff(
context=f"{source} subset={subset_name} split={split}",
logger=logger,
**dataset_kwargs,
)
num_shards = dataset.num_shards
Expand Down Expand Up @@ -319,7 +315,7 @@ def process_shard_task(task: ShardTask) -> dict:
# If output already exists, skip the work to let Zephyr resume cleanly without sentinels.
fs, _ = url_to_fs(output_filename)
if fs.exists(output_filename):
logging.info(
logger.info(
f"Skipping subset={subset_name} split={task.split} shard={task.shard_idx} "
f"because output exists: {output_filename}"
)
Expand All @@ -343,7 +339,6 @@ def process_shard_task(task: ShardTask) -> dict:

dataset = load_dataset_with_backoff(
context=f"{task.source} subset={subset_name} split={task.split} shard={task.shard_idx}",
logger=logger,
**dataset_kwargs,
)
shard_dataset = dataset.shard(num_shards=task.num_shards, index=task.shard_idx)
Expand All @@ -357,7 +352,7 @@ def transform_records():

result = write_jsonl_file(transform_records(), output_filename)

logging.info(
logger.info(
f"Wrote {result['count']} rows to {result['path']} "
f"for subset={subset_name} split={task.split} shard={task.shard_idx}"
)
Expand Down
13 changes: 1 addition & 12 deletions lib/zephyr/src/zephyr/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
StageType,
compute_plan,
)
from zephyr.shuffle import ListShard, MemChunk, _write_scatter
from zephyr.writers import INTERMEDIATE_CHUNK_SIZE, ensure_parent_dir

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -113,18 +114,6 @@ def read(self) -> list:
return pickle.load(f)


# ---------------------------------------------------------------------------
# Scatter Parquet support (imported from shuffle.py)
# ---------------------------------------------------------------------------

from zephyr.shuffle import ( # noqa: E402
ListShard,
MemChunk,
ScatterReader, # noqa: F401 — re-exported for plan.py and external callers
ScatterWriter, # noqa: F401 — re-exported for external callers
_write_scatter,
)

# ---------------------------------------------------------------------------
# Task result
# ---------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion lib/zephyr/src/zephyr/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ def run_stage(
elif isinstance(op, Reduce):
# Build ScatterReader directly from per-mapper sidecars, then
# merge sorted chunks and reduce per key.
from zephyr.execution import ScatterReader
from zephyr.shuffle import ScatterReader

shard = ctx.shard
if not isinstance(shard, ScatterReader):
Expand Down
Loading