diff --git a/pyproject.toml b/pyproject.toml index 60bab574..88c47694 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "ZeroBand is a production ready codebase for decentralized trainin readme = "README.md" requires-python = ">=3.10" dependencies = [ - "torch==2.5.1", + "torch==2.6.0", "numpy", "setuptools", "transformers>=4.44.2", diff --git a/src/zeroband/checkpoint.py b/src/zeroband/checkpoint.py index a5b6a391..18047108 100644 --- a/src/zeroband/checkpoint.py +++ b/src/zeroband/checkpoint.py @@ -152,6 +152,7 @@ def non_error_barrier(): dist.barrier() except Exception as e: from zeroband.utils.logging import get_logger + get_logger().info(f"Error in data checkpointing barrier: {e}, continuing training") @@ -174,8 +175,8 @@ def __init__( self, config: CkptConfig, model: nn.Module, - optimizer: Optimizer, - scheduler: LambdaLR, + optimizer: list[Optimizer], + scheduler: list[LambdaLR], dataloader: StatefulDataLoader, training_progress: TrainingProgress, data_rank: int | None, diff --git a/src/zeroband/config.py b/src/zeroband/config.py index ee2d615e..38da2a2c 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -22,8 +22,11 @@ class DataConfig(BaseConfig): reverse_data_files: bool = False split_by_data_rank: bool = True + class AdamConfig(BaseConfig): - type: Literal["adam"] = "adam" # the literal is used to distinguish between the different optimizers configuration in the union type + type: Literal["adam"] = ( + "adam" # the literal is used to distinguish between the different optimizers configuration in the union type + ) lr: float = 4e-4 weight_decay: float = 0.1 betas1: float = 0.9 @@ -41,11 +44,34 @@ class SoapConfig(BaseConfig): precondition_frequency: int = 100 -OptimizersConfig: TypeAlias = AdamConfig | SoapConfig +class MuonConfig(BaseConfig): + type: Literal["muon"] = "muon" + ns_steps: int = 5 + lr: float = 0.02 + momentum: float = 0.95 + nesterov: bool = True + compression_ratio: float | None = None + compression_step_start: int = 0 + lie_compression: bool = False + + @model_validator(mode="after") + def calidate_compression(self): + if self.compression_ratio is not None: + assert 0 < self.compression_ratio <= 1, "compression_ratio must be between 0 and 1" + return self + + +OptimizersConfig: TypeAlias = AdamConfig | SoapConfig | MuonConfig + + +class PowerSGDConfig(BaseConfig): + rank: int = 1 + warmup_steps: int = 1000 class OptimConfig(BaseConfig): optim: OptimizersConfig = AdamConfig() + power_sgd: PowerSGDConfig | None = None lr: float = 4e-4 weight_decay: float = 0.1 @@ -70,6 +96,7 @@ class DilocoConfig(BaseConfig): retry_all_reduce: int = 3 + class MemoryProfilerConfig(BaseConfig): freq: int = 10 snapshot_dir: str @@ -231,7 +258,8 @@ def get_env_config(config: Config | None, item: str | None, default: Any | None return cfg -def get_env_config_bool(config: Config | None, item: str | None, default: bool | None = None) -> bool: + +def get_env_config_bool(config: Config | None, item: str | None, default: bool | None = None) -> bool: """ Call get_env_config and convert strings to bools where makes sense. @@ -248,4 +276,3 @@ def get_env_config_bool(config: Config | None, item: str | None, default: bool if isinstance(val, str): return val.lower() == "true" or val.lower() == "1" return bool(val) - diff --git a/src/zeroband/models/llama/model.py b/src/zeroband/models/llama/model.py index cb767790..3a3b96a7 100644 --- a/src/zeroband/models/llama/model.py +++ b/src/zeroband/models/llama/model.py @@ -23,7 +23,8 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE from torch.nn.attention import SDPBackend, sdpa_kernel -_flex_attention_compiled = torch.compile(flex_attention, dynamic=False) +_flex_attention_compiled = torch.compile(flex_attention, dynamic=True) +# _flex_attention_compiled = flex_attention # copied from https://github.com/pytorch/torchtune/blob/f2bd4bc25b24587aef40f486087412b9da8f1d94/torchtune/modules/attention_utils.py#L27 diff --git a/src/zeroband/muon.py b/src/zeroband/muon.py new file mode 100644 index 00000000..e13b5f28 --- /dev/null +++ b/src/zeroband/muon.py @@ -0,0 +1,273 @@ +# copied from https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py +import torch +from torch import Tensor +import torch.distributed as dist + + +@torch.compile +def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert ( + G.ndim >= 2 + ) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.mT + B = ( + b * A + c * A @ A + ) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +@torch.compile +def low_rank_approximation_zeropower_via_newtonschulz5(G: Tensor, rank: int, steps: int = 5) -> tuple[Tensor, Tensor]: + """ + Compute a low-rank approximation of matrix G using Newton-Schulz iteration. + Returns the expanded approximated matrix directly. + + Args: + G: Input tensor of shape (..., m, n) + rank: Target rank for the approximation + steps: Number of Newton-Schulz iterations + + Returns: + G_approx: Low rank approximation of G with the same shape as G + """ + assert G.ndim >= 2 + assert rank > 0 and rank <= min(G.size(-2), G.size(-1)) + + # Constants for quintic iteration + a, b, c = (3.4445, -4.7750, 2.0315) + + # Convert to bfloat16 + G = G.bfloat16() + + # Initialize random projection matrix Q in bfloat16 + n = G.size(-1) + Q = torch.randn((*G.shape[:-2], n, rank), device=G.device).bfloat16() + Q = Q / (Q.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Power iteration to find approximate range + Y = G @ Q + + # Normalize Y + Y = Y / (Y.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Newton-Schulz iterations for orthogonalization + for _ in range(steps): + A = Y @ Y.mT + B = b * A + c * A @ A + Y = a * Y + B @ Y + + # Compute factors and immediately expand + U = Y + V = G.mT @ U + G_approx = U @ V.mT + + return G_approx + + +@torch.compile +def low_rank_approximation_via_newtonschulz_lie(G: Tensor, rank: int, steps: int = 5) -> Tensor: + """ + Compute a low-rank approximation of matrix G using Newton-Schulz iteration with Lie group structure. + Returns the approximated matrix in the form Q = (I + UVᵀ)diag(d). + + Args: + G: Input tensor of shape (..., m, n) + rank: Target rank for the approximation + steps: Number of Newton-Schulz iterations + + Returns: + G_approx: Low rank approximation of G with the same shape as G + """ + assert G.ndim >= 2 + assert rank > 0 and rank <= min(G.size(-2), G.size(-1)) + + # Constants for quintic iteration + a, b, c = (3.4445, -4.7750, 2.0315) + + # Convert to bfloat16 + G = G.bfloat16() + m, n = G.size(-2), G.size(-1) + + # Initialize random projection matrix Q + Q = torch.randn((*G.shape[:-2], n, rank), device=G.device).bfloat16() + Q = Q / (Q.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Power iteration to find approximate range + Y = G @ Q + Y = Y / (Y.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Newton-Schulz iterations for orthogonalization + for _ in range(steps): + A = Y @ Y.mT + B = b * A + c * A @ A + Y = a * Y + B @ Y + + # Compute U factor + U = Y + + # Compute V factor through projection + V = G.mT @ U + + # Normalize U and V to have unit norm + U_norms = torch.sum(U * U, dim=-1, keepdim=True).sqrt() + V_norms = torch.sum(V * V, dim=-1, keepdim=True).sqrt() + + U = U / (U_norms + 1e-7) + V = V / (V_norms + 1e-7) + + # Create identity matrix of appropriate size + Id = torch.eye(m, device=G.device, dtype=G.dtype) + Id = Id.expand(*G.shape[:-2], m, m) + + # Compute diagonal scaling + d = torch.diagonal(G @ G.mT, dim1=-2, dim2=-1).sqrt() + d = d / (d.norm(dim=-1, keepdim=True) + 1e-7) + + # Construct final approximation Q = (I + UVᵀ)diag(d) + G_approx = U @ V.mT + + # Scale the approximation to match G's magnitude + G_norms = torch.sum(G * G, dim=(-2, -1), keepdim=True).sqrt() + G_approx_norms = torch.sum(G_approx * G_approx, dim=(-2, -1), keepdim=True).sqrt() + scale = G_norms / (G_approx_norms + 1e-7) + G_approx = G_approx * scale + + return G_approx + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - This optimizer assumes that all parameters passed in are 2D. + - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D + parameters; those should all be optimized by a standard method (e.g., AdamW). + - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. + - We believe it is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven"t tested this. + - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). + + Arguments: + lr: The learning rate used by the internal SGD. + momentum: The momentum used by the internal SGD. + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iteration steps to use. + """ + + def __init__( + self, + params, + lr=0.02, + momentum=0.95, + nesterov=True, + ns_steps=5, + rank=0, + world_size=1, + compression_ratio: float | None = None, + compression_step_start: int = 0, + lie_compression: bool = False, + ): + self.rank = rank + self.world_size = world_size + self.compression_ratio = compression_ratio + self.compression_step_start = compression_step_start + self.lie_compression = lie_compression + self._step_count = 0 # Add step counter + + defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps) + params: list[Tensor] = [*params] + assert all(isinstance(p, Tensor) for p in params) + sizes = {p.numel() for p in params} + + def create_update_buffer(size: int): + b = torch.empty(self.world_size, size, dtype=torch.bfloat16, device="cuda") + return dict(update_buffer=b, update_buffer_views=[b[i] for i in range(self.world_size)]) + + param_groups = [ + dict(params=[p for p in params if p.numel() == size], **create_update_buffer(size)) for size in sizes + ] + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + self._step_count += 1 # Increment step counter + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + nesterov = group["nesterov"] + ns_steps = group["ns_steps"] + update_buffer = group["update_buffer"] + update_buffer_views: list[Tensor] = group["update_buffer_views"] + params: list[Tensor] = group["params"] + handle = None + params_world = None + + def update_prev(): + if params_world is None: + return + assert handle is not None + handle.wait() + for p_world, g_world in zip(params_world, update_buffer_views): + p_world.add_( + g_world.view_as(p_world), + alpha=-lr * max(1, p_world.size(-2) / p_world.size(-1)) ** 0.5, + ) + + for base_i in range(len(params))[:: self.world_size]: + if base_i + self.rank < len(params): + p = params[base_i + self.rank] + g = p.grad + assert g is not None + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf: Tensor = state["momentum_buffer"] + buf.lerp_(g, 1 - momentum) + g = g.lerp_(buf, momentum) if nesterov else buf + + # Only apply compression if we've reached the start step and compression ratio is set + if self.compression_ratio is not None and self._step_count >= self.compression_step_start: + mat_rank = int(g.shape[0] * self.compression_ratio) + if self.lie_compression: + g = low_rank_approximation_via_newtonschulz_lie(g, mat_rank, steps=ns_steps).flatten() + else: + g = low_rank_approximation_zeropower_via_newtonschulz5( + g, mat_rank, steps=ns_steps + ).flatten() + else: + g = zeropower_via_newtonschulz5(g, steps=ns_steps).flatten() + else: + g = update_buffer_views[self.rank] + update_prev() + handle = dist.all_gather_into_tensor(update_buffer, g, async_op=True) + params_world = params[base_i : base_i + self.world_size] + update_prev() diff --git a/src/zeroband/optimizers.py b/src/zeroband/optimizers.py index 6faf5c01..a6dd702d 100644 --- a/src/zeroband/optimizers.py +++ b/src/zeroband/optimizers.py @@ -1,37 +1,71 @@ -from typing import Iterable import torch from distributed_shampoo import ( DefaultEigenvalueCorrectedShampooConfig, DistributedShampoo, - FullyShardShampooConfig, + DDPShampooConfig, ShampooPT2CompileConfig, ) -from zeroband.config import AdamConfig, SoapConfig, OptimizersConfig +from zeroband.models.llama.model import Transformer +from zeroband.muon import Muon +from zeroband.config import AdamConfig, SoapConfig, OptimizersConfig, MuonConfig +from zeroband.utils.world_info import get_world_info -def get_optimizer(params: Iterable[torch.nn.Parameter], config: OptimizersConfig) -> torch.optim.Optimizer: + +def get_optimizer(model: Transformer, config: OptimizersConfig) -> list[torch.optim.Optimizer]: if isinstance(config, AdamConfig): - return torch.optim.AdamW( - params, - lr=config.lr, - weight_decay=config.weight_decay, - betas=(config.betas1, config.betas2), - ) + return [ + torch.optim.AdamW( + model.parameters(), + lr=config.lr, + weight_decay=config.weight_decay, + betas=(config.betas1, config.betas2), + ) + ] elif isinstance(config, SoapConfig): - return DistributedShampoo( - params, + return [ + DistributedShampoo( + model.parameters(), + lr=config.lr, + betas=(config.betas1, config.betas2), + epsilon=1e-12, + weight_decay=config.weight_decay, + max_preconditioner_dim=config.max_preconditioner_dim, + precondition_frequency=config.precondition_frequency, + use_decoupled_weight_decay=True, + # This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is + # less expensive and might thereby allow for a smaller `precondition_frequency`. + preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, + distributed_config=DDPShampooConfig(), + shampoo_pt2_compile_config=ShampooPT2CompileConfig(enable_shampoo_pt2_dynamic_shape=False), + ) + ] + elif isinstance(config, MuonConfig): + world_info = get_world_info() + hidden_matrix_params = [p for n, p in model.layers.named_parameters() if p.ndim >= 2 and "embed" not in n] + embed_params = [p for n, p in model.named_parameters() if "embed" in n] + scalar_params = [p for p in model.parameters() if p.ndim < 2] + head_params = [model.output.weight] + + # init the optimizer(s) + adam_params = [ + dict(params=head_params, lr=0.008), + dict(params=embed_params, lr=0.6), + dict(params=scalar_params, lr=0.04), + ] + # small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence + # discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 + optimizer1 = torch.optim.Adam(adam_params, betas=(0.8, 0.95), eps=1e-10, fused=True) + optimizer2 = Muon( + hidden_matrix_params, lr=config.lr, - betas=(config.betas1, config.betas2), - epsilon=1e-12, - weight_decay=config.weight_decay, - max_preconditioner_dim=config.max_preconditioner_dim, - precondition_frequency=config.precondition_frequency, - use_decoupled_weight_decay=True, - # This can also be set to `DefaultSOAPConfig` which uses QR decompositions, hence is - # less expensive and might thereby allow for a smaller `precondition_frequency`. - preconditioner_config=DefaultEigenvalueCorrectedShampooConfig, - distributed_config=FullyShardShampooConfig(), - shampoo_pt2_compile_config=ShampooPT2CompileConfig(enable_shampoo_pt2_dynamic_shape=False), + momentum=config.momentum, + rank=world_info.rank, + world_size=world_info.world_size, + compression_ratio=config.compression_ratio, + compression_step_start=config.compression_step_start, + lie_compression=config.lie_compression, ) + return [optimizer2, optimizer1] else: raise ValueError(f"Unknown optimizer {config.optimizer}") diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 32f9cdd0..ef434366 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -1,14 +1,18 @@ import os import time from typing import TYPE_CHECKING -from multiprocessing.process import _children # type: ignore +from multiprocessing.process import _children # type: ignore import torch import torch.distributed as dist -from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy # type: ignore + +# from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy # type: ignore +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook + from torch.autograd.profiler import record_function -from zeroband.checkpoint import CkptManager, TrainingProgress +from zeroband.checkpoint import TrainingProgress from zeroband.comms import ElasticDeviceMesh from zeroband.config import Config from zeroband.data import TEST_VOCAB_SIZE, get_dataloader @@ -26,7 +30,7 @@ get_tensor_list_signature, get_peak_flops, get_num_params, - get_num_flop_per_token + get_num_flop_per_token, ) from zeroband.utils.metric_logger import MetricLogger, WandbMetricLogger, DummyMetricLogger from zeroband.utils.monitor import HttpMonitor @@ -65,7 +69,7 @@ def log_hash_training_state( if config.diloco is not None and diloco is not None: outer_optimizer_hash = get_optimizer_signature(diloco.outer_optimizer) - outer_model_hash = get_tensor_list_signature(diloco.param_list_cpu) # type: ignore + outer_model_hash = get_tensor_list_signature(diloco.param_list_cpu) # type: ignore logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}") logger.debug(f"outer diloco model hash {id} : {outer_model_hash}") @@ -148,62 +152,78 @@ def train(config: Config): enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src ) - mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None + # mp_policy = MixedPrecisionPolicy( + # param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None + # ) + + # for layer_id, transformer_block in model.layers.items(): + # if config.train.reshard_after_forward: + # reshard_after_forward = int(layer_id) < len(model.layers) - 1 + # else: + # reshard_after_forward = False + # fully_shard( + # transformer_block, + # mp_policy=mp_policy, + # mesh=elastic_device_mesh.cuda_local_mesh, + # reshard_after_forward=reshard_after_forward, + # ) + # fully_shard( + # model, + # mp_policy=mp_policy, + # mesh=elastic_device_mesh.cuda_local_mesh, + # reshard_after_forward=config.train.reshard_after_forward, + # ) + model: DDP = DDP( + model, device_ids=[world_info.local_rank], broadcast_buffers=False, gradient_as_bucket_view=True ) - for layer_id, transformer_block in model.layers.items(): - if config.train.reshard_after_forward: - reshard_after_forward = int(layer_id) < len(model.layers) - 1 - else: - reshard_after_forward = False - fully_shard( - transformer_block, - mp_policy=mp_policy, - mesh=elastic_device_mesh.cuda_local_mesh, - reshard_after_forward=reshard_after_forward, + if config.optim.power_sgd is not None: + state = powerSGD_hook.PowerSGDState( + process_group=None, # Default process group + matrix_approximation_rank=config.optim.power_sgd.rank, # Adjust rank based on compression needs + start_powerSGD_iter=config.optim.power_sgd.warmup_steps, # When to start compression ) - fully_shard( - model, - mp_policy=mp_policy, - mesh=elastic_device_mesh.cuda_local_mesh, - reshard_after_forward=config.train.reshard_after_forward, - ) - logger.debug("model fsdped") + + model.register_comm_hook(state, powerSGD_hook.powerSGD_hook) + + logger.debug("model ddped") # Setup optimizers with record_function("Set up Optimizers"): - inner_optimizer = get_optimizer(model.parameters(), config.optim.optim) + inner_optimizers = get_optimizer(model.module, config.optim.optim) diloco = Diloco(config.diloco, model, elastic_device_mesh) if config.diloco is not None else None - scheduler = get_scheduler( - sched_type=config.optim.sched_type, - optimizer=inner_optimizer, - num_warmup_steps=config.optim.warmup_steps, - num_stable_steps=config.optim.stable_steps, - num_training_steps=config.optim.total_steps, - ) + schedulers = [ + get_scheduler( + sched_type=config.optim.sched_type, + optimizer=opt, + num_warmup_steps=config.optim.warmup_steps, + num_stable_steps=config.optim.stable_steps, + num_training_steps=config.optim.total_steps, + ) + for opt in inner_optimizers + ] training_progress = TrainingProgress(total_tokens=0, outer_step=0, step=0) - ckpt_manager = CkptManager( - config=config.ckpt, - model=model, - optimizer=inner_optimizer, - scheduler=scheduler, - dataloader=train_dataloader, - training_progress=training_progress, - data_rank=config.data.data_rank, - diloco_offloaded_optimizer=diloco.outer_optimizer if config.diloco is not None else None, # type: ignore - diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, # type: ignore - ) + # ckpt_manager = CkptManager( + # config=config.ckpt, + # model=model, + # optimizer=inner_optimizers, + # scheduler=schedulers, + # dataloader=train_dataloader, + # training_progress=training_progress, + # data_rank=config.data.data_rank, + # diloco_offloaded_optimizer=diloco.outer_optimizer if config.diloco is not None else None, # type: ignore + # diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, # type: ignore + # ) if world_info.rank == 0: logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger metric_logger = logger_cls( project=config.project, - config={"config": config.model_dump(), "world_info": world_info.json()}, + logger_config={"config": config.model_dump(), "world_info": world_info.json()}, resume=config.wandb_resume, ) else: @@ -215,17 +235,17 @@ def train(config: Config): model = torch.compile(model) if not TYPE_CHECKING else model logger.debug("model compiled") - with record_function("Resume checkpoint"): - if config.ckpt.resume is not None: - # all is inplace - ckpt_manager.load( - resume_ckpt_path=config.ckpt.resume, - skip_dataloader=config.ckpt.skip_dataloader, - data_path=config.ckpt.data_path, - ) - log_hash_training_state( - config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="resume" - ) + # with record_function("Resume checkpoint"): + # if config.ckpt.resume is not None: + # # all is inplace + # ckpt_manager.load( + # resume_ckpt_path=config.ckpt.resume, + # skip_dataloader=config.ckpt.skip_dataloader, + # data_path=config.ckpt.data_path, + # ) + # log_hash_training_state( + # config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="resume" + # ) if config.train.memory_profiler is not None: memory_profiler = MemoryProfiler(config.train.memory_profiler.freq, config.train.memory_profiler.snapshot_dir) @@ -239,7 +259,7 @@ def train(config: Config): logger.info("starting training") - need_live_recovery = config.ckpt.live_recovery_rank_src is not None + # need_live_recovery = config.ckpt.live_recovery_rank_src is not None while True: if num_inner_steps > 1: # if we don't use diloco we don't print the outer step logs @@ -247,45 +267,45 @@ def train(config: Config): time_start_outer = time.perf_counter() - if config.diloco is not None: - assert diloco is not None - # this is a patch for now to allow live recovery worker to not affect the all reduce at all + # if config.diloco is not None: + # assert diloco is not None + # # this is a patch for now to allow live recovery worker to not affect the all reduce at all - if not need_live_recovery: - elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=True) + # if not need_live_recovery: + # elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=True) - maybe_dest_rank = elastic_device_mesh.live_recovery.should_send_ckpt_to() - if maybe_dest_rank is not None: - logger.info(f"Start live recovery to rank {maybe_dest_rank}") - ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank, blocking=True) + # maybe_dest_rank = elastic_device_mesh.live_recovery.should_send_ckpt_to() + # if maybe_dest_rank is not None: + # logger.info(f"Start live recovery to rank {maybe_dest_rank}") + # ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank, blocking=True) - elastic_device_mesh.live_recovery.reset() - else: - ## receiving - time_start_live_recovery = time.perf_counter() - logger.info(f"Start live recovery from rank {config.ckpt.live_recovery_rank_src}") + # elastic_device_mesh.live_recovery.reset() + # else: + # ## receiving + # time_start_live_recovery = time.perf_counter() + # logger.info(f"Start live recovery from rank {config.ckpt.live_recovery_rank_src}") - ## we create grad buffer and opts stats mamnually, the value will be overwritten by the ckpt but we need the DTensor to be correctly init before loading it + # ## we create grad buffer and opts stats mamnually, the value will be overwritten by the ckpt but we need the DTensor to be correctly init before loading it - diloco.outer_optimizer.step() # need to step to init the DTensor stats + # diloco.outer_optimizer.step() # need to step to init the DTensor stats - ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg) + # ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg) - log_hash_training_state( - config, - model, - inner_optimizer, - diloco, - metric_logger, - step=training_progress.step, - id="live_reco_recv", - ) - need_live_recovery = False + # log_hash_training_state( + # config, + # model, + # inner_optimizer, + # diloco, + # metric_logger, + # step=training_progress.step, + # id="live_reco_recv", + # ) + # need_live_recovery = False - if config.ckpt.remote_data_load: - ckpt_manager.remote_data_load() + # if config.ckpt.remote_data_load: + # ckpt_manager.remote_data_load() - logger.info("live recovery done in %f", time.perf_counter() - time_start_live_recovery) + # logger.info("live recovery done in %f", time.perf_counter() - time_start_live_recovery) # at the beginning of the inner steps we allow joiner to arrive. # We maybe reinit before the all reduce but only to allow leaving, not to join anymore @@ -300,7 +320,8 @@ def train(config: Config): for grad_acc_step in range(gradient_accumulation_steps): is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 # no sync if we are accumulating gradients - model.set_requires_gradient_sync(not is_accumulating) + # model.set_requires_gradient_sync(not is_accumulating) + model.require_backward_grad_sync = not is_accumulating with record_function("Load batch"): # TODO/NOTE: We could overlap sending the batch with communication @@ -315,7 +336,9 @@ def train(config: Config): block_mask = None with record_function("Run model"): - logits = model(tokens=input_ids, block_mask=block_mask).contiguous() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = model(tokens=input_ids, block_mask=block_mask).contiguous() + flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab") flatten_labels = rearrange(labels, "b seq -> (b seq)") @@ -324,7 +347,7 @@ def train(config: Config): flatten_logits, flatten_labels, z_weight=config.optim.z_loss_weight if config.optim.z_loss else None, - num_chunks=config.optim.num_chunks + num_chunks=config.optim.num_chunks, ) del logits @@ -356,13 +379,14 @@ def train(config: Config): torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) with record_function("Optimizer step"): - inner_optimizer.step() - scheduler.step() - inner_optimizer.zero_grad() + for inner_optimizer, scheduler in zip(inner_optimizers, schedulers): + inner_optimizer.step() + scheduler.step() + inner_optimizer.zero_grad() # logging training_progress.step += 1 - inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0] + inner_lr = [group["lr"] for group in inner_optimizers[0].param_groups][0] # syncing loss across all data parallel rank within a nodes @@ -431,18 +455,18 @@ def train(config: Config): training_progress.outer_step += 1 - if ( - config.ckpt.interval is not None - and training_progress.step > 0 - and training_progress.step % config.ckpt.interval == 0 - ): - # we only allow to checkpoint after a outer step. For non diloco training outer step = 1 anyway + # if ( + # config.ckpt.interval is not None + # and training_progress.step > 0 + # and training_progress.step % config.ckpt.interval == 0 + # ): + # # we only allow to checkpoint after a outer step. For non diloco training outer step = 1 anyway - do_remote = config.ckpt.remote is not None and training_progress.step % config.ckpt.remote.interval == 0 - ckpt_manager.save(remote=do_remote) - log_hash_training_state( - config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="save" - ) + # do_remote = config.ckpt.remote is not None and training_progress.step % config.ckpt.remote.interval == 0 + # ckpt_manager.save(remote=do_remote) + # log_hash_training_state( + # config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="save" + # ) if config.diloco: tokens_per_second = ( @@ -478,7 +502,7 @@ def train(config: Config): if config.monitor is not None: monitor.finish() - ckpt_manager.wait_for_blocking_job() + # ckpt_manager.wait_for_blocking_job() del elastic_device_mesh # allow to clean up for smoother tests transition @@ -491,7 +515,7 @@ def train(config: Config): if __name__ == "__main__": # Allow eager fallback during production so that that the training runs dont die # However, in development, we want to know that we broke torch compile - torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ # type: ignore + torch._dynamo.config.suppress_errors = "ZERO_BAND_DEV" not in os.environ # type: ignore torch.set_float32_matmul_precision("high") torch.manual_seed(42) @@ -514,21 +538,20 @@ def pretty_dict(d, indent=2): try: if config.train.torch_profiler and world_info.rank == 0: - # NOTE(apaz-cli): I cannot seem to get the memory profiler to work. # Running into this issue: https://github.com/pytorch/pytorch/issues/64345 # In the meantime, we can use the memory snapshotter. logger.debug("Running train() with profiler.") prof = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=True, - #profile_memory=True, - #with_stack=True, - ) + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + # profile_memory=True, + # with_stack=True, + ) try: prof.__enter__() train(config) @@ -546,8 +569,8 @@ def pretty_dict(d, indent=2): logger.info("\n" + "*" * width + " GPU MEM " + "*" * width) logger.info(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10)) - #logger.info("Exporting memory timeline.") - #prof.export_memory_timeline(f"logs/mem_timeline.html", device="cuda:0") + # logger.info("Exporting memory timeline.") + # prof.export_memory_timeline(f"logs/mem_timeline.html", device="cuda:0") else: train(config) except Exception as e: diff --git a/src/zeroband/utils/metric_logger.py b/src/zeroband/utils/metric_logger.py index 0a47dc3f..85847925 100644 --- a/src/zeroband/utils/metric_logger.py +++ b/src/zeroband/utils/metric_logger.py @@ -2,11 +2,9 @@ from typing import Any, Protocol import importlib.util -from zeroband.config import get_env_config - class MetricLogger(Protocol): - def __init__(self, project, config): ... + def __init__(self, project, logger_config): ... def log(self, metrics: dict[str, Any]): ... @@ -14,16 +12,16 @@ def finish(self): ... class WandbMetricLogger(MetricLogger): - def __init__(self, project, config, resume: bool): + def __init__(self, project, logger_config, resume: bool): if importlib.util.find_spec("wandb") is None: raise ImportError("wandb is not installed. Please install it to use WandbMonitor.") import wandb - run_name = get_env_config(config, "run_name") + run_name = logger_config["config"]["run_name"] wandb.init( - project=project, config=config, name=run_name, resume="auto" if resume else None + project=project, config=logger_config, name=run_name, resume="auto" if resume else None ) # make wandb reuse the same run id if possible def log(self, metrics: dict[str, Any]): @@ -38,9 +36,9 @@ def finish(self): class DummyMetricLogger(MetricLogger): - def __init__(self, project, config, *args, **kwargs): + def __init__(self, project, logger_config, *args, **kwargs): self.project = project - self.config = config + self.logger_config = logger_config open(self.project, "a").close() # Create an empty file to append to self.data = [] diff --git a/uv.lock b/uv.lock index cc6fdf7a..7a75071c 100644 --- a/uv.lock +++ b/uv.lock @@ -5,11 +5,7 @@ resolution-markers = [ "python_full_version < '3.11' and sys_platform != 'linux'", "python_full_version == '3.11.*' and sys_platform == 'linux'", "python_full_version == '3.11.*' and sys_platform != 'linux'", - "python_full_version < '3.11' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and sys_platform == 'linux'", "python_full_version == '3.12.*' and sys_platform == 'linux'", - "python_full_version < '3.11' and sys_platform != 'linux'", - "python_full_version == '3.11.*' and sys_platform != 'linux'", "python_full_version == '3.12.*' and sys_platform != 'linux'", "python_full_version >= '3.13' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'linux'", @@ -328,7 +324,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -1084,6 +1080,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bc/f7/7ec7fddc92e50714ea3745631f79bd9c96424cb2702632521028e57d3a36/multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02", size = 134824 }, { url = "https://files.pythonhosted.org/packages/50/15/b56e50e8debaf439f44befec5b2af11db85f6e0f344c3113ae0be0593a91/multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a", size = 143519 }, { url = "https://files.pythonhosted.org/packages/0a/7d/a988f258104dcd2ccf1ed40fdc97e26c4ac351eeaf81d76e266c52d84e2f/multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e", size = 146741 }, + { url = "https://files.pythonhosted.org/packages/ea/89/38df130f2c799090c978b366cfdf5b96d08de5b29a4a293df7f7429fa50b/multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435", size = 132628 }, + { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351 }, ] [[package]] @@ -1247,7 +1245,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771 }, { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 }, - { url = "https://files.pythonhosted.org/packages/e2/2a/4f27ca96232e8b5269074a72e03b4e0d43aa68c9b965058b1684d07c6ff8/nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc", size = 396895858 }, ] [[package]] @@ -1257,7 +1254,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556 }, { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 }, - { url = "https://files.pythonhosted.org/packages/f3/79/8cf313ec17c58ccebc965568e5bcb265cdab0a1df99c4e674bb7a3b99bfe/nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922", size = 9938035 }, ] [[package]] @@ -1267,7 +1263,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372 }, { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 }, - { url = "https://files.pythonhosted.org/packages/7c/30/8c844bfb770f045bcd8b2c83455c5afb45983e1a8abf0c4e5297b481b6a5/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec", size = 19751955 }, ] [[package]] @@ -1277,7 +1272,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177 }, { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 }, - { url = "https://files.pythonhosted.org/packages/a8/8b/450e93fab75d85a69b50ea2d5fdd4ff44541e0138db16f9cd90123ef4de4/nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e", size = 878808 }, ] [[package]] @@ -1285,11 +1279,10 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, - { url = "https://files.pythonhosted.org/packages/3f/d0/f90ee6956a628f9f04bf467932c0a25e5a7e706a684b896593c06c82f460/nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a", size = 679925892 }, ] [[package]] @@ -1297,12 +1290,11 @@ name = "nvidia-cufft-cu12" version = "11.2.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 }, { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, - { url = "https://files.pythonhosted.org/packages/f6/ee/3f3f8e9874f0be5bbba8fb4b62b3de050156d159f8b6edc42d6f1074113b/nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b", size = 210576476 }, ] [[package]] @@ -1312,7 +1304,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811 }, { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 }, - { url = "https://files.pythonhosted.org/packages/1c/22/2573503d0d4e45673c263a313f79410e110eb562636b0617856fdb2ff5f6/nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771", size = 55799918 }, ] [[package]] @@ -1320,14 +1311,13 @@ name = "nvidia-cusolver-cu12" version = "11.6.1.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 }, { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, - { url = "https://files.pythonhosted.org/packages/f2/be/d435b7b020e854d5d5a682eb5de4328fd62f6182507406f2818280e206e2/nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c", size = 125224015 }, ] [[package]] @@ -1335,12 +1325,20 @@ name = "nvidia-cusparse-cu12" version = "12.3.1.170" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 }, { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, - { url = "https://files.pythonhosted.org/packages/a2/e0/3155ca539760a8118ec94cc279b34293309bcd14011fc724f87f31988843/nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f", size = 204684315 }, +] + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/8e/675498726c605c9441cf46653bd29cb1b8666da1fb1469ffa25f67f20c58/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:067a7f6d03ea0d4841c85f0c6f1991c5dda98211f6302cb83a4ab234ee95bef8", size = 149422781 }, + { url = "https://files.pythonhosted.org/packages/78/a8/bcbb63b53a4b1234feeafb65544ee55495e1bb37ec31b999b963cbccfd1d/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9", size = 150057751 }, ] [[package]] @@ -1358,7 +1356,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510 }, { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 }, - { url = "https://files.pythonhosted.org/packages/81/19/0babc919031bee42620257b9a911c528f05fb2688520dcd9ca59159ffea8/nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1", size = 95336325 }, ] [[package]] @@ -1368,7 +1365,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417 }, { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 }, - { url = "https://files.pythonhosted.org/packages/54/1b/f77674fbb73af98843be25803bbd3b9a4f0a96c75b8d33a2854a5c7d2d77/nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485", size = 66307 }, ] [[package]] @@ -1477,7 +1473,7 @@ name = "portalocker" version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pywin32", marker = "platform_system == 'Windows'" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/7e/57/b969aed128768558255822e75b402a19530bd63321f637d42f4724abc1ed/portalocker-3.0.0.tar.gz", hash = "sha256:21f535de2e7a82c94c130c054adb5c7421d480d5619d61073996e2f89bcb879b", size = 41961 } wheels = [ @@ -1532,8 +1528,6 @@ version = "6.0.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/18/c7/8c6872f7372eb6a6b2e4708b88419fb46b857f7a2e1892966b851cc79fc9/psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2", size = 508067 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c5/66/78c9c3020f573c58101dc43a44f6855d01bbbd747e24da2f0c4491200ea3/psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35", size = 249766 }, - { url = "https://files.pythonhosted.org/packages/e1/3f/2403aa9558bea4d3854b0e5e567bc3dd8e9fbc1fc4453c0aa9aafeb75467/psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1", size = 253024 }, { url = "https://files.pythonhosted.org/packages/0b/37/f8da2fbd29690b3557cca414c1949f92162981920699cd62095a984983bf/psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0", size = 250961 }, { url = "https://files.pythonhosted.org/packages/35/56/72f86175e81c656a01c4401cd3b1c923f891b31fbcebe98985894176d7c9/psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0", size = 287478 }, { url = "https://files.pythonhosted.org/packages/19/74/f59e7e0d392bc1070e9a70e2f9190d652487ac115bb16e2eff6b22ad1d24/psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd", size = 290455 }, @@ -2373,44 +2367,48 @@ wheels = [ [[package]] name = "torch" -version = "2.5.1" +version = "2.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, - { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/ef/834af4a885b31a0b32fff2d80e1e40f771e1566ea8ded55347502440786a/torch-2.5.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:71328e1bbe39d213b8721678f9dcac30dfc452a46d586f1d514a6aa0a99d4744", size = 906446312 }, - { url = "https://files.pythonhosted.org/packages/69/f0/46e74e0d145f43fa506cb336eaefb2d240547e4ce1f496e442711093ab25/torch-2.5.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:34bfa1a852e5714cbfa17f27c49d8ce35e1b7af5608c4bc6e81392c352dbc601", size = 91919522 }, - { url = "https://files.pythonhosted.org/packages/a5/13/1eb674c8efbd04d71e4a157ceba991904f633e009a584dd65dccbafbb648/torch-2.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:32a037bd98a241df6c93e4c789b683335da76a2ac142c0973675b715102dc5fa", size = 203088048 }, - { url = "https://files.pythonhosted.org/packages/a9/9d/e0860474ee0ff8f6ef2c50ec8f71a250f38d78a9b9df9fd241ad3397a65b/torch-2.5.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:23d062bf70776a3d04dbe74db950db2a5245e1ba4f27208a87f0d743b0d06e86", size = 63877046 }, - { url = "https://files.pythonhosted.org/packages/d1/35/e8b2daf02ce933e4518e6f5682c72fd0ed66c15910ea1fb4168f442b71c4/torch-2.5.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:de5b7d6740c4b636ef4db92be922f0edc425b65ed78c5076c43c42d362a45457", size = 906474467 }, - { url = "https://files.pythonhosted.org/packages/40/04/bd91593a4ca178ece93ca55f27e2783aa524aaccbfda66831d59a054c31e/torch-2.5.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:340ce0432cad0d37f5a31be666896e16788f1adf8ad7be481196b503dad675b9", size = 91919450 }, - { url = "https://files.pythonhosted.org/packages/0d/4a/e51420d46cfc90562e85af2fee912237c662ab31140ab179e49bd69401d6/torch-2.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:603c52d2fe06433c18b747d25f5c333f9c1d58615620578c326d66f258686f9a", size = 203098237 }, - { url = "https://files.pythonhosted.org/packages/d0/db/5d9cbfbc7968d79c5c09a0bc0bc3735da079f2fd07cc10498a62b320a480/torch-2.5.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:31f8c39660962f9ae4eeec995e3049b5492eb7360dd4f07377658ef4d728fa4c", size = 63884466 }, - { url = "https://files.pythonhosted.org/packages/8b/5c/36c114d120bfe10f9323ed35061bc5878cc74f3f594003854b0ea298942f/torch-2.5.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ed231a4b3a5952177fafb661213d690a72caaad97d5824dd4fc17ab9e15cec03", size = 906389343 }, - { url = "https://files.pythonhosted.org/packages/6d/69/d8ada8b6e0a4257556d5b4ddeb4345ea8eeaaef3c98b60d1cca197c7ad8e/torch-2.5.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:3f4b7f10a247e0dcd7ea97dc2d3bfbfc90302ed36d7f3952b0008d0df264e697", size = 91811673 }, - { url = "https://files.pythonhosted.org/packages/5f/ba/607d013b55b9fd805db2a5c2662ec7551f1910b4eef39653eeaba182c5b2/torch-2.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:73e58e78f7d220917c5dbfad1a40e09df9929d3b95d25e57d9f8558f84c9a11c", size = 203046841 }, - { url = "https://files.pythonhosted.org/packages/57/6c/bf52ff061da33deb9f94f4121fde7ff3058812cb7d2036c97bc167793bd1/torch-2.5.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:8c712df61101964eb11910a846514011f0b6f5920c55dbf567bff8a34163d5b1", size = 63858109 }, - { url = "https://files.pythonhosted.org/packages/69/72/20cb30f3b39a9face296491a86adb6ff8f1a47a897e4d14667e6cf89d5c3/torch-2.5.1-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:9b61edf3b4f6e3b0e0adda8b3960266b9009d02b37555971f4d1c8f7a05afed7", size = 906393265 }, + { url = "https://files.pythonhosted.org/packages/37/81/aa9ab58ec10264c1abe62c8b73f5086c3c558885d6beecebf699f0dbeaeb/torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:6860df13d9911ac158f4c44031609700e1eba07916fff62e21e6ffa0a9e01961", size = 766685561 }, + { url = "https://files.pythonhosted.org/packages/86/86/e661e229df2f5bfc6eab4c97deb1286d598bbeff31ab0cdb99b3c0d53c6f/torch-2.6.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c4f103a49830ce4c7561ef4434cc7926e5a5fe4e5eb100c19ab36ea1e2b634ab", size = 95751887 }, + { url = "https://files.pythonhosted.org/packages/20/e0/5cb2f8493571f0a5a7273cd7078f191ac252a402b5fb9cb6091f14879109/torch-2.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:56eeaf2ecac90da5d9e35f7f35eb286da82673ec3c582e310a8d1631a1c02341", size = 204165139 }, + { url = "https://files.pythonhosted.org/packages/e5/16/ea1b7842413a7b8a5aaa5e99e8eaf3da3183cc3ab345ad025a07ff636301/torch-2.6.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:09e06f9949e1a0518c5b09fe95295bc9661f219d9ecb6f9893e5123e10696628", size = 66520221 }, + { url = "https://files.pythonhosted.org/packages/78/a9/97cbbc97002fff0de394a2da2cdfa859481fdca36996d7bd845d50aa9d8d/torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7979834102cd5b7a43cc64e87f2f3b14bd0e1458f06e9f88ffa386d07c7446e1", size = 766715424 }, + { url = "https://files.pythonhosted.org/packages/6d/fa/134ce8f8a7ea07f09588c9cc2cea0d69249efab977707cf67669431dcf5c/torch-2.6.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ccbd0320411fe1a3b3fec7b4d3185aa7d0c52adac94480ab024b5c8f74a0bf1d", size = 95759416 }, + { url = "https://files.pythonhosted.org/packages/11/c5/2370d96b31eb1841c3a0883a492c15278a6718ccad61bb6a649c80d1d9eb/torch-2.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:46763dcb051180ce1ed23d1891d9b1598e07d051ce4c9d14307029809c4d64f7", size = 204164970 }, + { url = "https://files.pythonhosted.org/packages/0b/fa/f33a4148c6fb46ca2a3f8de39c24d473822d5774d652b66ed9b1214da5f7/torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21", size = 66530713 }, + { url = "https://files.pythonhosted.org/packages/e5/35/0c52d708144c2deb595cd22819a609f78fdd699b95ff6f0ebcd456e3c7c1/torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9", size = 766624563 }, + { url = "https://files.pythonhosted.org/packages/01/d6/455ab3fbb2c61c71c8842753b566012e1ed111e7a4c82e0e1c20d0c76b62/torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb", size = 95607867 }, + { url = "https://files.pythonhosted.org/packages/18/cf/ae99bd066571656185be0d88ee70abc58467b76f2f7c8bfeb48735a71fe6/torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239", size = 204120469 }, + { url = "https://files.pythonhosted.org/packages/81/b4/605ae4173aa37fb5aa14605d100ff31f4f5d49f617928c9f486bb3aaec08/torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989", size = 66532538 }, + { url = "https://files.pythonhosted.org/packages/24/85/ead1349fc30fe5a32cadd947c91bda4a62fbfd7f8c34ee61f6398d38fb48/torch-2.6.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:4874a73507a300a5d089ceaff616a569e7bb7c613c56f37f63ec3ffac65259cf", size = 766626191 }, + { url = "https://files.pythonhosted.org/packages/dd/b0/26f06f9428b250d856f6d512413e9e800b78625f63801cbba13957432036/torch-2.6.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a0d5e1b9874c1a6c25556840ab8920569a7a4137afa8a63a32cee0bc7d89bd4b", size = 95611439 }, + { url = "https://files.pythonhosted.org/packages/c2/9c/fc5224e9770c83faed3a087112d73147cd7c7bfb7557dcf9ad87e1dda163/torch-2.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:510c73251bee9ba02ae1cb6c9d4ee0907b3ce6020e62784e2d7598e0cfa4d6cc", size = 204126475 }, + { url = "https://files.pythonhosted.org/packages/88/8b/d60c0491ab63634763be1537ad488694d316ddc4a20eaadd639cedc53971/torch-2.6.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:ff96f4038f8af9f7ec4231710ed4549da1bdebad95923953a25045dcf6fd87e2", size = 66536783 }, ] [[package]] @@ -2447,7 +2445,7 @@ name = "tqdm" version = "4.66.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 } wheels = [ @@ -2490,15 +2488,13 @@ wheels = [ [[package]] name = "triton" -version = "3.1.0" +version = "3.2.0" source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "filelock", marker = "python_full_version < '3.13'" }, -] wheels = [ - { url = "https://files.pythonhosted.org/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b0dd10a925263abbe9fa37dcde67a5e9b2383fc269fdf59f5657cac38c5d1d8", size = 209460013 }, - { url = "https://files.pythonhosted.org/packages/86/17/d9a5cf4fcf46291856d1e90762e36cbabd2a56c7265da0d1d9508c8e3943/triton-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f34f6e7885d1bf0eaaf7ba875a5f0ce6f3c13ba98f9503651c1e6dc6757ed5c", size = 209506424 }, - { url = "https://files.pythonhosted.org/packages/78/eb/65f5ba83c2a123f6498a3097746607e5b2f16add29e36765305e4ac7fdd8/triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc", size = 209551444 }, + { url = "https://files.pythonhosted.org/packages/01/65/3ffa90e158a2c82f0716eee8d26a725d241549b7d7aaf7e4f44ac03ebd89/triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3e54983cd51875855da7c68ec05c05cf8bb08df361b1d5b69e05e40b0c9bd62", size = 253090354 }, + { url = "https://files.pythonhosted.org/packages/a7/2e/757d2280d4fefe7d33af7615124e7e298ae7b8e3bc4446cdb8e88b0f9bab/triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8009a1fb093ee8546495e96731336a33fb8856a38e45bb4ab6affd6dbc3ba220", size = 253157636 }, + { url = "https://files.pythonhosted.org/packages/06/00/59500052cb1cf8cf5316be93598946bc451f14072c6ff256904428eaf03c/triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d9b215efc1c26fa7eefb9a157915c92d52e000d2bf83e5f69704047e63f125c", size = 253159365 }, + { url = "https://files.pythonhosted.org/packages/c7/30/37a3384d1e2e9320331baca41e835e90a3767303642c7a80d4510152cbcf/triton-3.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5dfa23ba84541d7c0a531dfce76d8bcd19159d50a4a8b14ad01e91734a5c1b0", size = 253154278 }, ] [[package]] @@ -2796,7 +2792,7 @@ requires-dist = [ { name = "requests", marker = "extra == 'all'", specifier = ">=2.32.3" }, { name = "setuptools" }, { name = "toposolve" }, - { name = "torch", specifier = "==2.5.1" }, + { name = "torch", specifier = "==2.6.0" }, { name = "torch-shampoo", git = "https://github.com/facebookresearch/optimizers.git?rev=main" }, { name = "torchdata", specifier = ">=0.8.0" }, { name = "transformers", specifier = ">=4.44.2" },