diff --git a/gpt2_module_v3/README.md b/gpt2_module_v3/README.md new file mode 100644 index 0000000..a5a4e9b --- /dev/null +++ b/gpt2_module_v3/README.md @@ -0,0 +1,180 @@ +# GPT-2 Module V3 for MAX serving + +This module provides a custom GPT-2 architecture for serving with `max serve` using the Module V3 API. + +## Status + +**Current Status: Work in Progress** + +The model compiles and serves successfully, but produces incorrect (gibberish) output. The issue is still being investigated. The standalone model in `main.py` works correctly. + +## How to Run + +### Prerequisites + +```bash +pixi install +``` + +### Running the Server + +```bash +pixi run max serve \ + --model openai-community/gpt2 \ + --custom-architectures gpt2_module_v3 \ + --port 8888 +``` + +> Note: We do NOT use `--use-module-v3` here because we're registering a **new** architecture. the `--use-module-v3` flag is only needed when adding a new version of an existing MAX-registered architecture (it automatically appends `_ModuleV3` to the architecture name). + +### Testing the API + +GPT-2 is a base language model (not a chat model), so use the completions API: + +```bash +curl http://localhost:8888/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai-community/gpt2", + "prompt": "The future of AI", + "max_tokens": 30 + }' | jq . +``` + +Note: Do NOT use `/v1/chat/completions`. +GPT-2 does not have a chat template. + +## Files + +| File | Description | +|----------------------|----------------------------------------------------------| +| `__init__.py` | Exports `ARCHITECTURES` list for custom arch discovery | +| `arch.py` | Defines `SupportedArchitecture` for GPT-2 | +| `model.py` | `GPT2Model` class extending `PipelineModel` | +| `model_config.py` | `GPT2Config` dataclass for model configuration | +| `gpt2.py` | Neural network module definitions (attention, MLP, etc.) | +| `weight_adapters.py` | Converts HuggingFace safetensor weights to MAX format | + +## Architecture registration + +Key requirements for custom architecture registration: + +1. **Export `ARCHITECTURES`**: The `__init__.py` must export an `ARCHITECTURES` list: + + ```python + ARCHITECTURES = [gpt2] + ``` + +2. **Weight Adapter**: Register a weight adapter for safetensor format: + + ```python + weight_adapters={ + WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict, + } + ``` + +## Changes from Original Implementation + +The following changes were required to adapt the model from `main.py` for serving: + +### 1. Input Format Change + +**Original (main.py)**: 2D input `[batch_size, seq_length]` + +```python +batch_size, seq_length = input_ids.shape +``` + +**Serving**: 1D ragged input `[total_tokens]` with `input_row_offsets` + +```python +seq_length = tokens.shape[0] # Flattened tokens +# input_row_offsets tells where each sequence starts/ends +``` + +### 2. Position Embeddings + +Both implementations use `Tensor.arange`: + +```python +positions = Tensor.arange(seq_length, dtype=tokens.dtype, device=tokens.device) +pos_embeds = self.wpe(positions) +``` + +### 3. Weight Transposition + +GPT-2 uses Conv1D layers which store weights as `[in_features, out_features]`, but MAX's Linear expects `[out_features, in_features]`. Required transposition for: + +- `.c_attn.weight` +- `.c_proj.weight` +- `.c_fc.weight` + +### 4. Weight Adapter Output Format + +**Important**: The weight adapter must return **raw numpy arrays**, not `WeightData` objects: + +```python +# Correct: +new_state_dict[max_name] = arr + +# Wrong (causes issues): +new_state_dict[max_name] = WeightData.from_numpy(arr, max_name) +``` + +### 5. Contiguous Arrays + +Transposed arrays must be made contiguous: + +```python +arr = np.ascontiguousarray(arr.T) +``` + +### 6. Tied Embeddings + +GPT-2 ties `lm_head.weight` to `wte.weight`: + +```python +if "language_model.lm_head.weight" not in new_state_dict: + new_state_dict["language_model.lm_head.weight"] = wte_array.copy() +``` + +## Developer Experience Notes + +### Issues Encountered with MAX Experimental APIs + +1. **`F.range` vs `Tensor.arange`**: The functional `F.range` API was deprecated/changed. Had to use `Tensor.arange` instead. + +2. **DLPack Conversion**: Weight data from safetensors required careful conversion: + + ```python + arr = np.array(np.from_dlpack(weight_data), copy=True) + ``` + +3. **Non-Contiguous Tensor Errors**: MAX doesn't support non-contiguous tensors. Error message: + + ```output + ValueError: Max does not currently support executing non-contiguous tensors. + ``` + Solution: Always use `np.ascontiguousarray()` after transpose. + +4. **Weight Adapter Return Type**: Despite type hint `dict[str, WeightData]`, the actual return must be raw data (numpy arrays), following the pattern in `gpt_oss_module_v3`. + + +5. **Chat Template Error**: GPT-2 is a base model without chat template. Using `/v1/chat/completions` results in: + + ```output + ValueError: Cannot use chat template functions because tokenizer.chat_template is not set + ``` + + Solution: Use `/v1/completions` instead. + +### Known Issue: Incorrect Output + +The model currently produces gibberish output when served, despite: + +- Weights loading correctly (verified via logging) +- Weight shapes being correct +- Transposition being applied +- Tied embeddings being handled + +The standalone test comparing `main.py` model with the serving model shows identical output when weights are loaded via `load_state_dict` from PyTorch, suggesting the issue may be in how weights flow through the compile/serve pipeline. diff --git a/gpt2_module_v3/__init__.py b/gpt2_module_v3/__init__.py new file mode 100644 index 0000000..80c6bc1 --- /dev/null +++ b/gpt2_module_v3/__init__.py @@ -0,0 +1,19 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + + +from .arch import gpt2_module_v3_arch + +ARCHITECTURES = [gpt2_module_v3_arch] + +__all__ = ["gpt2_module_v3_arch", "ARCHITECTURES"] diff --git a/gpt2_module_v3/arch.py b/gpt2_module_v3/arch.py new file mode 100644 index 0000000..6040005 --- /dev/null +++ b/gpt2_module_v3/arch.py @@ -0,0 +1,50 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + + +from max.graph.weights import WeightsFormat +from max.interfaces import PipelineTask +from max.nn.kv_cache import KVCacheStrategy +from max.pipelines.lib import ( + RopeType, + SupportedArchitecture, + SupportedEncoding, + TextTokenizer, +) + +from . import weight_adapters +from .model import GPT2Model + +gpt2_module_v3_arch = SupportedArchitecture( + name="GPT2LMHeadModel", + example_repo_ids=[ + "openai-community/gpt2", + "openai-community/gpt2-medium", + "openai-community/gpt2-large", + "openai-community/gpt2-xl", + ], + default_encoding=SupportedEncoding.float32, + supported_encodings={ + SupportedEncoding.float32: [KVCacheStrategy.PAGED], + SupportedEncoding.bfloat16: [KVCacheStrategy.PAGED], + }, + pipeline_model=GPT2Model, + task=PipelineTask.TEXT_GENERATION, + tokenizer=TextTokenizer, + default_weights_format=WeightsFormat.safetensors, + multi_gpu_supported=False, + rope_type=RopeType.none, + weight_adapters={ + WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict, + }, +) diff --git a/gpt2_module_v3/gpt2.py b/gpt2_module_v3/gpt2.py new file mode 100644 index 0000000..10f9381 --- /dev/null +++ b/gpt2_module_v3/gpt2.py @@ -0,0 +1,275 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Implements the GPT-2 model for MAX serving.""" + +from __future__ import annotations + +import math +from collections.abc import Sequence + +from max.dtype import DType +from max.driver import Device +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.graph import BufferValue, Dim, DimLike, TensorValue +from max.kv_cache import NullKVCacheManager, PagedKVCacheManager +from max.nn.kv_cache import PagedCacheValues +from max.nn.module_v3 import Module +from max.nn.module_v3.embedding import Embedding +from max.nn.module_v3.linear import Linear +from max.nn.module_v3.sequential import Sequential + +from .model_config import GPT2Config + + +@F.functional +def causal_mask( + sequence_length: DimLike, + num_tokens: DimLike, + *, + dtype: DType, + device: Device, +): + """Create a causal attention mask.""" + n = Dim(sequence_length) + num_tokens + mask = Tensor.constant(float("-inf"), dtype=dtype, device=device) + mask = F.broadcast_to(mask, shape=(sequence_length, n)) + return F.band_part(mask, num_lower=None, num_upper=0, exclude=True) + + +class LayerNorm(Module): + """Layer normalization module.""" + + def __init__(self, dim: DimLike, *, eps: float = 1e-5): + self.eps = eps + self.weight = Tensor.ones([dim]) + self.bias = Tensor.zeros([dim]) + + def __call__(self, x: Tensor) -> Tensor: + return F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps) + + +class GPT2Attention(Module): + """GPT-2 attention matching HuggingFace structure.""" + + def __init__(self, config: GPT2Config): + super().__init__() + self.embed_dim = config.n_embd + self.num_heads = config.n_head + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + + self.c_attn = Linear(self.embed_dim, 3 * self.embed_dim, bias=True) + self.c_proj = Linear(self.embed_dim, self.embed_dim, bias=True) + + def _attn(self, query, key, value): + attn_weights = query @ key.transpose(-1, -2) + attn_weights = attn_weights / math.sqrt(int(value.shape[-1])) + + seq_len = query.shape[-2] + mask = causal_mask(seq_len, 0, dtype=query.dtype, device=query.device) + attn_weights = attn_weights + mask + + attn_weights = F.softmax(attn_weights) + attn_output = attn_weights @ value + + return attn_output + + def _split_heads(self, tensor, num_heads, attn_head_size): + """Split the last dimension into (num_heads, head_size).""" + new_shape = tensor.shape[:-1] + [num_heads, attn_head_size] + tensor = tensor.reshape(new_shape) + return tensor.transpose(-3, -2) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """Merge attention heads back.""" + tensor = tensor.transpose(-3, -2) + new_shape = tensor.shape[:-2] + [num_heads * attn_head_size] + return tensor.reshape(new_shape) + + def __call__(self, hidden_states): + query, key, value = F.split( + self.c_attn(hidden_states), + [self.split_size, self.split_size, self.split_size], + axis=2, + ) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + attn_output = self._attn(query, key, value) + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + + return attn_output + + +class GPT2MLP(Module): + """GPT-2 MLP structure.""" + + def __init__(self, intermediate_size: int, config: GPT2Config): + super().__init__() + embed_dim = config.n_embd + self.c_fc = Linear(embed_dim, intermediate_size, bias=True) + self.c_proj = Linear(intermediate_size, embed_dim, bias=True) + + def __call__(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = F.gelu(hidden_states, approximate="tanh") + hidden_states = self.c_proj(hidden_states) + return hidden_states + + +class GPT2Block(Module): + """GPT-2 transformer block.""" + + def __init__(self, config: GPT2Config): + super().__init__() + hidden_size = config.n_embd + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config) + self.ln_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(inner_dim, config) + + def __call__(self, hidden_states): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn(hidden_states) + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = residual + feed_forward_hidden_states + + return hidden_states + + +class GPT2TextModel(Module): + """The GPT-2 language model.""" + + def __init__(self, config: GPT2Config) -> None: + super().__init__() + self.devices = config.devices + + self.wte = Embedding(config.vocab_size, dim=config.n_embd) + self.wpe = Embedding(config.n_positions, dim=config.n_embd) + self.h = Sequential(*(GPT2Block(config) for _ in range(config.n_layer))) + self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.lm_head = Linear(config.n_embd, config.vocab_size, bias=False) + + self.n_embd = config.n_embd + self.kv_params = config.kv_params + self.return_logits = config.return_logits + + def __call__( + self, + tokens: Tensor, + kv_collection: PagedCacheValues, + return_n_logits: Tensor, + input_row_offsets: Tensor, + ) -> tuple[Tensor, ...]: + # Get sequence length from tokens + seq_length = tokens.shape[0] + + # Token embeddings + tok_embeds = self.wte(tokens) + + # Position embeddings using Tensor.arange like the original implementation + positions = Tensor.arange( + seq_length, dtype=tokens.dtype, device=tokens.device + ) + pos_embeds = self.wpe(positions) + + # Combine embeddings + h = tok_embeds + pos_embeds + + # Add batch dimension for transformer layers (they expect batch dim) + h = h.reshape([1, seq_length, self.n_embd]) + + # Run through transformer layers + h = self.h(h) + h = self.ln_f(h) + + # Remove batch dimension + h = h.reshape([seq_length, self.n_embd]) + + # Get last token logits per sequence using input_row_offsets + last_token_indices = input_row_offsets[1:] - 1 + last_token_h = F.gather(h, last_token_indices, axis=0) + last_logits = F.cast(self.lm_head(last_token_h), DType.float32) + + return (last_logits,) + + +class GPT2(Module): + """The GPT-2 model wrapper for serving.""" + + def __init__( + self, + config: GPT2Config, + kv_manager: PagedKVCacheManager | NullKVCacheManager, + ) -> None: + super().__init__() + self.language_model = GPT2TextModel(config) + self.config = config + self.kv_manager = kv_manager + + def __call__( + self, + tokens: Tensor, + return_n_logits: Tensor, + input_row_offsets: Tensor, + *variadic_args, + ) -> tuple[Tensor, ...]: + kv_collection = _unflatten_kv_inputs( + self.config, self.kv_manager, variadic_args + ) + return self.language_model( + tokens, kv_collection[0], return_n_logits, input_row_offsets + ) + + +def _unflatten_kv_inputs( + config: GPT2Config, + kv_manager: PagedKVCacheManager | NullKVCacheManager, + kv_inputs_flat: Sequence[Tensor], +) -> list[PagedCacheValues]: + """Unflatten KV cache inputs from variadic args.""" + kv_params = config.kv_params + n_devices = kv_params.n_devices + fetch_types = kv_manager.get_symbolic_inputs()[0] + len_of_kv_tuple_per_dev = len(list(fetch_types)) + kv_caches_per_dev: list[PagedCacheValues] = [] + + for i in range(n_devices): + start_idx = i * len_of_kv_tuple_per_dev + + kv_block = kv_inputs_flat[start_idx] + cache_lengths = kv_inputs_flat[start_idx + 1] + lookup_table = kv_inputs_flat[start_idx + 2] + max_lengths = kv_inputs_flat[start_idx + 3] + + kv_caches_per_dev.append( + PagedCacheValues( + kv_blocks=BufferValue(kv_block), + cache_lengths=TensorValue(cache_lengths), + lookup_table=TensorValue(lookup_table), + max_lengths=TensorValue(max_lengths), + ) + ) + return kv_caches_per_dev diff --git a/gpt2_module_v3/model.py b/gpt2_module_v3/model.py new file mode 100644 index 0000000..6d02ce7 --- /dev/null +++ b/gpt2_module_v3/model.py @@ -0,0 +1,326 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +import logging +import time +from collections.abc import Callable, Sequence +from typing import Any, cast + +import numpy as np +from max.driver import Device, Tensor +from max.dtype import DType +from max.engine import InferenceSession +from max.graph import DeviceRef, TensorType +from max.graph.weights import Weights, WeightsAdapter +from max.kv_cache import ( + NullKVCacheManager, + PagedKVCacheManager, + estimate_kv_cache_size, + load_kv_manager, +) +from max.nn import ReturnLogits +from max.nn.kv_cache import KVCacheInputs, KVCacheInputsSequence, KVCacheParams +from max.pipelines.core import TextContext +from max.pipelines.lib import ( + KVCacheConfig, + KVCacheMixin, + ModelInputs, + ModelOutputs, + PipelineConfig, + PipelineModel, + SupportedEncoding, +) +from transformers import AutoConfig + +from .gpt2 import GPT2 +from .model_config import GPT2Config + +logger = logging.getLogger("max.pipelines") + + +class GPT2Inputs(ModelInputs): + """A class representing inputs for the GPT-2 model.""" + + tokens: Tensor + """Tensor containing the input token IDs.""" + + input_row_offsets: Tensor + """Tensor containing the offsets for each row in the ragged input sequence.""" + + def __init__( + self, + tokens: Tensor, + input_row_offsets: Tensor, + return_n_logits: Tensor, + kv_cache_inputs: KVCacheInputs | None = None, + ) -> None: + self.tokens = tokens + self.input_row_offsets = input_row_offsets + self.kv_cache_inputs = kv_cache_inputs + self.return_n_logits = return_n_logits + + +class GPT2Model(PipelineModel[TextContext], KVCacheMixin): + """A GPT-2 pipeline model for text generation.""" + + def __init__( + self, + pipeline_config: PipelineConfig, + session: InferenceSession, + huggingface_config: AutoConfig, + encoding: SupportedEncoding, + devices: list[Device], + kv_cache_config: KVCacheConfig, + weights: Weights, + adapter: WeightsAdapter | None = None, + return_logits: ReturnLogits = ReturnLogits.LAST_TOKEN, + ) -> None: + super().__init__( + pipeline_config, + session, + huggingface_config, + encoding, + devices, + kv_cache_config, + weights, + adapter, + return_logits, + ) + + self.model = self.load_model() + + @staticmethod + def calculate_max_seq_len( + pipeline_config: PipelineConfig, huggingface_config: AutoConfig + ) -> int: + """Calculates the maximum sequence length for the GPT-2 model.""" + max_seq_len = pipeline_config.max_length + if max_seq_len: + return max_seq_len + return huggingface_config.n_positions + + @classmethod + def get_kv_params( + cls, + huggingface_config: AutoConfig, + n_devices: int, + kv_cache_config: KVCacheConfig, + cache_dtype: DType, + ) -> KVCacheParams: + """Gets the parameters required to configure the KV cache for GPT-2.""" + return GPT2Config.get_kv_params( + huggingface_config, n_devices, kv_cache_config, cache_dtype + ) + + @classmethod + def get_num_layers(cls, huggingface_config: AutoConfig) -> int: + """Gets the number of hidden layers from the HuggingFace configuration.""" + return GPT2Config.get_num_layers(huggingface_config) + + @classmethod + def estimate_kv_cache_size( + cls, + pipeline_config: PipelineConfig, + available_cache_memory: int, + devices: list[Device], + huggingface_config: AutoConfig, + kv_cache_config: KVCacheConfig, + cache_dtype: DType, + ) -> int: + """Estimates the size of the KV cache required for the GPT-2 model in bytes.""" + return estimate_kv_cache_size( + params=GPT2Config.get_kv_params( + huggingface_config=huggingface_config, + n_devices=len(devices), + kv_cache_config=kv_cache_config, + cache_dtype=cache_dtype, + ), + max_batch_size=pipeline_config.max_batch_size, + max_seq_len=cls.calculate_max_seq_len( + pipeline_config, huggingface_config=huggingface_config + ), + available_cache_memory=available_cache_memory, + ) + + def load_model(self) -> Callable[..., Any]: + """Loads the compiled GPT-2 model into the MAX Engine session.""" + + assert self.pipeline_config.max_batch_size, ( + "Expected max_batch_size to be set" + ) + self._input_row_offsets_prealloc = Tensor.from_numpy( + np.arange(self.pipeline_config.max_batch_size + 1, dtype=np.uint32) + ).to(self.devices[0]) + + logger.info("Building and compiling model...") + before = time.perf_counter() + + device0 = self.devices[0] + device_ref = DeviceRef(device0.label, device0.id) + tokens_type = TensorType( + DType.int64, shape=["total_seq_len"], device=device_ref + ) + input_row_offsets_type = TensorType( + DType.uint32, + shape=["input_row_offsets_len"], + device=device0, + ) + return_n_logits_type = TensorType( + DType.int64, shape=["return_n_logits"], device=DeviceRef.CPU() + ) + + huggingface_config = self.huggingface_config + if self.adapter: + state_dict = self.adapter( + dict(self.weights.items()), + huggingface_config=huggingface_config, + pipeline_config=self.pipeline_config, + ) + else: + state_dict = { + key: value.data() for key, value in self.weights.items() + } + model_config = GPT2Config.generate( + pipeline_config=self.pipeline_config, + huggingface_config=huggingface_config, + state_dict=state_dict, + dtype=self.dtype, + n_devices=len(self.devices), + cache_dtype=self.encoding.cache_dtype, + kv_cache_config=self.kv_cache_config, + return_logits=self.return_logits, + ) + nn_model = GPT2(model_config, self.kv_manager) + nn_model.to(self.devices[0]) + + kv_inputs = self.kv_manager.get_symbolic_inputs() + flattened_kv_types = [ + kv_type for sublist in kv_inputs for kv_type in sublist + ] + + compiled_model = nn_model.compile( + tokens_type, + return_n_logits_type, + input_row_offsets_type, + *flattened_kv_types, + weights=state_dict, + ) + after = time.perf_counter() + + logger.info( + f"Building and compiling model took {after - before:.6f} seconds" + ) + + return compiled_model + + def execute(self, model_inputs: ModelInputs) -> ModelOutputs: + """Executes the GPT-2 model with the prepared inputs.""" + model_inputs = cast(GPT2Inputs, model_inputs) + curr_kv_cache_inputs = model_inputs.kv_cache_inputs or () + + if isinstance(model_inputs.input_row_offsets, np.ndarray): + tensor = Tensor.from_numpy(model_inputs.input_row_offsets) + input_row_offsets = tensor.to(self.devices[0]) + else: + input_row_offsets = model_inputs.input_row_offsets + + model_outputs = self.model( + model_inputs.tokens, + model_inputs.return_n_logits, + input_row_offsets, + *curr_kv_cache_inputs, + ) + if len(model_outputs) == 3: + return ModelOutputs( + logits=cast(Tensor, model_outputs[1].driver_tensor), + next_token_logits=cast(Tensor, model_outputs[0].driver_tensor), + logit_offsets=cast(Tensor, model_outputs[2].driver_tensor), + ) + else: + return ModelOutputs( + logits=cast(Tensor, model_outputs[0].driver_tensor), + next_token_logits=cast(Tensor, model_outputs[0].driver_tensor), + ) + + def prepare_initial_token_inputs( + self, + replica_batches: Sequence[Sequence[TextContext]], + kv_cache_inputs: KVCacheInputs | None = None, + return_n_logits: int = 1, + ) -> ModelInputs: + """Prepares the initial inputs for the first execution pass of the GPT-2 model.""" + if len(replica_batches) > 1: + raise ValueError("Model does not support DP>1") + + context_batch = replica_batches[0] + assert kv_cache_inputs is not None + kv_cache_inputs = cast(KVCacheInputsSequence, kv_cache_inputs) + + input_row_offsets = np.cumsum( + [0] + [ctx.active_length for ctx in context_batch], dtype=np.uint32 + ) + + tokens = np.concatenate([ctx.next_tokens for ctx in context_batch]) + + input_row_offsets_tensor = Tensor.from_numpy(input_row_offsets).to( + self.devices[0] + ) + + return GPT2Inputs( + tokens=Tensor.from_numpy(tokens).to(self.devices[0]), + input_row_offsets=input_row_offsets_tensor, + return_n_logits=Tensor.from_numpy( + np.array([return_n_logits], dtype=np.int64) + ), + kv_cache_inputs=kv_cache_inputs, + ) + + def prepare_next_token_inputs( + self, next_tokens: Tensor, prev_model_inputs: ModelInputs + ) -> ModelInputs: + """Prepares the inputs for subsequent execution steps in a multi-step generation.""" + prev_model_inputs = cast(GPT2Inputs, prev_model_inputs) + row_offsets_size = prev_model_inputs.input_row_offsets.shape[0] + + next_row_offsets = self._input_row_offsets_prealloc[ + :row_offsets_size + ].to(self.devices[0]) + + return GPT2Inputs( + tokens=next_tokens, + input_row_offsets=next_row_offsets, + return_n_logits=prev_model_inputs.return_n_logits, + kv_cache_inputs=prev_model_inputs.kv_cache_inputs, + ) + + def load_kv_manager( + self, session: InferenceSession, available_cache_memory: int | None + ) -> PagedKVCacheManager | NullKVCacheManager: + """Loads and initializes the KVCacheManager for the GPT-2 model.""" + return load_kv_manager( + params=GPT2Config.get_kv_params( + huggingface_config=self.huggingface_config, + n_devices=len(self.devices), + kv_cache_config=self.kv_cache_config, + cache_dtype=self.encoding.cache_dtype, + ), + max_batch_size=self.pipeline_config.max_batch_size, + max_seq_len=self.calculate_max_seq_len( + self.pipeline_config, huggingface_config=self.huggingface_config + ), + devices=self.devices, + available_cache_memory=available_cache_memory, + session=session, + ) diff --git a/gpt2_module_v3/model_config.py b/gpt2_module_v3/model_config.py new file mode 100644 index 0000000..4154e10 --- /dev/null +++ b/gpt2_module_v3/model_config.py @@ -0,0 +1,145 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +from dataclasses import dataclass + +from max.dtype import DType +from max.graph import DeviceRef +from max.graph.weights import WeightData, WeightsFormat, weights_format +from max.nn import ReturnLogits +from max.nn.kv_cache import KVCacheParams +from max.pipelines.lib import ( + KVCacheConfig, + MAXModelConfig, + MAXModelConfigBase, + PipelineConfig, +) +from transformers import AutoConfig + + +@dataclass +class GPT2ConfigBase(MAXModelConfigBase): + """Base configuration for GPT-2 models.""" + + vocab_size: int + """Vocabulary size of the GPT-2 model.""" + + n_positions: int + """Maximum sequence length the model can handle.""" + + n_embd: int + """Dimension of the hidden representations.""" + + n_layer: int + """Number of hidden layers in the Transformer decoder.""" + + n_head: int + """Number of attention heads for each attention layer.""" + + n_inner: int | None + """Dimension of the MLP representations. If None, defaults to 4 * n_embd.""" + + layer_norm_epsilon: float + """The epsilon used by the layer normalization layers.""" + + # MAX-specific config parameters + dtype: DType + """DType of the model weights and input.""" + + devices: list[DeviceRef] + """Devices to run the model with.""" + + return_logits: ReturnLogits + """Whether to return the last token, all logits, or a variable number of logits.""" + + kv_params: KVCacheParams + """KV cache parameters.""" + + +@dataclass +class GPT2Config(MAXModelConfig, GPT2ConfigBase): + """Represents the complete MAX Engine configuration for GPT-2 models.""" + + @staticmethod + def get_kv_params( + huggingface_config: AutoConfig, + n_devices: int, + kv_cache_config: KVCacheConfig, + cache_dtype: DType, + ) -> KVCacheParams: + """Constructs the KV cache parameters from configuration objects.""" + return KVCacheParams( + dtype=cache_dtype, + num_layers=GPT2Config.get_num_layers(huggingface_config), + n_kv_heads=huggingface_config.n_head, + head_dim=huggingface_config.n_embd // huggingface_config.n_head, + page_size=kv_cache_config.kv_cache_page_size, + cache_strategy=kv_cache_config.cache_strategy, + enable_prefix_caching=kv_cache_config.enable_prefix_caching, + enable_kvcache_swapping_to_host=kv_cache_config.enable_kvcache_swapping_to_host, + host_kvcache_swap_space_gb=kv_cache_config.host_kvcache_swap_space_gb, + n_devices=n_devices, + ) + + @staticmethod + def get_num_layers(huggingface_config: AutoConfig) -> int: + """Retrieves the number of hidden layers from the HuggingFace configuration.""" + return huggingface_config.n_layer + + @staticmethod + def calculate_max_seq_len( + pipeline_config: PipelineConfig, huggingface_config: AutoConfig + ) -> int: + """Calculates the maximum sequence length for the model.""" + max_seq_len = pipeline_config.max_length + if max_seq_len: + return max_seq_len + return huggingface_config.n_positions + + @staticmethod + def generate( + pipeline_config: PipelineConfig, + huggingface_config: AutoConfig, + state_dict: dict[str, WeightData], + dtype: DType, + n_devices: int, + cache_dtype: DType, + kv_cache_config: KVCacheConfig, + return_logits: ReturnLogits, + ) -> GPT2Config: + """Generates a GPT2Config instance from various configuration sources.""" + device_refs = [ + DeviceRef(spec.device_type, spec.id) + for spec in pipeline_config.model_config.device_specs + ] + + return GPT2Config( + vocab_size=huggingface_config.vocab_size, + n_positions=huggingface_config.n_positions, + n_embd=huggingface_config.n_embd, + n_layer=huggingface_config.n_layer, + n_head=huggingface_config.n_head, + n_inner=getattr(huggingface_config, "n_inner", None), + layer_norm_epsilon=huggingface_config.layer_norm_epsilon, + dtype=dtype, + devices=device_refs, + return_logits=return_logits, + kv_params=GPT2Config.get_kv_params( + huggingface_config=huggingface_config, + n_devices=n_devices, + kv_cache_config=kv_cache_config, + cache_dtype=cache_dtype, + ), + ) diff --git a/gpt2_module_v3/weight_adapters.py b/gpt2_module_v3/weight_adapters.py new file mode 100644 index 0000000..6c0bd22 --- /dev/null +++ b/gpt2_module_v3/weight_adapters.py @@ -0,0 +1,92 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +import numpy as np +from max.graph.weights import WeightData, Weights + +# Mapping from HuggingFace GPT-2 safetensor names to MAX format +# Note: GPT-2 safetensors don't have 'transformer.' prefix +GPT2_SAFETENSOR_MAP: dict[str, str] = { + "wte.": "language_model.wte.", + "wpe.": "language_model.wpe.", + "ln_f.": "language_model.ln_f.", + "h.": "language_model.h.", +} + +# Weights that need to be transposed (Conv1D -> Linear) +# GPT-2 uses Conv1D which stores weights as [in_features, out_features] +# Linear expects [out_features, in_features] +TRANSPOSE_WEIGHTS = [ + ".c_attn.weight", + ".c_proj.weight", + ".c_fc.weight", +] + + +def convert_safetensor_state_dict( + state_dict: dict[str, Weights], **kwargs +) -> dict[str, WeightData]: + """Convert safetensor state dict to MAX format. + + Args: + state_dict: Dictionary of weight tensors + + Returns: + Dictionary of converted weight data (raw arrays, not WeightData objects) + + Note: + Despite the return type hint, this function returns raw numpy arrays, + not WeightData objects. This follows the pattern used in gpt_oss_module_v3. + """ + new_state_dict: dict[str, WeightData] = {} + wte_array = None # Keep track of wte array for tying + + for weight_name, value in state_dict.items(): + max_name: str = weight_name + + # Skip attention bias buffers (causal masks) - we generate these dynamically + if weight_name.endswith(".attn.bias"): + continue + + # Remap weight names from HuggingFace to MAX format + for before, after in GPT2_SAFETENSOR_MAP.items(): + max_name = max_name.replace(before, after) + + # Get the weight data and convert to numpy array + weight_data = value.data() + arr = np.array(np.from_dlpack(weight_data), copy=True) + + # Transpose Conv1D weights to Linear format + needs_transpose = any(pat in weight_name for pat in TRANSPOSE_WEIGHTS) + if needs_transpose: + # Conv1D: [in_features, out_features] -> Linear: [out_features, in_features] + arr = np.ascontiguousarray(arr.T) + else: + # Ensure all arrays are contiguous + arr = np.ascontiguousarray(arr) + + # Keep wte array for tying embeddings + if max_name == "language_model.wte.weight": + wte_array = arr + + # Return raw array (like gpt_oss_module_v3), not WeightData object + new_state_dict[max_name] = arr + + # Handle tied embeddings - if lm_head.weight is missing, copy from wte.weight + if "language_model.lm_head.weight" not in new_state_dict: + if wte_array is not None: + new_state_dict["language_model.lm_head.weight"] = wte_array.copy() + + return new_state_dict diff --git a/pixi.lock b/pixi.lock index 818053d..285f273 100644 --- a/pixi.lock +++ b/pixi.lock @@ -45,10 +45,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/click-8.3.0-pyh707e725_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/cpython-3.13.7-py313hd8ed1ab_100.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/cyclopts-4.3.0-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/datasets-4.2.0-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/deprecated-1.2.18-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/dill-0.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/dnspython-2.8.0-pyhcf101f3_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/docstring_parser-0.17.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/docutils-0.22.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/email-validator-2.3.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/email_validator-2.3.0-hd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.3.0-pyhd8ed1ab_0.conda @@ -149,17 +152,17 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.10.0-h5888daf_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-4.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/markupsafe-3.0.3-py313h3dea7bd_0.conda - - conda: https://conda.modular.com/max-nightly/linux-64/max-25.7.0.dev2025100805-3.13release.conda - - conda: https://conda.modular.com/max-nightly/linux-64/max-core-25.7.0.dev2025100805-release.conda - - conda: https://conda.modular.com/max-nightly/noarch/max-pipelines-25.7.0.dev2025100805-release.conda - - conda: https://conda.modular.com/max-nightly/noarch/mblack-25.7.0.dev2025100805-release.conda + - conda: https://conda.modular.com/max-nightly/linux-64/max-26.1.0.dev2025121105-3.13release.conda + - conda: https://conda.modular.com/max-nightly/linux-64/max-core-26.1.0.dev2025121105-release.conda + - conda: https://conda.modular.com/max-nightly/noarch/max-pipelines-26.1.0.dev2025121105-release.conda + - conda: https://conda.modular.com/max-nightly/noarch/mblack-26.1.0.dev2025121105-release.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/mdbook-0.4.52-hb17b654_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/mkl-2024.2.2-ha770c72_17.conda - - conda: https://conda.modular.com/max-nightly/noarch/modular-25.7.0.dev2025100805-release.conda - - conda: https://conda.modular.com/max-nightly/linux-64/mojo-0.25.7.0.dev2025100805-release.conda - - conda: https://conda.modular.com/max-nightly/linux-64/mojo-compiler-0.25.7.0.dev2025100805-release.conda - - conda: https://conda.modular.com/max-nightly/noarch/mojo-python-0.25.7.0.dev2025100805-release.conda + - conda: https://conda.modular.com/max-nightly/noarch/modular-26.1.0.dev2025121105-release.conda + - conda: https://conda.modular.com/max-nightly/linux-64/mojo-0.26.1.0.dev2025121105-release.conda + - conda: https://conda.modular.com/max-nightly/linux-64/mojo-compiler-0.26.1.0.dev2025121105-release.conda + - conda: https://conda.modular.com/max-nightly/noarch/mojo-python-0.26.1.0.dev2025121105-release.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/mpc-1.3.1-h24ddda3_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/mpfr-4.2.1-h90cbb55_3.conda - conda: https://conda.anaconda.org/conda-forge/noarch/mpmath-1.3.0-pyhd8ed1ab_1.conda @@ -223,6 +226,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/regex-2025.9.18-py313h07c4f96_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/requests-2.32.5-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/rich-14.2.0-pyhcf101f3_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/rich-rst-1.3.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/rich-toolkit-0.15.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.14.2-ha3a3aed_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/s2n-1.5.26-h5ac9029_0.conda @@ -312,10 +316,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/click-8.3.0-pyh707e725_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/cpython-3.13.7-py313hd8ed1ab_100.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/cyclopts-4.3.0-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/datasets-4.2.0-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/deprecated-1.2.18-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/dill-0.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/dnspython-2.8.0-pyhcf101f3_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/docstring_parser-0.17.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/docutils-0.22.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/email-validator-2.3.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/email_validator-2.3.0-hd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.3.0-pyhd8ed1ab_0.conda @@ -410,16 +417,16 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/lz4-c-1.10.0-h286801f_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-4.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/markupsafe-3.0.3-py313h7d74516_0.conda - - conda: https://conda.modular.com/max-nightly/osx-arm64/max-25.7.0.dev2025100805-3.13release.conda - - conda: https://conda.modular.com/max-nightly/osx-arm64/max-core-25.7.0.dev2025100805-release.conda - - conda: https://conda.modular.com/max-nightly/noarch/max-pipelines-25.7.0.dev2025100805-release.conda - - conda: https://conda.modular.com/max-nightly/noarch/mblack-25.7.0.dev2025100805-release.conda + - conda: https://conda.modular.com/max-nightly/osx-arm64/max-26.1.0.dev2025121105-3.13release.conda + - conda: https://conda.modular.com/max-nightly/osx-arm64/max-core-26.1.0.dev2025121105-release.conda + - conda: https://conda.modular.com/max-nightly/noarch/max-pipelines-26.1.0.dev2025121105-release.conda + - conda: https://conda.modular.com/max-nightly/noarch/mblack-26.1.0.dev2025121105-release.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/mdbook-0.4.52-hcdef695_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_1.conda - - conda: https://conda.modular.com/max-nightly/noarch/modular-25.7.0.dev2025100805-release.conda - - conda: https://conda.modular.com/max-nightly/osx-arm64/mojo-0.25.7.0.dev2025100805-release.conda - - conda: https://conda.modular.com/max-nightly/osx-arm64/mojo-compiler-0.25.7.0.dev2025100805-release.conda - - conda: https://conda.modular.com/max-nightly/noarch/mojo-python-0.25.7.0.dev2025100805-release.conda + - conda: https://conda.modular.com/max-nightly/noarch/modular-26.1.0.dev2025121105-release.conda + - conda: https://conda.modular.com/max-nightly/osx-arm64/mojo-0.26.1.0.dev2025121105-release.conda + - conda: https://conda.modular.com/max-nightly/osx-arm64/mojo-compiler-0.26.1.0.dev2025121105-release.conda + - conda: https://conda.modular.com/max-nightly/noarch/mojo-python-0.26.1.0.dev2025121105-release.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/mpc-1.3.1-h8f1351a_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/mpfr-4.2.1-hb693164_3.conda - conda: https://conda.anaconda.org/conda-forge/noarch/mpmath-1.3.0-pyhd8ed1ab_1.conda @@ -484,6 +491,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/regex-2025.9.18-py313h6535dbc_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/requests-2.32.5-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/rich-14.2.0-pyhcf101f3_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/rich-rst-1.3.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/rich-toolkit-0.15.1-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.14.2-h492a034_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/safetensors-0.6.2-py313h6e3aefc_1.conda @@ -1295,6 +1303,22 @@ packages: license: Python-2.0 size: 48174 timestamp: 1756909387263 +- conda: https://conda.anaconda.org/conda-forge/noarch/cyclopts-4.3.0-pyhcf101f3_0.conda + sha256: d5fec265be19ebe74c675156cc0646ad16317000544ea6edb7db277fbbd33b1f + md5: 5ad57b23d6e3e1b1a2839223f90fc0e2 + depends: + - python >=3.10 + - attrs >=23.1.0 + - rich >=13.6.0 + - docstring_parser >=0.15,<4.0 + - rich-rst >=1.3.1,<2.0.0 + - typing_extensions >=4.8.0 + - tomli >=2.0.0 + - python + license: Apache-2.0 + license_family: APACHE + size: 151341 + timestamp: 1764622199978 - conda: https://conda.anaconda.org/conda-forge/noarch/datasets-4.2.0-pyhcf101f3_0.conda sha256: 341667142604009a0f18bfeda809fce74e107620f89eecc21135cc0e99196c37 md5: 1cfa29cb97dfd0280959f70454b8b225 @@ -1358,6 +1382,23 @@ packages: license: ISC size: 196500 timestamp: 1757292856922 +- conda: https://conda.anaconda.org/conda-forge/noarch/docstring_parser-0.17.0-pyhd8ed1ab_0.conda + sha256: 3069a555097f084d3b7bc8f9efbb42f9907ecbfa24d310c63df9814a8df491af + md5: ce49d3e5a7d20be2ba57a2c670bdd82e + depends: + - python >=3.9 + license: MIT + license_family: MIT + size: 31742 + timestamp: 1753195731224 +- conda: https://conda.anaconda.org/conda-forge/noarch/docutils-0.22.3-pyhd8ed1ab_0.conda + sha256: ab77ee201665dc654248e3a250bd6fe05db0a1892716a6feb8da4a3162518624 + md5: abbe8c85619c87c4f4f61b44173434af + depends: + - python >=3.10 + license: CC-PDDC AND BSD-3-Clause AND BSD-2-Clause AND ZPL-2.1 + size: 436965 + timestamp: 1762425841874 - conda: https://conda.anaconda.org/conda-forge/noarch/email-validator-2.3.0-pyhd8ed1ab_0.conda sha256: c37320864c35ef996b0e02e289df6ee89582d6c8e233e18dc9983375803c46bb md5: 3bc0ac31178387e8ed34094d9481bfe8 @@ -3560,28 +3601,34 @@ packages: license_family: BSD size: 25778 timestamp: 1759055530601 -- conda: https://conda.modular.com/max-nightly/linux-64/max-25.7.0.dev2025100805-3.13release.conda - sha256: bde8bc66223e89fa6e91fd64b592702384482e26ad0e42bcfb8fa31553bc9a86 +- conda: https://conda.modular.com/max-nightly/linux-64/max-26.1.0.dev2025121105-3.13release.conda + sha256: af609f394371cd9dceaad6426d0fe2e005790c0dc676525f3de3636331c8a426 depends: - numpy >=1.18 - typing-extensions >=4.12.2 + - pyyaml >=6.0.1 - python-gil - - max-core ==25.7.0.dev2025100805 release + - max-core ==26.1.0.dev2025121105 release - python_abi 3.13.* *_cp313 constrains: - click >=8.0.0 + - cyclopts >=4.2.5 - gguf >=0.17.1 - hf-transfer >=0.1.9 - huggingface_hub >=0.28.0 - jinja2 >=3.1.6 - - llguidance >=1.0.1 + - llguidance >=0.7.30 - pillow >=11.0.0 - psutil >=6.1.1 + - pydantic-settings >=2.7.1 + - pydantic - requests >=2.32.3 - rich >=13.0.1 - sentencepiece >=0.2.0 + - taskgroup >=0.2.2 + - tomli >=2.0.0 - tqdm >=4.67.1 - - transformers >=4.55.0 + - transformers >=4.57.0,<5.0.0 - uvicorn >=0.34.0 - uvloop >=0.21.0 - aiofiles >=24.1.0 @@ -3596,8 +3643,6 @@ packages: - opentelemetry-sdk >=1.29.0,<1.36.0 - prometheus_client >=0.21.0 - protobuf >=6.31.1,<6.32.0 - - pydantic-settings >=2.7.1 - - pydantic - pyinstrument >=5.0.1 - python-json-logger >=2.0.7 - pyzmq >=26.3.0 @@ -3605,33 +3650,38 @@ packages: - scipy >=1.13.0 - sse-starlette >=2.1.2 - starlette >=0.47.2 - - taskgroup >=0.2.2 - tokenizers >=0.19.0 license: LicenseRef-Modular-Proprietary - size: 6417323 - timestamp: 1759900825814 -- conda: https://conda.modular.com/max-nightly/osx-arm64/max-25.7.0.dev2025100805-3.13release.conda - sha256: 4d20aa324d81495c7a5c8dcf3daf9b82b867cca2f7198c98c91e9cb9ca260787 + size: 5904071 + timestamp: 1765430656228 +- conda: https://conda.modular.com/max-nightly/osx-arm64/max-26.1.0.dev2025121105-3.13release.conda + sha256: bcb3917ab1c4888863fa434a95997cc7e90d71454bf9ca2aefbfaa10b0f0df36 depends: - numpy >=1.18 - typing-extensions >=4.12.2 + - pyyaml >=6.0.1 - python-gil - - max-core ==25.7.0.dev2025100805 release + - max-core ==26.1.0.dev2025121105 release - python_abi 3.13.* *_cp313 constrains: - click >=8.0.0 + - cyclopts >=4.2.5 - gguf >=0.17.1 - hf-transfer >=0.1.9 - huggingface_hub >=0.28.0 - jinja2 >=3.1.6 - - llguidance >=1.0.1 + - llguidance >=0.7.30 - pillow >=11.0.0 - psutil >=6.1.1 + - pydantic-settings >=2.7.1 + - pydantic - requests >=2.32.3 - rich >=13.0.1 - sentencepiece >=0.2.0 + - taskgroup >=0.2.2 + - tomli >=2.0.0 - tqdm >=4.67.1 - - transformers >=4.55.0 + - transformers >=4.57.0,<5.0.0 - uvicorn >=0.34.0 - uvloop >=0.21.0 - aiofiles >=24.1.0 @@ -3646,8 +3696,6 @@ packages: - opentelemetry-sdk >=1.29.0,<1.36.0 - prometheus_client >=0.21.0 - protobuf >=6.31.1,<6.32.0 - - pydantic-settings >=2.7.1 - - pydantic - pyinstrument >=5.0.1 - python-json-logger >=2.0.7 - pyzmq >=26.3.0 @@ -3655,42 +3703,46 @@ packages: - scipy >=1.13.0 - sse-starlette >=2.1.2 - starlette >=0.47.2 - - taskgroup >=0.2.2 - tokenizers >=0.19.0 license: LicenseRef-Modular-Proprietary - size: 9224286 - timestamp: 1759901122330 -- conda: https://conda.modular.com/max-nightly/linux-64/max-core-25.7.0.dev2025100805-release.conda - sha256: aadaccc24f36ee82a36733c3575e5ea2c89d3c4370465828e9d89e84508cf559 + size: 8402143 + timestamp: 1765431174146 +- conda: https://conda.modular.com/max-nightly/linux-64/max-core-26.1.0.dev2025121105-release.conda + sha256: 3cccb6c921636813014b7f9d859e08dcefd9bb63359a9f7f42c691dcc11d6ebd depends: - - mojo-compiler ==0.25.7.0.dev2025100805 release + - mojo-compiler ==0.26.1.0.dev2025121105 release license: LicenseRef-Modular-Proprietary - size: 77595859 - timestamp: 1759900825813 -- conda: https://conda.modular.com/max-nightly/osx-arm64/max-core-25.7.0.dev2025100805-release.conda - sha256: 1cf47887f83373b9953fc3e7581fee841a5169896fb1ca991f049cf962a077be + size: 119835262 + timestamp: 1765430656228 +- conda: https://conda.modular.com/max-nightly/osx-arm64/max-core-26.1.0.dev2025121105-release.conda + sha256: a787c25efbf9b41956d6ca062f1d5238b19c97c9b1dd8989ba2822d3393de559 depends: - - mojo-compiler ==0.25.7.0.dev2025100805 release + - mojo-compiler ==0.26.1.0.dev2025121105 release license: LicenseRef-Modular-Proprietary - size: 72139580 - timestamp: 1759901122330 -- conda: https://conda.modular.com/max-nightly/noarch/max-pipelines-25.7.0.dev2025100805-release.conda + size: 80084970 + timestamp: 1765431174145 +- conda: https://conda.modular.com/max-nightly/noarch/max-pipelines-26.1.0.dev2025121105-release.conda noarch: python - sha256: 371bd0489d6822f3e497ff75b4333c27a71a2c54f1df6541923905139026a5a1 + sha256: 1533efb14826ed804da614da12861ca3b43b17d07c973f95ad37c5d5cac972a9 depends: - click >=8.0.0 + - cyclopts >=4.2.5 - gguf >=0.17.1 - hf-transfer >=0.1.9 - huggingface_hub >=0.28.0 - jinja2 >=3.1.6 - - llguidance >=1.0.1 + - llguidance >=0.7.30 - pillow >=11.0.0 - psutil >=6.1.1 + - pydantic-settings >=2.7.1 + - pydantic - requests >=2.32.3 - rich >=13.0.1 - sentencepiece >=0.2.0 + - taskgroup >=0.2.2 + - tomli >=2.0.0 - tqdm >=4.67.1 - - transformers >=4.55.0 + - transformers >=4.57.0,<5.0.0 - uvicorn >=0.34.0 - uvloop >=0.21.0 - aiofiles >=24.1.0 @@ -3705,8 +3757,6 @@ packages: - opentelemetry-sdk >=1.29.0,<1.36.0 - prometheus_client >=0.21.0 - protobuf >=6.31.1,<6.32.0 - - pydantic-settings >=2.7.1 - - pydantic - pyinstrument >=5.0.1 - python-json-logger >=2.0.7 - pyzmq >=26.3.0 @@ -3714,15 +3764,14 @@ packages: - scipy >=1.13.0 - sse-starlette >=2.1.2 - starlette >=0.47.2 - - taskgroup >=0.2.2 - tokenizers >=0.19.0 - - max >=25.7.0.dev2025100805,<26.0a0 + - max >=26.1.0.dev2025121105,<27.0a0 license: LicenseRef-Modular-Proprietary - size: 9976 - timestamp: 1759900825813 -- conda: https://conda.modular.com/max-nightly/noarch/mblack-25.7.0.dev2025100805-release.conda + size: 16853 + timestamp: 1765430656228 +- conda: https://conda.modular.com/max-nightly/noarch/mblack-26.1.0.dev2025121105-release.conda noarch: python - sha256: c1100f036b452171b8f8e28032b7aca60a0547da1e27317d21a6498c0ddee76a + sha256: 14f7095b631e39fa486ccf42de4098ebc4d898016baad04e24ec335af39cd101 depends: - python >=3.10 - click >=8.0.0 @@ -3734,8 +3783,8 @@ packages: - typing_extensions >=v4.12.2 - python license: MIT - size: 131951 - timestamp: 1759900825814 + size: 138237 + timestamp: 1765430656228 - conda: https://conda.anaconda.org/conda-forge/linux-64/mdbook-0.4.52-hb17b654_0.conda sha256: 324d8270a6bc20f6b6ece1d05507a1415585dc84b1f980671328df29aad2c2ff md5: 50207b127e50bc000a4243355a88980a @@ -3780,57 +3829,57 @@ packages: license_family: Proprietary size: 124988693 timestamp: 1753975818422 -- conda: https://conda.modular.com/max-nightly/noarch/modular-25.7.0.dev2025100805-release.conda +- conda: https://conda.modular.com/max-nightly/noarch/modular-26.1.0.dev2025121105-release.conda noarch: python - sha256: 2f371369d6b21afa31029625a856bf76c0a9501360cc00819aac925b76d3880f + sha256: d11bbc6fdb58cdcf64515a1a889263ee57e709e3c29fe8dc03d63066b1a2e2c2 depends: - - max-pipelines ==25.7.0.dev2025100805 release - - mojo ==0.25.7.0.dev2025100805 release + - max-pipelines ==26.1.0.dev2025121105 release + - mojo ==0.26.1.0.dev2025121105 release license: LicenseRef-Modular-Proprietary - size: 9441 - timestamp: 1759900825814 -- conda: https://conda.modular.com/max-nightly/linux-64/mojo-0.25.7.0.dev2025100805-release.conda - sha256: e479169432a90b74b1719ba93824b86d7e1343257ff1502f0e2c77583511e16e + size: 16284 + timestamp: 1765430656228 +- conda: https://conda.modular.com/max-nightly/linux-64/mojo-0.26.1.0.dev2025121105-release.conda + sha256: 1ad48bf92c77edb3261de1758fa380b519e60e69c066a127b91e2934cb4112f4 depends: - python >=3.10 - - mojo-compiler ==0.25.7.0.dev2025100805 release - - mblack ==25.7.0.dev2025100805 release + - mojo-compiler ==0.26.1.0.dev2025121105 release + - mblack ==26.1.0.dev2025121105 release - jupyter_client >=8.6.2,<8.7 license: LicenseRef-Modular-Proprietary - size: 91228895 - timestamp: 1759900825814 -- conda: https://conda.modular.com/max-nightly/osx-arm64/mojo-0.25.7.0.dev2025100805-release.conda - sha256: 4a0f7af1b50cb67997e8b5ff5b36b544dfacfa2ea5ca950b6c4c05af6ca3c4e2 + size: 87665164 + timestamp: 1765430656228 +- conda: https://conda.modular.com/max-nightly/osx-arm64/mojo-0.26.1.0.dev2025121105-release.conda + sha256: e4115b3f995d447fa6afe211bda0845a564d51aabe8399fba6293f34b58dd0e0 depends: - python >=3.10 - - mojo-compiler ==0.25.7.0.dev2025100805 release - - mblack ==25.7.0.dev2025100805 release + - mojo-compiler ==0.26.1.0.dev2025121105 release + - mblack ==26.1.0.dev2025121105 release - jupyter_client >=8.6.2,<8.7 license: LicenseRef-Modular-Proprietary - size: 77059509 - timestamp: 1759901122330 -- conda: https://conda.modular.com/max-nightly/linux-64/mojo-compiler-0.25.7.0.dev2025100805-release.conda - sha256: 2432c8589946bd06a1b32bd96efcd51886928a02c41e4f21bca98759e05271ad + size: 74133806 + timestamp: 1765431174145 +- conda: https://conda.modular.com/max-nightly/linux-64/mojo-compiler-0.26.1.0.dev2025121105-release.conda + sha256: d609022164aa88e415e7413c49180161b24c06091313f876087e04dff0d45ef6 depends: - - mojo-python ==0.25.7.0.dev2025100805 release + - mojo-python ==0.26.1.0.dev2025121105 release license: LicenseRef-Modular-Proprietary - size: 87931683 - timestamp: 1759900825813 -- conda: https://conda.modular.com/max-nightly/osx-arm64/mojo-compiler-0.25.7.0.dev2025100805-release.conda - sha256: c19e4b2a53ef60670d994ae0a901577760e0f0aadf4615b165b1e0f392220f2c + size: 84195881 + timestamp: 1765430656228 +- conda: https://conda.modular.com/max-nightly/osx-arm64/mojo-compiler-0.26.1.0.dev2025121105-release.conda + sha256: e9b2364c42bdc1b1cffc6c1441e2feb97acc7acc16983ab137338a47a7a24be1 depends: - - mojo-python ==0.25.7.0.dev2025100805 release + - mojo-python ==0.26.1.0.dev2025121105 release license: LicenseRef-Modular-Proprietary - size: 62364223 - timestamp: 1759901122330 -- conda: https://conda.modular.com/max-nightly/noarch/mojo-python-0.25.7.0.dev2025100805-release.conda + size: 64840514 + timestamp: 1765431174145 +- conda: https://conda.modular.com/max-nightly/noarch/mojo-python-0.26.1.0.dev2025121105-release.conda noarch: python - sha256: 4c4ed4977272dc997cca836559f5b3c8ba0a33c382ff45e171ecd82e465bf781 + sha256: abad0df247ffe68e214b5a4a0aab729b5f80c05582a8397894ac28db71f434d1 depends: - python license: LicenseRef-Modular-Proprietary - size: 17949 - timestamp: 1759900825813 + size: 24243 + timestamp: 1765430656227 - conda: https://conda.anaconda.org/conda-forge/linux-64/mpc-1.3.1-h24ddda3_1.conda sha256: 1bf794ddf2c8b3a3e14ae182577c624fa92dea975537accff4bc7e5fea085212 md5: aa14b9a5196a6d8dd364164b7ce56acf @@ -5155,6 +5204,17 @@ packages: license_family: MIT size: 200840 timestamp: 1760026188268 +- conda: https://conda.anaconda.org/conda-forge/noarch/rich-rst-1.3.2-pyhd8ed1ab_0.conda + sha256: 202e90d6624abc924e185166f6fcfdd29c6749ec26d60480a0a34c898c0b67fd + md5: cbd84dbdb3f5a7d762b5fb2b0d49e7cd + depends: + - docutils + - python >=3.10 + - rich >=12.0.0 + license: MIT + license_family: MIT + size: 18299 + timestamp: 1760519277784 - conda: https://conda.anaconda.org/conda-forge/noarch/rich-toolkit-0.15.1-pyhcf101f3_0.conda sha256: 7c8ffaa40bf4ba5fc6bb8f0e4b9da77678fe74cdb50ab82041d6a5e4a25f530b md5: 12f69ed6e4115871451a3c7809b4651e diff --git a/pixi.toml b/pixi.toml index d15e1be..d866b94 100644 --- a/pixi.toml +++ b/pixi.toml @@ -31,9 +31,9 @@ test-all = "python tests/test.step_01.py && python tests/test.step_02.py && pyth fmt = "ruff format ." [dependencies] -modular = "25.7.*" +modular = ">=26.1.0.dev2025121105,<27" transformers = ">=4.57.0,<5" pytorch = ">=2.8.0,<3" numpy = ">=2.3.3,<3" mdbook = ">=0.4.0" -ruff = ">=0.14.1,<0.15" \ No newline at end of file +ruff = ">=0.14.1,<0.15"