From 7b3ae19b5e43a63f56cf960b16248954a8bed100 Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Wed, 9 Apr 2025 16:38:37 -0400 Subject: [PATCH 1/7] add qwen2 support --- contributed/models/qwen2/__init__.py | 0 contributed/models/qwen2/modeling_qwen.py | 1006 +++++++++++++++++++++ 2 files changed, 1006 insertions(+) create mode 100644 contributed/models/qwen2/__init__.py create mode 100644 contributed/models/qwen2/modeling_qwen.py diff --git a/contributed/models/qwen2/__init__.py b/contributed/models/qwen2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/contributed/models/qwen2/modeling_qwen.py b/contributed/models/qwen2/modeling_qwen.py new file mode 100644 index 0000000..705db6e --- /dev/null +++ b/contributed/models/qwen2/modeling_qwen.py @@ -0,0 +1,1006 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 model for NXD inference.""" +import copy +import gc +import logging +import math +from typing import List, Optional, Tuple, Type + +import torch +from neuronx_distributed.parallel_layers import parallel_state # noqa: E402 +from neuronx_distributed.parallel_layers.layers import ( # noqa: E402; noqa: E402; noqa: E402; noqa: E402; noqa: E402 + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) +from neuronx_distributed.parallel_layers.utils import get_padding_length +from neuronx_distributed.quantization.quantization_config import QuantizationType, QuantizedDtype +from neuronx_distributed.quantization.quantization_layers import ( # noqa: E402; noqa: E402; noqa: E402; noqa: E402; noqa: E402 + QuantizedColumnParallel, + QuantizedRowParallel, +) +from neuronx_distributed_inference.modules.attention.gqa import GroupQueryAttention_O +from neuronxcc.nki._private_kernels.mlp import ( + mlp_fused_add_isa_kernel, + mlp_isa_kernel, + quant_mlp_fused_add_isa_kernel, + quant_mlp_isa_kernel, +) +from neuronxcc.nki._private_kernels.rmsnorm import rmsnorm_quant_isa_kernel +from neuronxcc.starfish.penguin.targets.nki.private_api import vnc +from torch import nn +from torch_neuronx.xla_impl.ops import nki_jit +from transformers import Qwen2ForCausalLM +from transformers.activations import ACT2FN +from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding + +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig # noqa: E402 +from neuronx_distributed_inference.models.model_base import ( # noqa: E402 + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase +from neuronx_distributed_inference.modules.attention.gqa import ( # noqa: E402 + BaseGroupQueryAttention, +) +from neuronx_distributed_inference.modules.attention.utils import ( + RotaryEmbedding, + preprocess_quantized_linear_layer, + transpose_parallel_linear_layer, +) +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.flashdecode.utils import calculate_num_cores_per_group +from neuronx_distributed_inference.modules.lora_serving.lora_module import is_lora_module +from neuronx_distributed_inference.utils.distributed import get_tp_group + +_Qwen2_MODULE_MAP = {} + +logger = logging.getLogger("Neuron") + + +def get_rmsnorm_cls(): + # Initialize to the appropriate implementation of RMSNorm + # If infer on NXD -> CustomRMSNorm + # If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU) + return CustomRMSNorm if parallel_state.model_parallel_is_initialized() else LlamaRMSNorm + + +def preshard_hook_fn(module: torch.nn.Module, model_state_dict: dict, prefix: str) -> bool: + if isinstance(module, (BaseGroupQueryAttention,)): + return module.preshard_hook(model_state_dict, prefix) + + return False + + +def _register_module(key: str, cls: Type[nn.Module]): + _Qwen2_MODULE_MAP[key] = cls + + +def register_module(key: str): + """ + Register a module for use in NeuronQwen2. + + Arguments: + key: String used to identify the module + + Example: + @register_module("NeuronQwen2Attention") + class NeuronQwen2Attention(nn.Module): + ... + """ + + def inner(cls: Type[nn.Module]): + _register_module(key, cls) + return cls + + return inner + + +def convert_state_dict_to_fused_qkv(Qwen2_state_dict, cfg: InferenceConfig): + """ + This function concats the qkv weights to a Wqkv weight for fusedqkv, and deletes the qkv weights. + """ + for l in range(cfg.num_hidden_layers): # noqa: E741 + Qwen2_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat( + [ + Qwen2_state_dict[f"layers.{l}.self_attn.q_proj.weight"], + Qwen2_state_dict[f"layers.{l}.self_attn.k_proj.weight"], + Qwen2_state_dict[f"layers.{l}.self_attn.v_proj.weight"], + ], + ) + del Qwen2_state_dict[f"layers.{l}.self_attn.q_proj.weight"] + del Qwen2_state_dict[f"layers.{l}.self_attn.k_proj.weight"] + del Qwen2_state_dict[f"layers.{l}.self_attn.v_proj.weight"] + + gc.collect() + + return Qwen2_state_dict + + +class Qwen2InferenceConfig(InferenceConfig): + def add_derived_config(self): + self.num_cores_per_group = 1 + if self.neuron_config.flash_decoding_enabled: + num_attn_heads, num_kv_heads = self.num_attention_heads, self.num_key_value_heads + self.num_cores_per_group = calculate_num_cores_per_group( + num_attn_heads, num_kv_heads, self.neuron_config.tp_degree + ) + + def get_required_attributes(self) -> List[str]: + return [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "pad_token_id", + "vocab_size", + "max_position_embeddings", + "rope_theta", + "rms_norm_eps", + "hidden_act", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return NeuronConfig + + +class NeuronQwen2MLP(nn.Module): + """ + This class just replace the linear layers (gate_proj, up_proj and down_proj) with column and row parallel layers + """ + + def __init__(self, config: InferenceConfig): + super().__init__() + self.config = config + self.neuron_config = config.neuron_config + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.act_fn = ACT2FN[config.hidden_act] + + self.sequence_parallel_enabled = getattr( + self.neuron_config, "sequence_parallel_enabled", False + ) + self.sequence_dimension = 1 if self.sequence_parallel_enabled else None + self.rms_norm_eps = config.rms_norm_eps + self.mlp_kernel_enabled = self.neuron_config.mlp_kernel_enabled + self.quantized_mlp_kernel_enabled = self.neuron_config.quantized_mlp_kernel_enabled + self.rmsnorm_quantize_kernel_enabled = self.neuron_config.rmsnorm_quantize_kernel_enabled + self.quantized_kernel_lower_bound = self.neuron_config.quantized_kernel_lower_bound + self.logical_neuron_cores = self.neuron_config.logical_neuron_cores + mlp_bias = getattr(config, "mlp_bias", False) + if parallel_state.model_parallel_is_initialized(): + if self.quantized_mlp_kernel_enabled: + # Quantized MLP kernels expect intermediate size to be multiple of 128, so we need to pad + tp_degree = self.neuron_config.tp_degree + self.intermediate_size += ( + get_padding_length(self.intermediate_size // tp_degree, 128) * tp_degree + ) + logger.debug(f"Quantized intermediate_size: {self.intermediate_size}") + + quantization_type = QuantizationType(self.neuron_config.quantization_type) + quantized_dtype = QuantizedDtype.F8E4M3 + self.gate_proj = QuantizedColumnParallel( + input_size=self.hidden_size, + output_size=self.intermediate_size, + bias=mlp_bias, + gather_output=False, + sequence_parallel_enabled=False, + dtype=config.neuron_config.torch_dtype, + quantized_dtype=quantized_dtype, + quantization_type=quantization_type, + tensor_model_parallel_group=get_tp_group(config), + ) + self.up_proj = QuantizedColumnParallel( + input_size=self.hidden_size, + output_size=self.intermediate_size, + bias=mlp_bias, + gather_output=False, + sequence_parallel_enabled=False, + dtype=config.neuron_config.torch_dtype, + quantized_dtype=quantized_dtype, + quantization_type=quantization_type, + tensor_model_parallel_group=get_tp_group(config), + ) + self.down_proj = QuantizedRowParallel( + input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=mlp_bias, + quantization_type=quantization_type, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + quantized_dtype=quantized_dtype, + sequence_parallel_enabled=False, + quantization_per_channel_axis=0, + tensor_model_parallel_group=get_tp_group(config), + ) + + else: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=mlp_bias, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + tensor_model_parallel_group=get_tp_group(config), + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=mlp_bias, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + tensor_model_parallel_group=get_tp_group(config), + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=mlp_bias, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + tensor_model_parallel_group=get_tp_group(config), + reduce_dtype=config.neuron_config.rpl_reduce_dtype, + ) + + if self.mlp_kernel_enabled: + if self.quantized_mlp_kernel_enabled: + preprocess_quantized_linear_layer(self.gate_proj) + preprocess_quantized_linear_layer(self.up_proj) + preprocess_quantized_linear_layer(self.down_proj) + + else: + # Transpose the weights to the layout expected by kernels + self.gate_proj.weight = transpose_parallel_linear_layer(self.gate_proj.weight) + self.up_proj.weight = transpose_parallel_linear_layer(self.up_proj.weight) + self.down_proj.weight = transpose_parallel_linear_layer(self.down_proj.weight) + + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias) + + def _kernel_enabled_quantized_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): + grid = (vnc(self.logical_neuron_cores),) + fused_residual = residual is not None + logger.debug( + f"MLP: quantized kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" + ) + + # Can't do residual add in the kernel if SP is enabled + if fused_residual: + assert ( + not self.sequence_parallel_enabled + ), "Quantized MLP cannot have both fused residual add and sequence parallel RMSnorm!" + # Using fused residual add + _mlp_fwd_call = nki_jit()(quant_mlp_fused_add_isa_kernel) + else: + _mlp_fwd_call = nki_jit()(quant_mlp_isa_kernel) + + # Handle SP RMSnorm + x_orig_dtype = x.dtype + if self.sequence_parallel_enabled: + # This RMSNormQuant kernel will do quantization inside, so we pass the + # lower_bound for clipping. + # If we don't use this kernel, the MLP kernel below will do the + # quantization, so we also pass lower_bound to that kernel. + if self.rmsnorm_quantize_kernel_enabled: + logger.debug( + "Running Quantized MLP kernel with sequence-parallel RMSnorm-Quantize kernel!" + ) + _rmsnorm_quant_fwd_call = nki_jit()(rmsnorm_quant_isa_kernel) + quant_rmsnorm_out = torch.zeros( + size=( + x.shape[0], # batch size + x.shape[1], # sequence length + x.shape[2] + 4, # hidden size + 4 bytes for packing fp32 scale + ), + dtype=torch.int8, + device=x.device, + ) + ln_w = rmsnorm.weight.unsqueeze(0) + lower_bound = self.quantized_kernel_lower_bound + _rmsnorm_quant_fwd_call[grid]( + x, ln_w, lower_bound, quant_rmsnorm_out, kernel_name="QuantOnly" + ) + x = gather_from_sequence_parallel_region( + quant_rmsnorm_out, + self.sequence_dimension, + process_group=get_tp_group(self.config), + ) + + else: + logger.debug( + "Running Quantized MLP kernel with external (native compiler) sequence-parallel RMSnorm!" + ) + x = gather_from_sequence_parallel_region( + x, self.sequence_dimension, process_group=get_tp_group(self.config) + ) + + # Build output tensor + output_tensor_seqlen = x.shape[1] + if fused_residual: + # seqlen dim is doubled to store the residual add output + output_tensor_seqlen *= 2 + + output_tensor = torch.zeros( + size=( + x.shape[0], # batch size + output_tensor_seqlen, + self.hidden_size, # hidden size + ), + dtype=x_orig_dtype, + device=x.device, + ) + + # Grab weights + # all weights of the layers are stored in (out, in) shape + # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] + ln_w = rmsnorm.weight.unsqueeze(0) + gate_w = self.gate_proj.weight.data + gate_w_scale = self.gate_proj.weight_scale + up_w = self.up_proj.weight.data + up_w_scale = self.up_proj.weight_scale + down_w = self.down_proj.weight.data + down_w_scale = self.down_proj.weight_scale + lower_bound = self.quantized_kernel_lower_bound + + if fused_residual: + _mlp_fwd_call[grid]( + x, # attn_output + residual, # hidden + ln_w, # ln_w + gate_w, # gate_w + gate_w_scale, + up_w, # up_w + up_w_scale, + down_w, # down_w + down_w_scale, + lower_bound, + output_tensor, # out + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + store_add=True, + ) + original_seqlen = x.shape[1] + residual = output_tensor[:, original_seqlen:, :] + output_tensor = output_tensor[:, :original_seqlen, :] + else: + _mlp_fwd_call[grid]( + x, # hidden + # should be fine to pass gamma is as a dummy even if not using fused rmsnorm + ln_w, + gate_w, # gate_w + gate_w_scale, + up_w, # up_w + up_w_scale, + down_w, # down_w + down_w_scale, + lower_bound, + output_tensor, # out + # Run RMSNorm inside the kernel if NOT using SP rmsnorm + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + ) + residual = None + + # All-reduce or reduce-scatter, depending on whether SP is enabled + if self.sequence_parallel_enabled: + output_tensor = reduce_scatter_to_sequence_parallel_region( + output_tensor, self.sequence_dimension, process_group=get_tp_group(self.config) + ) + else: + output_tensor = reduce_from_tensor_model_parallel_region(output_tensor) + + logger.debug(f"Quantized MLP output shape {output_tensor.shape}") + return (output_tensor, residual) + + def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): + fused_residual = residual is not None + logger.debug( + f"MLP: kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" + ) + + # Choose which kernel to call + if fused_residual: + assert ( + not self.sequence_parallel_enabled + ), "MLP kernel cannot have both fused residual add and sequence parallel RMSnorm!" + # Using fused residual add + _mlp_fwd_call = nki_jit()(mlp_fused_add_isa_kernel) + else: + _mlp_fwd_call = nki_jit()(mlp_isa_kernel) + + if self.sequence_parallel_enabled: + x = gather_from_sequence_parallel_region( + x, self.sequence_dimension, process_group=get_tp_group(self.config) + ) + + # Build output tensor + output_tensor_seqlen = x.shape[1] + if fused_residual: + # seqlen dim is doubled to store the residual add output + output_tensor_seqlen *= 2 + + output_tensor = torch.zeros( + size=( + x.shape[0], # batch size + output_tensor_seqlen, + self.hidden_size, # hidden size + ), + dtype=x.dtype, + device=x.device, + ) + + # Grab weights + # all weights of the layers are stored in (out, in) shape + # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] + ln_w = rmsnorm.weight.unsqueeze(0) + gate_w = self.gate_proj.weight.data + up_w = self.up_proj.weight.data + down_w = self.down_proj.weight.data + + grid = (vnc(self.logical_neuron_cores),) + + if fused_residual: + _mlp_fwd_call[grid]( + x, # attn_output + residual, # hidden + ln_w, # ln_w + gate_w, # gate_w + up_w, # up_w + down_w, # down_w + output_tensor, # out + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + store_add=True, + ) + original_seqlen = x.shape[1] + residual = output_tensor[:, original_seqlen:, :] + output_tensor = output_tensor[:, :original_seqlen, :] + else: + _mlp_fwd_call[grid]( + x, # hidden + # should be fine to pass gamma is as a dummy even if not using fused rmsnorm + ln_w, + gate_w, + up_w, + down_w, + output_tensor, # out + # Run RMSNorm inside the kernel if NOT using SP rmsnorm + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + ) + residual = None + + # All-reduce or reduce-scatter, depending on whether SP is enabled + if self.sequence_parallel_enabled: + output_tensor = reduce_scatter_to_sequence_parallel_region( + output_tensor, self.sequence_dimension, process_group=get_tp_group(self.config) + ) + else: + output_tensor = reduce_from_tensor_model_parallel_region( + output_tensor, process_group=get_tp_group(self.config) + ) + + logger.debug(f"MLP output shape {output_tensor.shape}") + return (output_tensor, residual) + + def _native_mlp(self, x, rmsnorm, adapter_ids=None): + logger.debug("MLP: native compiler") + # all-gather is done here instead of CPL layers to + # avoid 2 all-gathers from up and gate projections + if self.sequence_parallel_enabled: + x = gather_from_sequence_parallel_region( + x, self.sequence_dimension, process_group=get_tp_group(self.config) + ) + + gate_proj_output = ( + self.gate_proj(x) + if not is_lora_module(self.gate_proj) + else self.gate_proj(x, adapter_ids) + ) + up_proj_output = ( + self.up_proj(x) if not is_lora_module(self.up_proj) else self.up_proj(x, adapter_ids) + ) + down_proj_input = self.act_fn(gate_proj_output) * up_proj_output + output = ( + self.down_proj(down_proj_input) + if not is_lora_module(self.up_proj) + else self.down_proj(down_proj_input, adapter_ids) + ) + logger.debug(f"MLP output shape {output.shape}") + return output + + def forward(self, x, rmsnorm=None, residual=None, adapter_ids=None): + """ + If residual is passed in, will fuse its add into the MLP kernel + + Returns a tuple of (output, residual), where residual is the output of the residual add + """ + if self.mlp_kernel_enabled: + fused_rmsnorm = not self.sequence_parallel_enabled + # Quantized MLP kernel + if self.quantized_mlp_kernel_enabled: + return self._kernel_enabled_quantized_mlp( + x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids + ) + # MLP kernel + return self._kernel_enabled_mlp( + x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids + ) + else: + # No kernel + return (self._native_mlp(x, rmsnorm, adapter_ids=adapter_ids), None) + + +@register_module("NeuronQwen2Attention") +class NeuronQwen2Attention(NeuronAttentionBase): + """ + Compared with Qwen2Attention, this class just + 1. replaces the q_proj, k_proj, v_proj with column parallel layer + 2. replaces the o_proj with row parallel layer + 3. update self.num_head to be self.num_head / tp_degree + 4. update self.num_key_value_heads to be self.num_key_value_heads / tp_degree + 5. update forward() method to adjust to changes from self.num_head + """ + + def __init__(self, config: InferenceConfig, tensor_model_parallel_group=None): + super().__init__(tensor_model_parallel_group=tensor_model_parallel_group) + + self.config = config + self.neuron_config = config.neuron_config + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_attention_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.padding_side = config.neuron_config.padding_side + self.torch_dtype = config.neuron_config.torch_dtype + self.is_medusa = config.neuron_config.is_medusa + self.flash_decoding_enabled = config.neuron_config.flash_decoding_enabled + self.num_cores_per_group = config.num_cores_per_group + self.bias = getattr(config, "attention_bias", True) + self.rpl_reduce_dtype = config.neuron_config.rpl_reduce_dtype + self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled + self.rms_norm_eps = config.rms_norm_eps + + if parallel_state.model_parallel_is_initialized(): + self.tp_degree = self.config.neuron_config.tp_degree + else: + self.tp_degree = 1 + + self.fused_qkv = config.neuron_config.fused_qkv + self.clip_qkv = None + + self.sequence_parallel_enabled = self.neuron_config.sequence_parallel_enabled + self.sequence_dimension = 1 if self.sequence_parallel_enabled else None + logger.debug( + f"Hello from NeuronQwen2Attention init! Is SP enabled? {self.sequence_parallel_enabled}. Dim? {self.sequence_dimension}" + ) + + self.init_gqa_properties() + + self.init_rope() + + self.o_proj = GroupQueryAttention_O( + hidden_size=self.hidden_size, + head_dim=self.head_dim, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + tp_degree=self.tp_degree, + dtype=self.torch_dtype, + bias=False, + input_is_parallel=True, + layer_name=self.o_proj_layer_name, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + tensor_model_parallel_group=self.tensor_model_parallel_group, + rpl_reduce_dtype=self.rpl_reduce_dtype, + ) + + def init_rope(self): + if not hasattr(self.config, "rope_scaling") or self.config.rope_scaling is None: + # TODO(yihsian): Check if we can just use our own implementation + if self.is_medusa: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + self.rotary_emb = RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + rope_type = self.config.rope_scaling.get( + "rope_type", self.config.rope_scaling.get("type", None) + ) + if rope_type == "Qwen2": + self.rotary_emb = Qwen2RotaryEmbedding( + dim=self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + factor=self.config.rope_scaling["factor"], + low_freq_factor=self.config.rope_scaling["low_freq_factor"], + high_freq_factor=self.config.rope_scaling["high_freq_factor"], + original_max_position_embeddings=self.config.rope_scaling[ + "original_max_position_embeddings" + ], + ) + else: + # Qwen2RotaryEmbedding automatically chooses the correct scaling type from config. + # Warning: The HF implementation may have precision issues when run on Neuron. + # We include it here for compatibility with other scaling types. + self.rotary_emb = LlamaRotaryEmbedding(self.config) + + +# TODO: Modularize RotaryEmbedding. See how HF transformers does it in 4.43. +class Qwen2RotaryEmbedding(nn.Module): + """ + Adapted from Qwen2 4.43 impl + * https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/Qwen2/modeling_Qwen2.py#L78 + * https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/modeling_rope_utils.py#L345 + + This implementation ensures inv_freq is calculated and stored in fp32. + """ + + def __init__( + self, + dim, + max_position_embeddings=131072, + base=500000.0, + factor=8.0, + low_freq_factor=1.0, + high_freq_factor=4.0, + original_max_position_embeddings=8192, + ): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.factor = factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.old_context_len = original_max_position_embeddings + self.register_buffer("inv_freq", None, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq is None: + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + + low_freq_wavelen = self.old_context_len / self.low_freq_factor + high_freq_wavelen = self.old_context_len / self.high_freq_factor + new_freqs = [] + for freq in inv_freq: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / self.factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (self.old_context_len / wavelen - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / self.factor + smooth * freq) + self.inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device) + + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + with torch.autocast(device_type=x.device.type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class NeuronQwen2DecoderLayer(nn.Module): + """ + Just replace the attention with the NXD version, and MLP with the NXD version + """ + + def __init__(self, config: InferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = _Qwen2_MODULE_MAP[config.neuron_config.attn_cls]( + config=config, tensor_model_parallel_group=get_tp_group(config) + ) + self.mlp = NeuronQwen2MLP(config) + logger.debug( + f"Instantiating RMSNorm modules with hidden size {config.hidden_size} and EPS {config.rms_norm_eps}" + ) + self.input_layernorm = None + if ( + not config.neuron_config.is_eagle_draft + or config.neuron_config.enable_eagle_draft_input_norm + ): + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled + self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled + self.rmsnorm_quantize_kernel_enabled = config.neuron_config.rmsnorm_quantize_kernel_enabled + self.mlp_kernel_fuse_residual_add = config.neuron_config.mlp_kernel_fuse_residual_add + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + adapter_ids=None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + # RMSNorm (fused with QKV kernel when SP is disabled) + if (not self.qkv_kernel_enabled or self.sequence_parallel_enabled) and self.input_layernorm: + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + adapter_ids=adapter_ids, + rmsnorm=self.input_layernorm, + **kwargs, + ) + + if self.mlp_kernel_enabled and self.mlp_kernel_fuse_residual_add: + assert ( + not self.sequence_parallel_enabled + ), "mlp_kernel_fuse_residual_add should be off when sequence parallelism is enabled" + # First residual add handled in the MLP kernel + hidden_states, residual = self.mlp( + hidden_states, + rmsnorm=self.post_attention_layernorm, + residual=residual, + adapter_ids=adapter_ids, + ) + else: + hidden_states = residual + hidden_states + residual = hidden_states + # RMSNorm (fused with QKV kernel when SP is disabled) + if not self.mlp_kernel_enabled or self.sequence_parallel_enabled: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp( + hidden_states, + rmsnorm=self.post_attention_layernorm, + adapter_ids=adapter_ids, + ) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value, cos_cache, sin_cache) + return outputs + + +class ResBlock(nn.Module): + """ + A Residual Block module. + + This module performs a linear transformation followed by a SiLU activation, + and then adds the result to the original input, creating a residual connection. + + Args: + hidden_size (int): The size of the hidden layers in the block. + """ + + def __init__(self, hidden_size): + super().__init__() + self.linear = nn.Linear(hidden_size, hidden_size) + # Initialize as an identity mapping + torch.nn.init.zeros_(self.linear.weight) + # Use SiLU activation to keep consistent with the Qwen2 model + self.act = nn.SiLU() + + def forward(self, x): + """ + Forward pass of the ResBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output after the residual connection and activation. + """ + return x + self.act(self.linear(x)) + + +class NeuronQwen2Model(NeuronBaseModel): + """ + The neuron version of the Qwen2Model + """ + + def setup_attr_for_model(self, config: InferenceConfig): + # Needed for init_inference_optimization() + self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if parallel_state.model_parallel_is_initialized(): + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=not config.neuron_config.vocab_parallel, + sequence_parallel_enabled=False, + pad=True, + tensor_model_parallel_group=get_tp_group(config), + use_spmd_rank=config.neuron_config.vocab_parallel, + ) + + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + pad=True, + tensor_model_parallel_group=get_tp_group(config), + ) + else: + self.embed_tokens = nn.Embedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + ) + self.lm_head = nn.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + ) + + # In the target fp8 checkpoint, the 1st and last + # layers are not using fp8. + updated_configs = [] + for i in range(config.num_hidden_layers): + # TODO: Remove hardcoded code to have non-quantized MLPs for first and last decoder block + if i == 0 or i == config.num_hidden_layers - 1: + non_quant_config = copy.deepcopy(config) + non_quant_config.neuron_config.quantized_mlp_kernel_enabled = False + updated_configs.append(non_quant_config) + else: + updated_configs.append(config) + self.layers = nn.ModuleList([NeuronQwen2DecoderLayer(conf) for conf in updated_configs]) + if not config.neuron_config.is_eagle_draft: + self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + if config.neuron_config.is_eagle_draft: + fc_bias = getattr(config, "fc_bias", False) + self.fc = ColumnParallelLinear( + config.hidden_size * 2, config.hidden_size, bias=fc_bias, gather_output=True + ) + self.is_medusa = config.neuron_config.is_medusa + self.num_medusa_heads = config.neuron_config.num_medusa_heads + self.medusa_speculation_length = config.neuron_config.medusa_speculation_length + + if self.is_medusa: + if parallel_state.model_parallel_is_initialized(): + medusa_head_cls = ColumnParallelLinear + else: + medusa_head_cls = nn.Linear + for i in range(self.num_medusa_heads): + medusa_head = nn.Sequential( + *([ResBlock(config.hidden_size)] * 1), + medusa_head_cls( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + ), + ) + setattr(self, f"medusa_head_{i}", medusa_head) + + +class NeuronQwen2ForCausalLM(NeuronBaseForCausalLM): + """ + This class extends Qwen2ForCausalLM create traceable + blocks for Neuron. + + Args: + Qwen2ForCausalLM (_type_): _description_ + """ + + _model_cls = NeuronQwen2Model + + @staticmethod + def load_hf_model(model_path): + return Qwen2ForCausalLM.from_pretrained(model_path) + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict: dict, config: InferenceConfig) -> dict: + """This function should be over-ridden in child classes as needed""" + neuron_config = config.neuron_config + if neuron_config.fused_qkv: + state_dict = convert_state_dict_to_fused_qkv(state_dict, config) + + if neuron_config.vocab_parallel: + # TODO: this hack can be removed after replication_id is ready to use + state_dict["embed_tokens.rank_util.rank"] = torch.arange( + 0, neuron_config.local_ranks_size + ) + + # to facilitate rank usage in attention + num_layers = config.num_hidden_layers + tp_degree = neuron_config.tp_degree + for i in range(num_layers): + state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + # to facilitate rank usage in base model + state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) + return state_dict + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() + + @classmethod + def get_config_cls(cls): + return Qwen2InferenceConfig \ No newline at end of file From c6b43cfe0e8be74b162e23d220e12f817c157283 Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Wed, 16 Apr 2025 17:31:06 -0400 Subject: [PATCH 2/7] update qwen file and add test nb --- .DS_Store | Bin 0 -> 6148 bytes contributed/.DS_Store | Bin 0 -> 6148 bytes contributed/models/.DS_Store | Bin 0 -> 6148 bytes contributed/models/qwen2/__init__.py | 0 contributed/models/qwen2/modeling_qwen.py | 1 + contributed/models/qwen2/qwen-test.ipynb | 351 ++++++++++++++++++++++ 6 files changed, 352 insertions(+) create mode 100644 .DS_Store create mode 100644 contributed/.DS_Store create mode 100644 contributed/models/.DS_Store delete mode 100644 contributed/models/qwen2/__init__.py create mode 100644 contributed/models/qwen2/qwen-test.ipynb diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..2a25ad40e3b4764ef0022a094a3a33461292726e GIT binary patch literal 6148 zcmeHKJ5EC}5S)b+kMWP2?XrftZJ@$IX zmZy087JzL}!vnAeu%tWU%ZIu7zWdDXDq=)B&v?fx_Bi1aFSG351I}IHfDvzGo%4s` z<8a)MBL$?u?I_^ihemhog;QdDIygiNKwK~! z#(DG-#O48FFPst?p;=OiNwsP*Ea{B5%Ik$wV$xyNd|2IV)uCA2&huNO!+N4dDIf(d z6}Zpk+Ux%#{g3|tlBAUskODWQfGu`UyDgtowRQG5ueFW-O!u5`x*O*~;SlAR80DA? fFUOlm%Dm=t?)SneG3bm3ov5Dy*F`1;{#$_`PNfza literal 0 HcmV?d00001 diff --git a/contributed/.DS_Store b/contributed/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..d0da439dffcc3253b60c3efbd08401ed8b1d1bf9 GIT binary patch literal 6148 zcmeHKJ8r`;3?&<*FpwoLQn+p#{cTYW6kJiV}5X0@}4hPZxz z1iisDMMS%g{#E24A}!odK5bZ=?VGRcDI*Gm9!oUT*h% z+n`YaDnJFO02QDDn=6ndw%u=jG7qEzRN%iWVBd!VH>`Hg$K#o*^3M>`a#qwbF z{{;V~|6h{0q5@Rlp%l>Ocs_P`Qr6bhP1)427QyNR`;IWSLD5Ku-`MH~|+x{OnLsRaM>Q=sY{#FtDyh=vi`J?8Ni) zEhb|Cw*Eb>fCYeex+^|>7??4h;TwNA;|rgspX23ly\\nOkay, so I need to figure out how many times the letter \\'r\\' appears in the word \"strawberry.\" Let me start by writing down the word and looking at each letter one by one. \\n\\nFirst, I\\'ll spell out \"strawberry\" to make sure I have all the letters right. S-T-R-A-W-B-E-R-R-Y. Wait, let me check that again. Sometimes I might miss a letter. Let me count the letters as I write them:\\n\\n1. S\\n2. T\\n3. R\\n4. A\\n5. W\\n6. B\\n7. E\\n8. R\\n9. R\\n10. Y\\n\\nHmm, so that\\'s 10 letters in total. Now, I need to count how many times \\'R\\' shows up. Let me go through each letter again and note the positions where \\'R\\' is.\\n\\nStarting from the first letter:\\n1. S – not an R\\n2. T – not an R\\n3. R – that\\'s the first R\\n4. A – no\\n5. W – no\\n6. B – no\\n7. E – no\\n8. R – second R\\n9. R – third R\\n10. Y – no\\n\\nWait a second, so after the first R at position 3, the next R is at position 8, and then another at 9? Let me confirm the spelling again because sometimes people might confuse \"strawberry\" with other similar words. Let me think: S-T-R-A-W-B-E-R-R-Y. Yes, that\\'s correct. After the \\'E\\', there are two R\\'s in a row, right? So positions 8 and 9 are both R\\'s. So that would make three R\\'s in total: one at position 3, and two at 8 and 9. \\n\\nBut hold on, maybe I miscounted the letters. Let me write them out again with numbers to be sure:\\n\\n1. S\\n2. T\\n3. R\\n4. A\\n5. W\\n6. B\\n7. E\\n8. R\\n9. R\\n10. Y\\n\\nYes, that\\'s correct. So the letters R are at positions 3, 8, and 9. That\\'s three R\\'s. Wait, but sometimes when I say \"strawberry,\" I might not pronounce the second R as clearly, but spelling-wise, it\\'s definitely there. Let me check another way. Maybe breaking the word into parts. \"Straw\" and \"berry.\" In \"straw,\" there\\'s an R. Then in \"berry,\" which is B-E-R-R-Y. So \"berry\" has two R\\'s. So adding the one from \"straw,\" that\\'s three total. \\n\\nAlternatively, maybe I can think of the word as S-T-R-A-W-B-E-R-R-Y. So breaking it down:\\n\\n- S T R (so first R)\\n- A W B E (no R\\'s here)\\n- R R Y (two more R\\'s)\\n\\nSo that\\'s 1 + 2 = 3 R\\'s. \\n\\nI think that\\'s right, but I want to be absolutely sure. Let me try another approach. Let me write the word and circle each R:\\n\\nS T R A W B E R R Y\\n\\nCircling the R\\'s: the third letter is R, then the eighth and ninth letters are R\\'s. So three in total. \\n\\nAlternatively, maybe I can use a different method. Let me count the letters one by one and tally the R\\'s:\\n\\nStarting with S: 0\\nT: 0\\nR: 1\\nA: 1\\nW:1\\nB:1\\nE:1\\nR: 2\\nR:3\\nY:3\\n\\nWait, no, that\\'s not the right way. Each time I see an R, I should increment the count. Let me try again:\\n\\n1. S – count remains 0\\n2. T – 0\\n3. R – count becomes 1\\n4. A – 1\\n5. W –1\\n6. B –1\\n7. E –1\\n8. R – count becomes 2\\n9. R – count becomes 3\\n10. Y –3\\n\\nYes, so the final count is 3. \\n\\nI think I might have confused myself earlier when I thought maybe two, but upon multiple checks, it\\'s three. Let me see if any sources or examples say otherwise. Wait, maybe I should just confirm by looking up the spelling of \"strawberry\" again. \\n\\nLooking it up in my mind: S-T-R-A-W-B-E-R-R-Y. Yes, that\\'s correct. The standard spelling has three R\\'s. So the answer should be three. \\n\\nAlternatively, maybe I can think of the pronunciation. In some accents, the double R in \"berry\" might be pronounced as a single sound, but that doesn\\'t change the spelling. The question is about the written word, so the letters are what matter, not the pronunciation. \\n\\nTherefore, after carefully going through each letter and multiple methods of counting, I can confidently say there are three R\\'s in \"strawberry.\"\\n\\n\\nThe word \"strawberry\" contains three instances of the letter \\'r\\'. Here\\'s the breakdown:\\n\\n1. **S** \\n2. **T** \\n3. **R** (1st \\'r\\') \\n4. **A** \\n5. **W** \\n6. **B** \\n7. **E** \\n8. **R** (2nd \\'r\\') \\n9. **R** (3rd \\'r\\') \\n10. **Y** \\n\\n**Answer:** There are **3 r\\'s** in the word \"strawberry.\"'" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import display\n", + "display(\"Generated outputs:\")\n", + "for i, output_token in enumerate(output_tokens):\n", + " display(f\"Output {i}: {output_token}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "model.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "del model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Test Token Output" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "dir = '/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/'\n", + "!cp modeling_qwen.py {dir}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!cp {dir}/inference_demo.py ." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Add the following to the inference_demo.py we just copied to our working directory\n", + "\n", + "```\n", + "from .modeling_qwen import NeuronQwen2ForCausalLM\n", + "\n", + "MODEL_TYPES = {\n", + " \"llama\": {\"causal-lm\": NeuronLlamaForCausalLM},\n", + " \"mixtral\": {\"causal-lm\": NeuronMixtralForCausalLM},\n", + " \"dbrx\": {\"causal-lm\": NeuronDbrxForCausalLM},\n", + " \"qwen\": {'causal-lm': NeuronQwen2ForCausalLM} #add this line\n", + "}\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!cp ./inference_demo.py {dir}/inference_demo.py" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Restart your kernel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!inference_demo \\\n", + " --model-type qwen \\\n", + " --task-type causal-lm \\\n", + " run \\\n", + " --model-path /home/ubuntu/model_hf_qwq/qwq/ \\\n", + " --compiled-model-path /home/ubuntu/traced_model_qwq/qwq/ \\\n", + " --torch-dtype bfloat16 \\\n", + " --tp-degree 8 \\\n", + " --batch-size 1 \\\n", + " --max-context-length 32 \\\n", + " --seq-len 64 \\\n", + " --on-device-sampling \\\n", + " --enable-bucketing \\\n", + " --top-k 1 \\\n", + " --do-sample \\\n", + " --pad-token-id 32000 \\\n", + " --prompt \"To be, or not to be\" \\\n", + " --check-accuracy-mode token-matching \\\n", + " --benchmark" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "aws_neuronx_venv_pytorch_2_5_nxd_inference", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 0eaad5c6c2b0abab5a844b23cb5f2b49cf670afb Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Wed, 30 Apr 2025 13:07:28 -0400 Subject: [PATCH 3/7] add qwen3 --- .DS_Store | Bin 6148 -> 6148 bytes contributed/.DS_Store | Bin 6148 -> 6148 bytes contributed/models/qwen2/modeling_qwen.py | 1007 --------------------- contributed/models/qwen2/qwen-test.ipynb | 351 ------- 4 files changed, 1358 deletions(-) delete mode 100644 contributed/models/qwen2/modeling_qwen.py delete mode 100644 contributed/models/qwen2/qwen-test.ipynb diff --git a/.DS_Store b/.DS_Store index 2a25ad40e3b4764ef0022a094a3a33461292726e..45c0ee21b183e0a5b2f8bc84b377fd3232abf0b8 100644 GIT binary patch literal 6148 zcmeHKPiqrF6o1pM&Bh=GQRqcj@EU4ri&}b#`J)sxg>KY?O5Dwcx^%l!vKw;X_ z_!YeRN&GIJ^!H|_kZcnW4dAnQzq#WXC2$Q$3 z<1~_!rW~hHrhI)p;FjHTzq&o0?jAg@d3%rcW;Jhm@StAv_8;!gW@Y!zz59nJ{j=me zm7f%IB(Og!xnb}UK0{fCJzu~2NHo}wrt0|NsP3otMgF=R3%F_bctFr-c_RG++#Rhl!UI5{UN xKR<_Yvml2U%f<#L=FRLJ{2V~tK!NYfllet-IY9;-0Ahy8HayasV?>rP0|1J+7Z(5k diff --git a/contributed/.DS_Store b/contributed/.DS_Store index d0da439dffcc3253b60c3efbd08401ed8b1d1bf9..4b777fbe381b5e51054e9bf4ad3ea349cf7ba8f7 100644 GIT binary patch delta 379 zcmZoMXfc=|#>B)qu~2NHo}wrV0|Nsi1A_nqLk>f+XHI@{Qcix-#*NDv>p?PX47m*X z3@Hq$$g)6jpmGLBAlCa220#{?+NAR00+2~S@f}G8Ihn;J1_sv{nV4Bv+1NSQIk-7u zgER8WgG&-iN{gKmi=siiko^3dBp5rfJ}E3SwLD%x#5q5&Br!8DwFs;uGbI(MCMG;H zFD1X+DZex?r5LO^7$U*J$-x;fAW>ayXlbFNU~Fz&tD{hDX=I?IU}9!cTg%BIs;qAv z6rY`wo0s1Ob~Xbe(EDJ(3#B0xGXq2S#2y(*lprVzF3QWv&r1g?VcfX!7|UjM4t@?` gU~b&_oq009h^`1oFUY(G2n{xRbBM?W=7|j~05E@GLjV8( delta 92 zcmZoMXfc=|#>CJzu~2NHo}wrt0|NsP3otO`Fcha0C+8&P=jTi;RA*$I%)>0R*@gKq r%jN*)IZT_`Iruq%+BQ2fe`lV|FQUr{(s=-g87ABCNN CustomRMSNorm - # If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU) - return CustomRMSNorm if parallel_state.model_parallel_is_initialized() else LlamaRMSNorm - - -def preshard_hook_fn(module: torch.nn.Module, model_state_dict: dict, prefix: str) -> bool: - if isinstance(module, (BaseGroupQueryAttention,)): - return module.preshard_hook(model_state_dict, prefix) - - return False - - -def _register_module(key: str, cls: Type[nn.Module]): - _Qwen2_MODULE_MAP[key] = cls - - -def register_module(key: str): - """ - Register a module for use in NeuronQwen2. - - Arguments: - key: String used to identify the module - - Example: - @register_module("NeuronQwen2Attention") - class NeuronQwen2Attention(nn.Module): - ... - """ - - def inner(cls: Type[nn.Module]): - _register_module(key, cls) - return cls - - return inner - - -def convert_state_dict_to_fused_qkv(Qwen2_state_dict, cfg: InferenceConfig): - """ - This function concats the qkv weights to a Wqkv weight for fusedqkv, and deletes the qkv weights. - """ - for l in range(cfg.num_hidden_layers): # noqa: E741 - Qwen2_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat( - [ - Qwen2_state_dict[f"layers.{l}.self_attn.q_proj.weight"], - Qwen2_state_dict[f"layers.{l}.self_attn.k_proj.weight"], - Qwen2_state_dict[f"layers.{l}.self_attn.v_proj.weight"], - ], - ) - del Qwen2_state_dict[f"layers.{l}.self_attn.q_proj.weight"] - del Qwen2_state_dict[f"layers.{l}.self_attn.k_proj.weight"] - del Qwen2_state_dict[f"layers.{l}.self_attn.v_proj.weight"] - - gc.collect() - - return Qwen2_state_dict - - -class Qwen2InferenceConfig(InferenceConfig): - def add_derived_config(self): - self.neuron_config.attn_cls = "NeuronQwen2Attention" - self.num_cores_per_group = 1 - if self.neuron_config.flash_decoding_enabled: - num_attn_heads, num_kv_heads = self.num_attention_heads, self.num_key_value_heads - self.num_cores_per_group = calculate_num_cores_per_group( - num_attn_heads, num_kv_heads, self.neuron_config.tp_degree - ) - - def get_required_attributes(self) -> List[str]: - return [ - "hidden_size", - "num_attention_heads", - "num_hidden_layers", - "num_key_value_heads", - "pad_token_id", - "vocab_size", - "max_position_embeddings", - "rope_theta", - "rms_norm_eps", - "hidden_act", - ] - - @classmethod - def get_neuron_config_cls(cls) -> Type[NeuronConfig]: - return NeuronConfig - - -class NeuronQwen2MLP(nn.Module): - """ - This class just replace the linear layers (gate_proj, up_proj and down_proj) with column and row parallel layers - """ - - def __init__(self, config: InferenceConfig): - super().__init__() - self.config = config - self.neuron_config = config.neuron_config - self.tp_degree = config.neuron_config.tp_degree - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.act_fn = ACT2FN[config.hidden_act] - - self.sequence_parallel_enabled = getattr( - self.neuron_config, "sequence_parallel_enabled", False - ) - self.sequence_dimension = 1 if self.sequence_parallel_enabled else None - self.rms_norm_eps = config.rms_norm_eps - self.mlp_kernel_enabled = self.neuron_config.mlp_kernel_enabled - self.quantized_mlp_kernel_enabled = self.neuron_config.quantized_mlp_kernel_enabled - self.rmsnorm_quantize_kernel_enabled = self.neuron_config.rmsnorm_quantize_kernel_enabled - self.quantized_kernel_lower_bound = self.neuron_config.quantized_kernel_lower_bound - self.logical_neuron_cores = self.neuron_config.logical_neuron_cores - mlp_bias = getattr(config, "mlp_bias", False) - if parallel_state.model_parallel_is_initialized(): - if self.quantized_mlp_kernel_enabled: - # Quantized MLP kernels expect intermediate size to be multiple of 128, so we need to pad - tp_degree = self.neuron_config.tp_degree - self.intermediate_size += ( - get_padding_length(self.intermediate_size // tp_degree, 128) * tp_degree - ) - logger.debug(f"Quantized intermediate_size: {self.intermediate_size}") - - quantization_type = QuantizationType(self.neuron_config.quantization_type) - quantized_dtype = QuantizedDtype.F8E4M3 - self.gate_proj = QuantizedColumnParallel( - input_size=self.hidden_size, - output_size=self.intermediate_size, - bias=mlp_bias, - gather_output=False, - sequence_parallel_enabled=False, - dtype=config.neuron_config.torch_dtype, - quantized_dtype=quantized_dtype, - quantization_type=quantization_type, - tensor_model_parallel_group=get_tp_group(config), - ) - self.up_proj = QuantizedColumnParallel( - input_size=self.hidden_size, - output_size=self.intermediate_size, - bias=mlp_bias, - gather_output=False, - sequence_parallel_enabled=False, - dtype=config.neuron_config.torch_dtype, - quantized_dtype=quantized_dtype, - quantization_type=quantization_type, - tensor_model_parallel_group=get_tp_group(config), - ) - self.down_proj = QuantizedRowParallel( - input_size=self.intermediate_size, - output_size=self.hidden_size, - bias=mlp_bias, - quantization_type=quantization_type, - input_is_parallel=True, - dtype=config.neuron_config.torch_dtype, - quantized_dtype=quantized_dtype, - sequence_parallel_enabled=False, - quantization_per_channel_axis=0, - tensor_model_parallel_group=get_tp_group(config), - ) - - else: - self.gate_proj = ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - bias=mlp_bias, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - pad=True, - sequence_parallel_enabled=False, - sequence_dimension=None, - tensor_model_parallel_group=get_tp_group(config), - ) - self.up_proj = ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - bias=mlp_bias, - gather_output=False, - dtype=config.neuron_config.torch_dtype, - pad=True, - sequence_parallel_enabled=False, - sequence_dimension=None, - tensor_model_parallel_group=get_tp_group(config), - ) - self.down_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=mlp_bias, - input_is_parallel=True, - dtype=config.neuron_config.torch_dtype, - pad=True, - sequence_parallel_enabled=self.sequence_parallel_enabled, - sequence_dimension=self.sequence_dimension, - tensor_model_parallel_group=get_tp_group(config), - reduce_dtype=config.neuron_config.rpl_reduce_dtype, - ) - - if self.mlp_kernel_enabled: - if self.quantized_mlp_kernel_enabled: - preprocess_quantized_linear_layer(self.gate_proj) - preprocess_quantized_linear_layer(self.up_proj) - preprocess_quantized_linear_layer(self.down_proj) - - else: - # Transpose the weights to the layout expected by kernels - self.gate_proj.weight = transpose_parallel_linear_layer(self.gate_proj.weight) - self.up_proj.weight = transpose_parallel_linear_layer(self.up_proj.weight) - self.down_proj.weight = transpose_parallel_linear_layer(self.down_proj.weight) - - else: - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias) - - def _kernel_enabled_quantized_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): - grid = (vnc(self.logical_neuron_cores),) - fused_residual = residual is not None - logger.debug( - f"MLP: quantized kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" - ) - - # Can't do residual add in the kernel if SP is enabled - if fused_residual: - assert ( - not self.sequence_parallel_enabled - ), "Quantized MLP cannot have both fused residual add and sequence parallel RMSnorm!" - # Using fused residual add - _mlp_fwd_call = nki_jit()(quant_mlp_fused_add_isa_kernel) - else: - _mlp_fwd_call = nki_jit()(quant_mlp_isa_kernel) - - # Handle SP RMSnorm - x_orig_dtype = x.dtype - if self.sequence_parallel_enabled: - # This RMSNormQuant kernel will do quantization inside, so we pass the - # lower_bound for clipping. - # If we don't use this kernel, the MLP kernel below will do the - # quantization, so we also pass lower_bound to that kernel. - if self.rmsnorm_quantize_kernel_enabled: - logger.debug( - "Running Quantized MLP kernel with sequence-parallel RMSnorm-Quantize kernel!" - ) - _rmsnorm_quant_fwd_call = nki_jit()(rmsnorm_quant_isa_kernel) - quant_rmsnorm_out = torch.zeros( - size=( - x.shape[0], # batch size - x.shape[1], # sequence length - x.shape[2] + 4, # hidden size + 4 bytes for packing fp32 scale - ), - dtype=torch.int8, - device=x.device, - ) - ln_w = rmsnorm.weight.unsqueeze(0) - lower_bound = self.quantized_kernel_lower_bound - _rmsnorm_quant_fwd_call[grid]( - x, ln_w, lower_bound, quant_rmsnorm_out, kernel_name="QuantOnly" - ) - x = gather_from_sequence_parallel_region( - quant_rmsnorm_out, - self.sequence_dimension, - process_group=get_tp_group(self.config), - ) - - else: - logger.debug( - "Running Quantized MLP kernel with external (native compiler) sequence-parallel RMSnorm!" - ) - x = gather_from_sequence_parallel_region( - x, self.sequence_dimension, process_group=get_tp_group(self.config) - ) - - # Build output tensor - output_tensor_seqlen = x.shape[1] - if fused_residual: - # seqlen dim is doubled to store the residual add output - output_tensor_seqlen *= 2 - - output_tensor = torch.zeros( - size=( - x.shape[0], # batch size - output_tensor_seqlen, - self.hidden_size, # hidden size - ), - dtype=x_orig_dtype, - device=x.device, - ) - - # Grab weights - # all weights of the layers are stored in (out, in) shape - # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] - ln_w = rmsnorm.weight.unsqueeze(0) - gate_w = self.gate_proj.weight.data - gate_w_scale = self.gate_proj.weight_scale - up_w = self.up_proj.weight.data - up_w_scale = self.up_proj.weight_scale - down_w = self.down_proj.weight.data - down_w_scale = self.down_proj.weight_scale - lower_bound = self.quantized_kernel_lower_bound - - if fused_residual: - _mlp_fwd_call[grid]( - x, # attn_output - residual, # hidden - ln_w, # ln_w - gate_w, # gate_w - gate_w_scale, - up_w, # up_w - up_w_scale, - down_w, # down_w - down_w_scale, - lower_bound, - output_tensor, # out - fused_rmsnorm=fused_rmsnorm, - eps=self.rms_norm_eps, - kernel_name="MLP", - store_add=True, - ) - original_seqlen = x.shape[1] - residual = output_tensor[:, original_seqlen:, :] - output_tensor = output_tensor[:, :original_seqlen, :] - else: - _mlp_fwd_call[grid]( - x, # hidden - # should be fine to pass gamma is as a dummy even if not using fused rmsnorm - ln_w, - gate_w, # gate_w - gate_w_scale, - up_w, # up_w - up_w_scale, - down_w, # down_w - down_w_scale, - lower_bound, - output_tensor, # out - # Run RMSNorm inside the kernel if NOT using SP rmsnorm - fused_rmsnorm=fused_rmsnorm, - eps=self.rms_norm_eps, - kernel_name="MLP", - ) - residual = None - - # All-reduce or reduce-scatter, depending on whether SP is enabled - if self.sequence_parallel_enabled: - output_tensor = reduce_scatter_to_sequence_parallel_region( - output_tensor, self.sequence_dimension, process_group=get_tp_group(self.config) - ) - else: - output_tensor = reduce_from_tensor_model_parallel_region(output_tensor) - - logger.debug(f"Quantized MLP output shape {output_tensor.shape}") - return (output_tensor, residual) - - def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): - fused_residual = residual is not None - logger.debug( - f"MLP: kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" - ) - - # Choose which kernel to call - if fused_residual: - assert ( - not self.sequence_parallel_enabled - ), "MLP kernel cannot have both fused residual add and sequence parallel RMSnorm!" - # Using fused residual add - _mlp_fwd_call = nki_jit()(mlp_fused_add_isa_kernel) - else: - _mlp_fwd_call = nki_jit()(mlp_isa_kernel) - - if self.sequence_parallel_enabled: - x = gather_from_sequence_parallel_region( - x, self.sequence_dimension, process_group=get_tp_group(self.config) - ) - - # Build output tensor - output_tensor_seqlen = x.shape[1] - if fused_residual: - # seqlen dim is doubled to store the residual add output - output_tensor_seqlen *= 2 - - output_tensor = torch.zeros( - size=( - x.shape[0], # batch size - output_tensor_seqlen, - self.hidden_size, # hidden size - ), - dtype=x.dtype, - device=x.device, - ) - - # Grab weights - # all weights of the layers are stored in (out, in) shape - # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] - ln_w = rmsnorm.weight.unsqueeze(0) - gate_w = self.gate_proj.weight.data - up_w = self.up_proj.weight.data - down_w = self.down_proj.weight.data - - grid = (vnc(self.logical_neuron_cores),) - - if fused_residual: - _mlp_fwd_call[grid]( - x, # attn_output - residual, # hidden - ln_w, # ln_w - gate_w, # gate_w - up_w, # up_w - down_w, # down_w - output_tensor, # out - fused_rmsnorm=fused_rmsnorm, - eps=self.rms_norm_eps, - kernel_name="MLP", - store_add=True, - ) - original_seqlen = x.shape[1] - residual = output_tensor[:, original_seqlen:, :] - output_tensor = output_tensor[:, :original_seqlen, :] - else: - _mlp_fwd_call[grid]( - x, # hidden - # should be fine to pass gamma is as a dummy even if not using fused rmsnorm - ln_w, - gate_w, - up_w, - down_w, - output_tensor, # out - # Run RMSNorm inside the kernel if NOT using SP rmsnorm - fused_rmsnorm=fused_rmsnorm, - eps=self.rms_norm_eps, - kernel_name="MLP", - ) - residual = None - - # All-reduce or reduce-scatter, depending on whether SP is enabled - if self.sequence_parallel_enabled: - output_tensor = reduce_scatter_to_sequence_parallel_region( - output_tensor, self.sequence_dimension, process_group=get_tp_group(self.config) - ) - else: - output_tensor = reduce_from_tensor_model_parallel_region( - output_tensor, process_group=get_tp_group(self.config) - ) - - logger.debug(f"MLP output shape {output_tensor.shape}") - return (output_tensor, residual) - - def _native_mlp(self, x, rmsnorm, adapter_ids=None): - logger.debug("MLP: native compiler") - # all-gather is done here instead of CPL layers to - # avoid 2 all-gathers from up and gate projections - if self.sequence_parallel_enabled: - x = gather_from_sequence_parallel_region( - x, self.sequence_dimension, process_group=get_tp_group(self.config) - ) - - gate_proj_output = ( - self.gate_proj(x) - if not is_lora_module(self.gate_proj) - else self.gate_proj(x, adapter_ids) - ) - up_proj_output = ( - self.up_proj(x) if not is_lora_module(self.up_proj) else self.up_proj(x, adapter_ids) - ) - down_proj_input = self.act_fn(gate_proj_output) * up_proj_output - output = ( - self.down_proj(down_proj_input) - if not is_lora_module(self.up_proj) - else self.down_proj(down_proj_input, adapter_ids) - ) - logger.debug(f"MLP output shape {output.shape}") - return output - - def forward(self, x, rmsnorm=None, residual=None, adapter_ids=None): - """ - If residual is passed in, will fuse its add into the MLP kernel - - Returns a tuple of (output, residual), where residual is the output of the residual add - """ - if self.mlp_kernel_enabled: - fused_rmsnorm = not self.sequence_parallel_enabled - # Quantized MLP kernel - if self.quantized_mlp_kernel_enabled: - return self._kernel_enabled_quantized_mlp( - x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids - ) - # MLP kernel - return self._kernel_enabled_mlp( - x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids - ) - else: - # No kernel - return (self._native_mlp(x, rmsnorm, adapter_ids=adapter_ids), None) - - -@register_module("NeuronQwen2Attention") -class NeuronQwen2Attention(NeuronAttentionBase): - """ - Compared with Qwen2Attention, this class just - 1. replaces the q_proj, k_proj, v_proj with column parallel layer - 2. replaces the o_proj with row parallel layer - 3. update self.num_head to be self.num_head / tp_degree - 4. update self.num_key_value_heads to be self.num_key_value_heads / tp_degree - 5. update forward() method to adjust to changes from self.num_head - """ - - def __init__(self, config: InferenceConfig, tensor_model_parallel_group=None): - super().__init__(tensor_model_parallel_group=tensor_model_parallel_group) - - self.config = config - self.neuron_config = config.neuron_config - self.hidden_size = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - self.head_dim = self.hidden_size // self.num_attention_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.padding_side = config.neuron_config.padding_side - self.torch_dtype = config.neuron_config.torch_dtype - self.is_medusa = config.neuron_config.is_medusa - self.flash_decoding_enabled = config.neuron_config.flash_decoding_enabled - self.num_cores_per_group = config.num_cores_per_group - self.bias = getattr(config, "attention_bias", True) - self.rpl_reduce_dtype = config.neuron_config.rpl_reduce_dtype - self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled - self.rms_norm_eps = config.rms_norm_eps - - if parallel_state.model_parallel_is_initialized(): - self.tp_degree = self.config.neuron_config.tp_degree - else: - self.tp_degree = 1 - - self.fused_qkv = config.neuron_config.fused_qkv - self.clip_qkv = None - - self.sequence_parallel_enabled = self.neuron_config.sequence_parallel_enabled - self.sequence_dimension = 1 if self.sequence_parallel_enabled else None - logger.debug( - f"Hello from NeuronQwen2Attention init! Is SP enabled? {self.sequence_parallel_enabled}. Dim? {self.sequence_dimension}" - ) - - self.init_gqa_properties() - - self.init_rope() - - self.o_proj = GroupQueryAttention_O( - hidden_size=self.hidden_size, - head_dim=self.head_dim, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - tp_degree=self.tp_degree, - dtype=self.torch_dtype, - bias=False, - input_is_parallel=True, - layer_name=self.o_proj_layer_name, - sequence_parallel_enabled=self.sequence_parallel_enabled, - sequence_dimension=self.sequence_dimension, - tensor_model_parallel_group=self.tensor_model_parallel_group, - rpl_reduce_dtype=self.rpl_reduce_dtype, - ) - - def init_rope(self): - if not hasattr(self.config, "rope_scaling") or self.config.rope_scaling is None: - # TODO(yihsian): Check if we can just use our own implementation - if self.is_medusa: - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - self.rotary_emb = RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - rope_type = self.config.rope_scaling.get( - "rope_type", self.config.rope_scaling.get("type", None) - ) - if rope_type == "Qwen2": - self.rotary_emb = Qwen2RotaryEmbedding( - dim=self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - factor=self.config.rope_scaling["factor"], - low_freq_factor=self.config.rope_scaling["low_freq_factor"], - high_freq_factor=self.config.rope_scaling["high_freq_factor"], - original_max_position_embeddings=self.config.rope_scaling[ - "original_max_position_embeddings" - ], - ) - else: - # Qwen2RotaryEmbedding automatically chooses the correct scaling type from config. - # Warning: The HF implementation may have precision issues when run on Neuron. - # We include it here for compatibility with other scaling types. - self.rotary_emb = LlamaRotaryEmbedding(self.config) - - -# TODO: Modularize RotaryEmbedding. See how HF transformers does it in 4.43. -class Qwen2RotaryEmbedding(nn.Module): - """ - Adapted from Qwen2 4.43 impl - * https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/Qwen2/modeling_Qwen2.py#L78 - * https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/modeling_rope_utils.py#L345 - - This implementation ensures inv_freq is calculated and stored in fp32. - """ - - def __init__( - self, - dim, - max_position_embeddings=131072, - base=500000.0, - factor=8.0, - low_freq_factor=1.0, - high_freq_factor=4.0, - original_max_position_embeddings=8192, - ): - super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.factor = factor - self.low_freq_factor = low_freq_factor - self.high_freq_factor = high_freq_factor - self.old_context_len = original_max_position_embeddings - self.register_buffer("inv_freq", None, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] - if self.inv_freq is None: - inv_freq = 1.0 / ( - self.base - ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) - ) - - low_freq_wavelen = self.old_context_len / self.low_freq_factor - high_freq_wavelen = self.old_context_len / self.high_freq_factor - new_freqs = [] - for freq in inv_freq: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / self.factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (self.old_context_len / wavelen - self.low_freq_factor) / ( - self.high_freq_factor - self.low_freq_factor - ) - new_freqs.append((1 - smooth) * freq / self.factor + smooth * freq) - self.inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device) - - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - position_ids_expanded = position_ids[:, None, :].float() - with torch.autocast(device_type=x.device.type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class NeuronQwen2DecoderLayer(nn.Module): - """ - Just replace the attention with the NXD version, and MLP with the NXD version - """ - - def __init__(self, config: InferenceConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = _Qwen2_MODULE_MAP[config.neuron_config.attn_cls]( - config=config, tensor_model_parallel_group=get_tp_group(config) - ) - self.mlp = NeuronQwen2MLP(config) - logger.debug( - f"Instantiating RMSNorm modules with hidden size {config.hidden_size} and EPS {config.rms_norm_eps}" - ) - self.input_layernorm = None - if ( - not config.neuron_config.is_eagle_draft - or config.neuron_config.enable_eagle_draft_input_norm - ): - self.input_layernorm = get_rmsnorm_cls()( - config.hidden_size, - eps=config.rms_norm_eps, - ) - self.post_attention_layernorm = get_rmsnorm_cls()( - config.hidden_size, - eps=config.rms_norm_eps, - ) - self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled - self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled - self.rmsnorm_quantize_kernel_enabled = config.neuron_config.rmsnorm_quantize_kernel_enabled - self.mlp_kernel_fuse_residual_add = config.neuron_config.mlp_kernel_fuse_residual_add - self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled - self.config = config - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - adapter_ids=None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - - # RMSNorm (fused with QKV kernel when SP is disabled) - if (not self.qkv_kernel_enabled or self.sequence_parallel_enabled) and self.input_layernorm: - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - adapter_ids=adapter_ids, - rmsnorm=self.input_layernorm, - **kwargs, - ) - - if self.mlp_kernel_enabled and self.mlp_kernel_fuse_residual_add: - assert ( - not self.sequence_parallel_enabled - ), "mlp_kernel_fuse_residual_add should be off when sequence parallelism is enabled" - # First residual add handled in the MLP kernel - hidden_states, residual = self.mlp( - hidden_states, - rmsnorm=self.post_attention_layernorm, - residual=residual, - adapter_ids=adapter_ids, - ) - else: - hidden_states = residual + hidden_states - residual = hidden_states - # RMSNorm (fused with QKV kernel when SP is disabled) - if not self.mlp_kernel_enabled or self.sequence_parallel_enabled: - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states, _ = self.mlp( - hidden_states, - rmsnorm=self.post_attention_layernorm, - adapter_ids=adapter_ids, - ) - - hidden_states = residual + hidden_states - - outputs = (hidden_states, present_key_value, cos_cache, sin_cache) - return outputs - - -class ResBlock(nn.Module): - """ - A Residual Block module. - - This module performs a linear transformation followed by a SiLU activation, - and then adds the result to the original input, creating a residual connection. - - Args: - hidden_size (int): The size of the hidden layers in the block. - """ - - def __init__(self, hidden_size): - super().__init__() - self.linear = nn.Linear(hidden_size, hidden_size) - # Initialize as an identity mapping - torch.nn.init.zeros_(self.linear.weight) - # Use SiLU activation to keep consistent with the Qwen2 model - self.act = nn.SiLU() - - def forward(self, x): - """ - Forward pass of the ResBlock. - - Args: - x (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Output after the residual connection and activation. - """ - return x + self.act(self.linear(x)) - - -class NeuronQwen2Model(NeuronBaseModel): - """ - The neuron version of the Qwen2Model - """ - - def setup_attr_for_model(self, config: InferenceConfig): - # Needed for init_inference_optimization() - self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None - self.tp_degree = config.neuron_config.tp_degree - self.hidden_size = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.num_key_value_heads = config.num_key_value_heads - self.max_batch_size = config.neuron_config.max_batch_size - self.buckets = config.neuron_config.buckets - - def init_model(self, config: InferenceConfig): - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - if parallel_state.model_parallel_is_initialized(): - self.embed_tokens = ParallelEmbedding( - config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=config.neuron_config.torch_dtype, - shard_across_embedding=not config.neuron_config.vocab_parallel, - sequence_parallel_enabled=False, - pad=True, - tensor_model_parallel_group=get_tp_group(config), - use_spmd_rank=config.neuron_config.vocab_parallel, - ) - - self.lm_head = ColumnParallelLinear( - config.hidden_size, - config.vocab_size, - gather_output=not self.on_device_sampling, - bias=False, - pad=True, - tensor_model_parallel_group=get_tp_group(config), - ) - else: - self.embed_tokens = nn.Embedding( - config.vocab_size, - config.hidden_size, - self.padding_idx, - ) - self.lm_head = nn.Linear( - config.hidden_size, - config.vocab_size, - bias=False, - ) - - # In the target fp8 checkpoint, the 1st and last - # layers are not using fp8. - updated_configs = [] - for i in range(config.num_hidden_layers): - # TODO: Remove hardcoded code to have non-quantized MLPs for first and last decoder block - if i == 0 or i == config.num_hidden_layers - 1: - non_quant_config = copy.deepcopy(config) - non_quant_config.neuron_config.quantized_mlp_kernel_enabled = False - updated_configs.append(non_quant_config) - else: - updated_configs.append(config) - self.layers = nn.ModuleList([NeuronQwen2DecoderLayer(conf) for conf in updated_configs]) - if not config.neuron_config.is_eagle_draft: - self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) - - if config.neuron_config.is_eagle_draft: - fc_bias = getattr(config, "fc_bias", False) - self.fc = ColumnParallelLinear( - config.hidden_size * 2, config.hidden_size, bias=fc_bias, gather_output=True - ) - self.is_medusa = config.neuron_config.is_medusa - self.num_medusa_heads = config.neuron_config.num_medusa_heads - self.medusa_speculation_length = config.neuron_config.medusa_speculation_length - - if self.is_medusa: - if parallel_state.model_parallel_is_initialized(): - medusa_head_cls = ColumnParallelLinear - else: - medusa_head_cls = nn.Linear - for i in range(self.num_medusa_heads): - medusa_head = nn.Sequential( - *([ResBlock(config.hidden_size)] * 1), - medusa_head_cls( - config.hidden_size, - config.vocab_size, - gather_output=not self.on_device_sampling, - bias=False, - ), - ) - setattr(self, f"medusa_head_{i}", medusa_head) - - -class NeuronQwen2ForCausalLM(NeuronBaseForCausalLM): - """ - This class extends Qwen2ForCausalLM create traceable - blocks for Neuron. - - Args: - Qwen2ForCausalLM (_type_): _description_ - """ - - _model_cls = NeuronQwen2Model - - @staticmethod - def load_hf_model(model_path): - return Qwen2ForCausalLM.from_pretrained(model_path) - - @staticmethod - def convert_hf_to_neuron_state_dict(state_dict: dict, config: InferenceConfig) -> dict: - """This function should be over-ridden in child classes as needed""" - neuron_config = config.neuron_config - if neuron_config.fused_qkv: - state_dict = convert_state_dict_to_fused_qkv(state_dict, config) - - if neuron_config.vocab_parallel: - # TODO: this hack can be removed after replication_id is ready to use - state_dict["embed_tokens.rank_util.rank"] = torch.arange( - 0, neuron_config.local_ranks_size - ) - - # to facilitate rank usage in attention - num_layers = config.num_hidden_layers - tp_degree = neuron_config.tp_degree - for i in range(num_layers): - state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( - 0, tp_degree, dtype=torch.int32 - ) - # to facilitate rank usage in base model - state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) - return state_dict - - @staticmethod - def update_state_dict_for_tied_weights(state_dict): - state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() - - @classmethod - def get_config_cls(cls): - return Qwen2InferenceConfig \ No newline at end of file diff --git a/contributed/models/qwen2/qwen-test.ipynb b/contributed/models/qwen2/qwen-test.ipynb deleted file mode 100644 index 72f4afe..0000000 --- a/contributed/models/qwen2/qwen-test.ipynb +++ /dev/null @@ -1,351 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip list | grep neuron" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "\n", - "from transformers import AutoTokenizer, GenerationConfig\n", - "from modeling_qwen import Qwen2InferenceConfig, NeuronQwen2ForCausalLM\n", - "from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig\n", - "from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter, load_pretrained_config" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "model_path = \"/home/ubuntu/model_hf_qwq/qwq/\"\n", - "traced_model_path = \"/home/ubuntu/traced_model_qwq/qwq/\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from huggingface_hub import HfFolder\n", - "HfFolder.save_token(\"YOUR TOKEN HERE\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from huggingface_hub import snapshot_download\n", - "\n", - "snapshot_download(\"Qwen/QwQ-32B\", local_dir=model_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "from modeling_qwen import Qwen2InferenceConfig, NeuronQwen2ForCausalLM\n", - "\n", - "def run_qwq_compile():\n", - " # Initialize configs and tokenizer.\n", - " tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\"right\")\n", - " tokenizer.pad_token = tokenizer.eos_token\n", - "\n", - " generation_config = GenerationConfig.from_pretrained(model_path)\n", - " generation_config_kwargs = {\n", - " \"do_sample\": True,\n", - " \"top_k\": 1,\n", - " \"pad_token_id\": tokenizer.pad_token_id,\n", - " }\n", - " generation_config.update(**generation_config_kwargs)\n", - " \n", - " neuron_config = NeuronConfig(\n", - " tp_degree=8,\n", - " batch_size=1,\n", - " max_context_length=4096,\n", - " seq_len=8096,\n", - " on_device_sampling_config=OnDeviceSamplingConfig(top_k=5),\n", - " enable_bucketing=True,\n", - " context_encoding_buckets=[128, 1024, 4096],\n", - " token_generation_buckets=[128, 1024, 8096],\n", - " flash_decoding_enabled=False,\n", - " torch_dtype=torch.bfloat16,\n", - " fused_qkv=False,\n", - " attn_cls=\"NeuronQwen2Attention\"\n", - " )\n", - " config = Qwen2InferenceConfig(\n", - " neuron_config,\n", - " load_config=load_pretrained_config(model_path),\n", - " )\n", - " \n", - " # Compile and save model.\n", - " print(\"\\nCompiling and saving model...\")\n", - " model = NeuronQwen2ForCausalLM(model_path, config)\n", - " model.compile(traced_model_path)\n", - " tokenizer.save_pretrained(traced_model_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "run_qwq_compile()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Run inference" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = NeuronQwen2ForCausalLM(traced_model_path)\n", - "model.load(traced_model_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [], - "source": [ - "tokenizer = AutoTokenizer.from_pretrained(traced_model_path)\n", - "tokenizer.pad_token = tokenizer.eos_token\n", - "generation_config = GenerationConfig.from_pretrained(model_path)\n", - "generation_config_kwargs = {\n", - " \"do_sample\": True,\n", - " \"temperature\": 0.9,\n", - " \"top_k\": 5,\n", - " \"pad_token_id\": tokenizer.pad_token_id,\n", - "}\n", - "generation_config.update(**generation_config_kwargs)\n", - "generation_model = HuggingFaceGenerationAdapter(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [], - "source": [ - "# Define a list of prompts\n", - "prompts = [\n", - " \"How many r's are in the word \\\"strawberry\\\"\",\n", - "]\n", - "\n", - "# Create messages for each prompt\n", - "messages_list = [\n", - " [{\"role\": \"user\", \"content\": prompt}] for prompt in prompts\n", - "]\n", - "\n", - "# Apply chat template to each set of messages\n", - "texts = [\n", - " tokenizer.apply_chat_template(\n", - " messages,\n", - " tokenize=False,\n", - " add_generation_prompt=True\n", - " ) for messages in messages_list\n", - "]\n", - "\n", - "# Tokenize the batch of texts\n", - "model_inputs = tokenizer(texts, return_tensors=\"pt\", padding=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"\\nGenerating outputs...\")\n", - "outputs = generation_model.generate(\n", - " **model_inputs,\n", - " generation_config=generation_config,\n", - " max_length=model.config.neuron_config.max_length,\n", - ")\n", - "output_tokens = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'Generated outputs:'" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "'Output 0: user\\nHow many r\\'s are in the word \"strawberry\"\\nassistant\\n\\nOkay, so I need to figure out how many times the letter \\'r\\' appears in the word \"strawberry.\" Let me start by writing down the word and looking at each letter one by one. \\n\\nFirst, I\\'ll spell out \"strawberry\" to make sure I have all the letters right. S-T-R-A-W-B-E-R-R-Y. Wait, let me check that again. Sometimes I might miss a letter. Let me count the letters as I write them:\\n\\n1. S\\n2. T\\n3. R\\n4. A\\n5. W\\n6. B\\n7. E\\n8. R\\n9. R\\n10. Y\\n\\nHmm, so that\\'s 10 letters in total. Now, I need to count how many times \\'R\\' shows up. Let me go through each letter again and note the positions where \\'R\\' is.\\n\\nStarting from the first letter:\\n1. S – not an R\\n2. T – not an R\\n3. R – that\\'s the first R\\n4. A – no\\n5. W – no\\n6. B – no\\n7. E – no\\n8. R – second R\\n9. R – third R\\n10. Y – no\\n\\nWait a second, so after the first R at position 3, the next R is at position 8, and then another at 9? Let me confirm the spelling again because sometimes people might confuse \"strawberry\" with other similar words. Let me think: S-T-R-A-W-B-E-R-R-Y. Yes, that\\'s correct. After the \\'E\\', there are two R\\'s in a row, right? So positions 8 and 9 are both R\\'s. So that would make three R\\'s in total: one at position 3, and two at 8 and 9. \\n\\nBut hold on, maybe I miscounted the letters. Let me write them out again with numbers to be sure:\\n\\n1. S\\n2. T\\n3. R\\n4. A\\n5. W\\n6. B\\n7. E\\n8. R\\n9. R\\n10. Y\\n\\nYes, that\\'s correct. So the letters R are at positions 3, 8, and 9. That\\'s three R\\'s. Wait, but sometimes when I say \"strawberry,\" I might not pronounce the second R as clearly, but spelling-wise, it\\'s definitely there. Let me check another way. Maybe breaking the word into parts. \"Straw\" and \"berry.\" In \"straw,\" there\\'s an R. Then in \"berry,\" which is B-E-R-R-Y. So \"berry\" has two R\\'s. So adding the one from \"straw,\" that\\'s three total. \\n\\nAlternatively, maybe I can think of the word as S-T-R-A-W-B-E-R-R-Y. So breaking it down:\\n\\n- S T R (so first R)\\n- A W B E (no R\\'s here)\\n- R R Y (two more R\\'s)\\n\\nSo that\\'s 1 + 2 = 3 R\\'s. \\n\\nI think that\\'s right, but I want to be absolutely sure. Let me try another approach. Let me write the word and circle each R:\\n\\nS T R A W B E R R Y\\n\\nCircling the R\\'s: the third letter is R, then the eighth and ninth letters are R\\'s. So three in total. \\n\\nAlternatively, maybe I can use a different method. Let me count the letters one by one and tally the R\\'s:\\n\\nStarting with S: 0\\nT: 0\\nR: 1\\nA: 1\\nW:1\\nB:1\\nE:1\\nR: 2\\nR:3\\nY:3\\n\\nWait, no, that\\'s not the right way. Each time I see an R, I should increment the count. Let me try again:\\n\\n1. S – count remains 0\\n2. T – 0\\n3. R – count becomes 1\\n4. A – 1\\n5. W –1\\n6. B –1\\n7. E –1\\n8. R – count becomes 2\\n9. R – count becomes 3\\n10. Y –3\\n\\nYes, so the final count is 3. \\n\\nI think I might have confused myself earlier when I thought maybe two, but upon multiple checks, it\\'s three. Let me see if any sources or examples say otherwise. Wait, maybe I should just confirm by looking up the spelling of \"strawberry\" again. \\n\\nLooking it up in my mind: S-T-R-A-W-B-E-R-R-Y. Yes, that\\'s correct. The standard spelling has three R\\'s. So the answer should be three. \\n\\nAlternatively, maybe I can think of the pronunciation. In some accents, the double R in \"berry\" might be pronounced as a single sound, but that doesn\\'t change the spelling. The question is about the written word, so the letters are what matter, not the pronunciation. \\n\\nTherefore, after carefully going through each letter and multiple methods of counting, I can confidently say there are three R\\'s in \"strawberry.\"\\n\\n\\nThe word \"strawberry\" contains three instances of the letter \\'r\\'. Here\\'s the breakdown:\\n\\n1. **S** \\n2. **T** \\n3. **R** (1st \\'r\\') \\n4. **A** \\n5. **W** \\n6. **B** \\n7. **E** \\n8. **R** (2nd \\'r\\') \\n9. **R** (3rd \\'r\\') \\n10. **Y** \\n\\n**Answer:** There are **3 r\\'s** in the word \"strawberry.\"'" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from IPython.display import display\n", - "display(\"Generated outputs:\")\n", - "for i, output_token in enumerate(output_tokens):\n", - " display(f\"Output {i}: {output_token}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [], - "source": [ - "model.reset()" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [], - "source": [ - "del model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Test Token Output" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "dir = '/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/'\n", - "!cp modeling_qwen.py {dir}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!cp {dir}/inference_demo.py ." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Add the following to the inference_demo.py we just copied to our working directory\n", - "\n", - "```\n", - "from .modeling_qwen import NeuronQwen2ForCausalLM\n", - "\n", - "MODEL_TYPES = {\n", - " \"llama\": {\"causal-lm\": NeuronLlamaForCausalLM},\n", - " \"mixtral\": {\"causal-lm\": NeuronMixtralForCausalLM},\n", - " \"dbrx\": {\"causal-lm\": NeuronDbrxForCausalLM},\n", - " \"qwen\": {'causal-lm': NeuronQwen2ForCausalLM} #add this line\n", - "}\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!cp ./inference_demo.py {dir}/inference_demo.py" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Restart your kernel" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!inference_demo \\\n", - " --model-type qwen \\\n", - " --task-type causal-lm \\\n", - " run \\\n", - " --model-path /home/ubuntu/model_hf_qwq/qwq/ \\\n", - " --compiled-model-path /home/ubuntu/traced_model_qwq/qwq/ \\\n", - " --torch-dtype bfloat16 \\\n", - " --tp-degree 8 \\\n", - " --batch-size 1 \\\n", - " --max-context-length 32 \\\n", - " --seq-len 64 \\\n", - " --on-device-sampling \\\n", - " --enable-bucketing \\\n", - " --top-k 1 \\\n", - " --do-sample \\\n", - " --pad-token-id 32000 \\\n", - " --prompt \"To be, or not to be\" \\\n", - " --check-accuracy-mode token-matching \\\n", - " --benchmark" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "aws_neuronx_venv_pytorch_2_5_nxd_inference", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From e1611a1ea82703ead98e7a4f58c401a82bc8081f Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Wed, 30 Apr 2025 13:10:36 -0400 Subject: [PATCH 4/7] lint --- contributed/models/qwen3/modeling_qwen.py | 996 ++++++++++++++++++++++ 1 file changed, 996 insertions(+) create mode 100644 contributed/models/qwen3/modeling_qwen.py diff --git a/contributed/models/qwen3/modeling_qwen.py b/contributed/models/qwen3/modeling_qwen.py new file mode 100644 index 0000000..cc5eb7f --- /dev/null +++ b/contributed/models/qwen3/modeling_qwen.py @@ -0,0 +1,996 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen3 model for NXD inference.""" + +import copy +import gc +import logging +from typing import List, Optional, Tuple, Type + + +from neuronx_distributed_inference.modules.attention.utils import ( + apply_rotary_pos_emb, + move_heads_front, +) + +import torch +from neuronx_distributed.parallel_layers import parallel_state # noqa: E402 +from neuronx_distributed.parallel_layers.layers import ( # noqa: E402; noqa: E402; noqa: E402; noqa: E402; noqa: E402 + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) +from neuronx_distributed.parallel_layers.utils import get_padding_length +from neuronx_distributed.quantization.quantization_config import ( + QuantizationType, + QuantizedDtype, +) +from neuronx_distributed.quantization.quantization_layers import ( # noqa: E402; noqa: E402; noqa: E402; noqa: E402; noqa: E402 + QuantizedColumnParallel, + QuantizedRowParallel, +) + +from neuronxcc.nki._private_kernels.mlp import ( + mlp_fused_add_isa_kernel, + mlp_isa_kernel, + quant_mlp_fused_add_isa_kernel, + quant_mlp_isa_kernel, +) +from neuronxcc.nki._private_kernels.rmsnorm import rmsnorm_quant_isa_kernel +from neuronxcc.nki.language import nc +from torch import nn +from torch_neuronx.xla_impl.ops import nki_jit +from transformers import Qwen3ForCausalLM +from transformers.activations import ACT2FN +from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm, Qwen3RotaryEmbedding + +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig # noqa: E402 +from neuronx_distributed_inference.models.model_base import ( # noqa: E402 + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.gqa import ( # noqa: E402 + BaseGroupQueryAttention, +) +from neuronx_distributed_inference.modules.attention.utils import ( + preprocess_quantized_linear_layer, + transpose_parallel_linear_layer, +) +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.flashdecode.utils import ( + calculate_num_cores_per_group, +) +from neuronx_distributed_inference.modules.lora_serving.lora_module import ( + is_lora_module, +) +from neuronx_distributed_inference.utils.distributed import get_tp_group + +_Qwen3_MODULE_MAP = {} + +logger = logging.getLogger("Neuron") + + +def get_rmsnorm_cls(): + # Initialize to the appropriate implementation of RMSNorm + # If infer on NXD -> CustomRMSNorm + # If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU) + return ( + CustomRMSNorm + if parallel_state.model_parallel_is_initialized() + else Qwen3RMSNorm + ) + + +def preshard_hook_fn( + module: torch.nn.Module, model_state_dict: dict, prefix: str +) -> bool: + if isinstance(module, (BaseGroupQueryAttention,)): + return module.preshard_hook(model_state_dict, prefix) + + return False + + +def _register_module(key: str, cls: Type[nn.Module]): + _Qwen3_MODULE_MAP[key] = cls + + +def register_module(key: str): + """ + Register a module for use in NeuronQwen3. + Arguments: + key: String used to identify the module + Example: + @register_module("NeuronQwen3Attention") + class NeuronQwen3Attention(nn.Module): + ... + """ + + def inner(cls: Type[nn.Module]): + _register_module(key, cls) + return cls + + return inner + + +def convert_state_dict_to_fused_qkv(Qwen3_state_dict, cfg: InferenceConfig): + """ + This function concats the qkv weights to a Wqkv weight for fusedqkv, and deletes the qkv weights. + """ + for l in range(cfg.num_hidden_layers): # noqa: E741 + Qwen3_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat( + [ + Qwen3_state_dict[f"layers.{l}.self_attn.q_proj.weight"], + Qwen3_state_dict[f"layers.{l}.self_attn.k_proj.weight"], + Qwen3_state_dict[f"layers.{l}.self_attn.v_proj.weight"], + ], + ) + del Qwen3_state_dict[f"layers.{l}.self_attn.q_proj.weight"] + del Qwen3_state_dict[f"layers.{l}.self_attn.k_proj.weight"] + del Qwen3_state_dict[f"layers.{l}.self_attn.v_proj.weight"] + + gc.collect() + + return Qwen3_state_dict + + +class Qwen3InferenceConfig(InferenceConfig): + def add_derived_config(self): + self.num_cores_per_group = 1 + if self.neuron_config.flash_decoding_enabled: + num_attn_heads, num_kv_heads = ( + self.num_attention_heads, + self.num_key_value_heads, + ) + self.num_cores_per_group = calculate_num_cores_per_group( + num_attn_heads, num_kv_heads, self.neuron_config.tp_degree + ) + + def get_required_attributes(self) -> List[str]: + return [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "pad_token_id", + "vocab_size", + "max_position_embeddings", + "rope_theta", + "rms_norm_eps", + "hidden_act", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return NeuronConfig + + +class NeuronQwen3MLP(nn.Module): + """ + This class just replace the linear layers (gate_proj, up_proj and down_proj) with column and row parallel layers + """ + + def __init__(self, config: InferenceConfig): + super().__init__() + self.config = config + self.neuron_config = config.neuron_config + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.act_fn = ACT2FN[config.hidden_act] + + self.sequence_parallel_enabled = getattr( + self.neuron_config, "sequence_parallel_enabled", False + ) + self.sequence_dimension = 1 if self.sequence_parallel_enabled else None + self.rms_norm_eps = config.rms_norm_eps + self.mlp_kernel_enabled = self.neuron_config.mlp_kernel_enabled + self.quantized_mlp_kernel_enabled = ( + self.neuron_config.quantized_mlp_kernel_enabled + ) + self.rmsnorm_quantize_kernel_enabled = ( + self.neuron_config.rmsnorm_quantize_kernel_enabled + ) + self.logical_neuron_cores = self.neuron_config.logical_neuron_cores + mlp_bias = getattr(config, "mlp_bias", False) + if parallel_state.model_parallel_is_initialized(): + if self.quantized_mlp_kernel_enabled: + # Quantized MLP kernels expect intermediate size to be multiple of 128, so we need to pad + tp_degree = self.neuron_config.tp_degree + self.intermediate_size += ( + get_padding_length(self.intermediate_size // tp_degree, 128) + * tp_degree + ) + logger.debug(f"Quantized intermediate_size: {self.intermediate_size}") + + quantization_type = QuantizationType( + self.neuron_config.quantization_type + ) + quantized_dtype = QuantizedDtype.F8E4M3 + self.gate_proj = QuantizedColumnParallel( + input_size=self.hidden_size, + output_size=self.intermediate_size, + bias=mlp_bias, + gather_output=False, + sequence_parallel_enabled=False, + dtype=config.neuron_config.torch_dtype, + quantized_dtype=quantized_dtype, + quantization_type=quantization_type, + tensor_model_parallel_group=get_tp_group(config), + ) + self.up_proj = QuantizedColumnParallel( + input_size=self.hidden_size, + output_size=self.intermediate_size, + bias=mlp_bias, + gather_output=False, + sequence_parallel_enabled=False, + dtype=config.neuron_config.torch_dtype, + quantized_dtype=quantized_dtype, + quantization_type=quantization_type, + tensor_model_parallel_group=get_tp_group(config), + ) + self.down_proj = QuantizedRowParallel( + input_size=self.intermediate_size, + output_size=self.hidden_size, + bias=mlp_bias, + quantization_type=quantization_type, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + quantized_dtype=quantized_dtype, + sequence_parallel_enabled=False, + quantization_per_channel_axis=0, + tensor_model_parallel_group=get_tp_group(config), + ) + + else: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=mlp_bias, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + tensor_model_parallel_group=get_tp_group(config), + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=mlp_bias, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=False, + sequence_dimension=None, + tensor_model_parallel_group=get_tp_group(config), + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=mlp_bias, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + pad=True, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + tensor_model_parallel_group=get_tp_group(config), + reduce_dtype=config.neuron_config.rpl_reduce_dtype, + ) + + if self.mlp_kernel_enabled: + if self.quantized_mlp_kernel_enabled: + preprocess_quantized_linear_layer(self.gate_proj) + preprocess_quantized_linear_layer(self.up_proj) + preprocess_quantized_linear_layer(self.down_proj) + + else: + # Transpose the weights to the layout expected by kernels + self.gate_proj.weight = transpose_parallel_linear_layer( + self.gate_proj.weight + ) + self.up_proj.weight = transpose_parallel_linear_layer( + self.up_proj.weight + ) + self.down_proj.weight = transpose_parallel_linear_layer( + self.down_proj.weight + ) + + else: + self.gate_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=mlp_bias + ) + self.up_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=mlp_bias + ) + self.down_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=mlp_bias + ) + + def _kernel_enabled_quantized_mlp( + self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids + ): + grid = (nc(self.logical_neuron_cores),) + fused_residual = residual is not None + logger.debug( + f"MLP: quantized kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" + ) + + # Can't do residual add in the kernel if SP is enabled + if fused_residual: + assert not self.sequence_parallel_enabled, ( + "Quantized MLP cannot have both fused residual add and sequence parallel RMSnorm!" + ) + # Using fused residual add + _mlp_fwd_call = nki_jit()(quant_mlp_fused_add_isa_kernel) + else: + _mlp_fwd_call = nki_jit()(quant_mlp_isa_kernel) + + # Handle SP RMSnorm + x_orig_dtype = x.dtype + if self.sequence_parallel_enabled: + # This RMSNormQuant kernel will do quantization inside, so we pass the + # lower_bound for clipping. + # If we don't use this kernel, the MLP kernel below will do the + # quantization, so we also pass lower_bound to that kernel. + if self.rmsnorm_quantize_kernel_enabled: + logger.debug( + "Running Quantized MLP kernel with sequence-parallel RMSnorm-Quantize kernel!" + ) + _rmsnorm_quant_fwd_call = nki_jit()(rmsnorm_quant_isa_kernel) + quant_rmsnorm_out = torch.zeros( + size=( + x.shape[0], # batch size + x.shape[1], # sequence length + x.shape[2] + 4, # hidden size + 4 bytes for packing fp32 scale + ), + dtype=torch.int8, + device=x.device, + ) + ln_w = rmsnorm.weight.unsqueeze(0) + lower_bound = self.quantized_kernel_lower_bound + _rmsnorm_quant_fwd_call[grid]( + x, ln_w, lower_bound, quant_rmsnorm_out, kernel_name="QuantOnly" + ) + x = gather_from_sequence_parallel_region( + quant_rmsnorm_out, + self.sequence_dimension, + process_group=get_tp_group(self.config), + ) + + else: + logger.debug( + "Running Quantized MLP kernel with external (native compiler) sequence-parallel RMSnorm!" + ) + x = gather_from_sequence_parallel_region( + x, self.sequence_dimension, process_group=get_tp_group(self.config) + ) + + # Build output tensor + output_tensor_seqlen = x.shape[1] + if fused_residual: + # seqlen dim is doubled to store the residual add output + output_tensor_seqlen *= 2 + + output_tensor = torch.zeros( + size=( + x.shape[0], # batch size + output_tensor_seqlen, + self.hidden_size, # hidden size + ), + dtype=x_orig_dtype, + device=x.device, + ) + + # Grab weights + # all weights of the layers are stored in (out, in) shape + # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] + ln_w = rmsnorm.weight.unsqueeze(0) + gate_w = self.gate_proj.weight.data + gate_w_scale = self.gate_proj.weight_scale + up_w = self.up_proj.weight.data + up_w_scale = self.up_proj.weight_scale + down_w = self.down_proj.weight.data + down_w_scale = self.down_proj.weight_scale + lower_bound = self.quantized_kernel_lower_bound + + if fused_residual: + _mlp_fwd_call[grid]( + x, # attn_output + residual, # hidden + ln_w, # ln_w + gate_w, # gate_w + gate_w_scale, + up_w, # up_w + up_w_scale, + down_w, # down_w + down_w_scale, + lower_bound, + output_tensor, # out + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + store_add=True, + ) + original_seqlen = x.shape[1] + residual = output_tensor[:, original_seqlen:, :] + output_tensor = output_tensor[:, :original_seqlen, :] + else: + _mlp_fwd_call[grid]( + x, # hidden + # should be fine to pass gamma is as a dummy even if not using fused rmsnorm + ln_w, + gate_w, # gate_w + gate_w_scale, + up_w, # up_w + up_w_scale, + down_w, # down_w + down_w_scale, + lower_bound, + output_tensor, # out + # Run RMSNorm inside the kernel if NOT using SP rmsnorm + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + ) + residual = None + + # All-reduce or reduce-scatter, depending on whether SP is enabled + if self.sequence_parallel_enabled: + output_tensor = reduce_scatter_to_sequence_parallel_region( + output_tensor, + self.sequence_dimension, + process_group=get_tp_group(self.config), + ) + else: + output_tensor = reduce_from_tensor_model_parallel_region(output_tensor) + + logger.debug(f"Quantized MLP output shape {output_tensor.shape}") + return (output_tensor, residual) + + def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual, adapter_ids): + fused_residual = residual is not None + logger.debug( + f"MLP: kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_neuron_cores={self.logical_neuron_cores}" + ) + + # Choose which kernel to call + if fused_residual: + assert not self.sequence_parallel_enabled, ( + "MLP kernel cannot have both fused residual add and sequence parallel RMSnorm!" + ) + # Using fused residual add + _mlp_fwd_call = nki_jit()(mlp_fused_add_isa_kernel) + else: + _mlp_fwd_call = nki_jit()(mlp_isa_kernel) + + if self.sequence_parallel_enabled: + x = gather_from_sequence_parallel_region( + x, self.sequence_dimension, process_group=get_tp_group(self.config) + ) + + # Build output tensor + output_tensor_seqlen = x.shape[1] + if fused_residual: + # seqlen dim is doubled to store the residual add output + output_tensor_seqlen *= 2 + + output_tensor = torch.zeros( + size=( + x.shape[0], # batch size + output_tensor_seqlen, + self.hidden_size, # hidden size + ), + dtype=x.dtype, + device=x.device, + ) + + # Grab weights + # all weights of the layers are stored in (out, in) shape + # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden] + ln_w = rmsnorm.weight.unsqueeze(0) + gate_w = self.gate_proj.weight.data + up_w = self.up_proj.weight.data + down_w = self.down_proj.weight.data + + grid = (nc(self.logical_neuron_cores),) + + if fused_residual: + _mlp_fwd_call[grid]( + x, # attn_output + residual, # hidden + ln_w, # ln_w + gate_w, # gate_w + up_w, # up_w + down_w, # down_w + output_tensor, # out + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + store_add=True, + ) + original_seqlen = x.shape[1] + residual = output_tensor[:, original_seqlen:, :] + output_tensor = output_tensor[:, :original_seqlen, :] + else: + _mlp_fwd_call[grid]( + x, # hidden + # should be fine to pass gamma is as a dummy even if not using fused rmsnorm + ln_w, + gate_w, + up_w, + down_w, + output_tensor, # out + # Run RMSNorm inside the kernel if NOT using SP rmsnorm + fused_rmsnorm=fused_rmsnorm, + eps=self.rms_norm_eps, + kernel_name="MLP", + ) + residual = None + + # All-reduce or reduce-scatter, depending on whether SP is enabled + if self.sequence_parallel_enabled: + output_tensor = reduce_scatter_to_sequence_parallel_region( + output_tensor, + self.sequence_dimension, + process_group=get_tp_group(self.config), + ) + else: + output_tensor = reduce_from_tensor_model_parallel_region( + output_tensor, process_group=get_tp_group(self.config) + ) + + logger.debug(f"MLP output shape {output_tensor.shape}") + return (output_tensor, residual) + + def _native_mlp(self, x, rmsnorm, adapter_ids=None): + logger.debug("MLP: native compiler") + # all-gather is done here instead of CPL layers to + # avoid 2 all-gathers from up and gate projections + if self.sequence_parallel_enabled: + x = gather_from_sequence_parallel_region( + x, self.sequence_dimension, process_group=get_tp_group(self.config) + ) + + gate_proj_output = ( + self.gate_proj(x) + if not is_lora_module(self.gate_proj) + else self.gate_proj(x, adapter_ids) + ) + up_proj_output = ( + self.up_proj(x) + if not is_lora_module(self.up_proj) + else self.up_proj(x, adapter_ids) + ) + down_proj_input = self.act_fn(gate_proj_output) * up_proj_output + output = ( + self.down_proj(down_proj_input) + if not is_lora_module(self.up_proj) + else self.down_proj(down_proj_input, adapter_ids) + ) + logger.debug(f"MLP output shape {output.shape}") + return output + + def forward(self, x, rmsnorm=None, residual=None, adapter_ids=None): + """ + If residual is passed in, will fuse its add into the MLP kernel + Returns a tuple of (output, residual), where residual is the output of the residual add + """ + if self.mlp_kernel_enabled: + fused_rmsnorm = not self.sequence_parallel_enabled + # Quantized MLP kernel + if self.quantized_mlp_kernel_enabled: + return self._kernel_enabled_quantized_mlp( + x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids + ) + # MLP kernel + return self._kernel_enabled_mlp( + x, fused_rmsnorm, rmsnorm, residual, adapter_ids=adapter_ids + ) + else: + # No kernel + return (self._native_mlp(x, rmsnorm, adapter_ids=adapter_ids), None) + + +@register_module("NeuronQwen3Attention") +class NeuronQwen3Attention(NeuronAttentionBase): + """ + Compared with Qwen3Attention, this class just + 1. replaces the q_proj, k_proj, v_proj with column parallel layer + 2. replaces the o_proj with row parallel layer + 3. update self.num_head to be self.num_head / tp_degree + 4. update self.num_key_value_heads to be self.num_key_value_heads / tp_degree + 5. update forward() method to adjust to changes from self.num_head + """ + + def __init__(self, config: InferenceConfig, tensor_model_parallel_group=None): + super().__init__(tensor_model_parallel_group=tensor_model_parallel_group) + + self.config = config + self.neuron_config = config.neuron_config + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_attention_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.padding_side = config.neuron_config.padding_side + self.torch_dtype = config.neuron_config.torch_dtype + self.is_medusa = config.neuron_config.is_medusa + self.flash_decoding_enabled = config.neuron_config.flash_decoding_enabled + self.num_cores_per_group = config.num_cores_per_group + self.bias = getattr(config, "attention_bias", False) + self.rpl_reduce_dtype = config.neuron_config.rpl_reduce_dtype + self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled + self.rms_norm_eps = config.rms_norm_eps + + self.q_norm = Qwen3RMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # unlike olmo, only on the head dim! + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + if parallel_state.model_parallel_is_initialized(): + self.tp_degree = self.config.neuron_config.tp_degree + else: + self.tp_degree = 1 + + self.fused_qkv = config.neuron_config.fused_qkv + self.clip_qkv = None + + self.sequence_parallel_enabled = self.neuron_config.sequence_parallel_enabled + self.sequence_dimension = 1 if self.sequence_parallel_enabled else None + logger.debug( + f"Hello from NeuronQwen3Attention init! Is SP enabled? {self.sequence_parallel_enabled}. Dim? {self.sequence_dimension}" + ) + + self.init_gqa_properties() + self.init_rope() + + def prep_qkv_tensors( + self, + position_ids, + hidden_states, + past_key_value, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + ): + """take care of the shape, layout, group query, custom position encoding, etc.""" + Q, K, V = self.qkv_proj( + hidden_states=hidden_states, rmsnorm=rmsnorm, adapter_ids=adapter_ids + ) + + # Divide hidden_dim across heads for MHA + # Change layout: BSHD -> BHSD + bsz, q_len, _ = hidden_states.size() + if self.sequence_parallel_enabled: + q_len *= self.tensor_model_parallel_group.size() + + Q = move_heads_front( + Q, bsz, q_len, self.num_heads, self.head_dim, layernorm=self.q_norm + ) + K = move_heads_front( + K, + bsz, + q_len, + self.num_key_value_heads, + self.head_dim, + layernorm=self.k_norm, + ) + V = move_heads_front( + V, bsz, q_len, self.num_key_value_heads, self.head_dim, layernorm=None + ) + + # Rotate Q and K + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + + Q, K = apply_rotary_pos_emb(Q, K, cos_cache, sin_cache) + + return Q, K, V, cos_cache, sin_cache + + def init_rope(self): + self.rotary_emb = Qwen3RotaryEmbedding(self.config) + + +class NeuronQwen3DecoderLayer(nn.Module): + """ + Just replace the attention with the NXD version, and MLP with the NXD version + """ + + def __init__(self, config: InferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + # self.self_attn = _Qwen3_MODULE_MAP[config.neuron_config.attn_cls]( + self.self_attn = NeuronQwen3Attention( + config=config, tensor_model_parallel_group=get_tp_group(config) + ) + self.mlp = NeuronQwen3MLP(config) + logger.debug( + f"Instantiating RMSNorm modules with hidden size {config.hidden_size} and EPS {config.rms_norm_eps}" + ) + self.input_layernorm = None + if ( + not config.neuron_config.is_eagle_draft + or config.neuron_config.enable_eagle_draft_input_norm + ): + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled + self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled + self.rmsnorm_quantize_kernel_enabled = ( + config.neuron_config.rmsnorm_quantize_kernel_enabled + ) + self.mlp_kernel_fuse_residual_add = ( + config.neuron_config.mlp_kernel_fuse_residual_add + ) + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + adapter_ids=None, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + + # RMSNorm (fused with QKV kernel when SP is disabled) + if ( + not self.qkv_kernel_enabled or self.sequence_parallel_enabled + ) and self.input_layernorm: + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + adapter_ids=adapter_ids, + rmsnorm=self.input_layernorm, + **kwargs, + ) + + if self.mlp_kernel_enabled and self.mlp_kernel_fuse_residual_add: + assert not self.sequence_parallel_enabled, ( + "mlp_kernel_fuse_residual_add should be off when sequence parallelism is enabled" + ) + # First residual add handled in the MLP kernel + hidden_states, residual = self.mlp( + hidden_states, + rmsnorm=self.post_attention_layernorm, + residual=residual, + adapter_ids=adapter_ids, + ) + else: + hidden_states = residual + hidden_states + residual = hidden_states + # RMSNorm (fused with QKV kernel when SP is disabled) + if not self.mlp_kernel_enabled or self.sequence_parallel_enabled: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp( + hidden_states, + rmsnorm=self.post_attention_layernorm, + adapter_ids=adapter_ids, + ) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value, cos_cache, sin_cache) + return outputs + + +class ResBlock(nn.Module): + """ + A Residual Block module. + This module performs a linear transformation followed by a SiLU activation, + and then adds the result to the original input, creating a residual connection. + Args: + hidden_size (int): The size of the hidden layers in the block. + """ + + def __init__(self, hidden_size): + super().__init__() + self.linear = nn.Linear(hidden_size, hidden_size) + # Initialize as an identity mapping + torch.nn.init.zeros_(self.linear.weight) + # Use SiLU activation to keep consistent with the Qwen3 model + self.act = nn.SiLU() + + def forward(self, x): + """ + Forward pass of the ResBlock. + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: Output after the residual connection and activation. + """ + return x + self.act(self.linear(x)) + + +class NeuronQwen3Model(NeuronBaseModel): + """ + The neuron version of the Qwen3Model + """ + + def setup_attr_for_model(self, config: InferenceConfig): + # Needed for init_inference_optimization() + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if parallel_state.model_parallel_is_initialized(): + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=not config.neuron_config.vocab_parallel, + sequence_parallel_enabled=False, + pad=True, + tensor_model_parallel_group=get_tp_group(config), + use_spmd_rank=config.neuron_config.vocab_parallel, + ) + + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + pad=True, + tensor_model_parallel_group=get_tp_group(config), + ) + else: + self.embed_tokens = nn.Embedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + ) + self.lm_head = nn.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + ) + + # In the target fp8 checkpoint, the 1st and last + # layers are not using fp8. + updated_configs = [] + for i in range(config.num_hidden_layers): + # TODO: Remove hardcoded code to have non-quantized MLPs for first and last decoder block + if i == 0 or i == config.num_hidden_layers - 1: + non_quant_config = copy.deepcopy(config) + non_quant_config.neuron_config.quantized_mlp_kernel_enabled = False + updated_configs.append(non_quant_config) + else: + updated_configs.append(config) + self.layers = nn.ModuleList( + [NeuronQwen3DecoderLayer(conf) for conf in updated_configs] + ) + if not config.neuron_config.is_eagle_draft: + self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + if config.neuron_config.is_eagle_draft: + fc_bias = getattr(config, "fc_bias", False) + self.fc = ColumnParallelLinear( + config.hidden_size * 2, + config.hidden_size, + bias=fc_bias, + gather_output=True, + ) + self.is_medusa = config.neuron_config.is_medusa + self.num_medusa_heads = config.neuron_config.num_medusa_heads + self.medusa_speculation_length = config.neuron_config.medusa_speculation_length + + if self.is_medusa: + if parallel_state.model_parallel_is_initialized(): + medusa_head_cls = ColumnParallelLinear + else: + medusa_head_cls = nn.Linear + for i in range(self.num_medusa_heads): + medusa_head = nn.Sequential( + *([ResBlock(config.hidden_size)] * 1), + medusa_head_cls( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + ), + ) + setattr(self, f"medusa_head_{i}", medusa_head) + + +class NeuronQwen3ForCausalLM(NeuronBaseForCausalLM): + """ + This class extends Qwen3ForCausalLM create traceable + blocks for Neuron. + Args: + Qwen3ForCausalLM (_type_): _description_ + """ + + _model_cls = NeuronQwen3Model + + @staticmethod + def load_hf_model(model_path): + return Qwen3ForCausalLM.from_pretrained(model_path) + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: InferenceConfig + ) -> dict: + """This function should be over-ridden in child classes as needed""" + neuron_config = config.neuron_config + if neuron_config.fused_qkv: + state_dict = convert_state_dict_to_fused_qkv(state_dict, config) + + if neuron_config.vocab_parallel: + # TODO: this hack can be removed after replication_id is ready to use + state_dict["embed_tokens.rank_util.rank"] = torch.arange( + 0, neuron_config.local_ranks_size + ) + + # to facilitate rank usage in attention + num_layers = config.num_hidden_layers + tp_degree = neuron_config.tp_degree + for i in range(num_layers): + state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + # to facilitate rank usage in base model + state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) + return state_dict + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() + + @classmethod + def get_config_cls(cls): + return Qwen3InferenceConfig From 176ced2c2e102d5eef015314d2f69e1ec2d029ea Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Wed, 30 Apr 2025 14:15:28 -0400 Subject: [PATCH 5/7] add inference nb --- contributed/models/qwen3/qwen-3-test.ipynb | 358 +++++++++++++++++++++ 1 file changed, 358 insertions(+) create mode 100644 contributed/models/qwen3/qwen-3-test.ipynb diff --git a/contributed/models/qwen3/qwen-3-test.ipynb b/contributed/models/qwen3/qwen-3-test.ipynb new file mode 100644 index 0000000..cd6cc4e --- /dev/null +++ b/contributed/models/qwen3/qwen-3-test.ipynb @@ -0,0 +1,358 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "libneuronxla 2.2.1630.0\n", + "neuronx-cc 2.17.194.0+d312836f\n", + "neuronx-distributed 0.11.0\n", + "neuronx-distributed-inference 0.2.0\n", + "torch-neuronx 2.5.1.2.6.0\n" + ] + } + ], + "source": [ + "!pip list | grep neuron" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoTokenizer, GenerationConfig\n", + "from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig\n", + "from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter, load_pretrained_config" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "model_path = \"/home/ubuntu/model_hf_qwen/qwen/\"\n", + "traced_model_path = \"/home/ubuntu/traced_model_qwen/qwen/\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_path,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Qwen3ForCausalLM(\n", + " (model): Qwen3Model(\n", + " (embed_tokens): Embedding(151936, 4096)\n", + " (layers): ModuleList(\n", + " (0-35): 36 x Qwen3DecoderLayer(\n", + " (self_attn): Qwen3Attention(\n", + " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (q_norm): Qwen3RMSNorm((128,), eps=1e-06)\n", + " (k_norm): Qwen3RMSNorm((128,), eps=1e-06)\n", + " )\n", + " (mlp): Qwen3MLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=12288, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=12288, bias=False)\n", + " (down_proj): Linear(in_features=12288, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): Qwen3RMSNorm((4096,), eps=1e-06)\n", + " (post_attention_layernorm): Qwen3RMSNorm((4096,), eps=1e-06)\n", + " )\n", + " )\n", + " (norm): Qwen3RMSNorm((4096,), eps=1e-06)\n", + " (rotary_emb): Qwen3RotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=151936, bias=False)\n", + ")" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import snapshot_download\n", + "\n", + "snapshot_download(\"Qwen/Qwen3-8B\", local_dir=model_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from modeling_qwen import Qwen3InferenceConfig, NeuronQwen3ForCausalLM\n", + "\n", + "def run_qwen3_compile():\n", + " # Initialize configs and tokenizer.\n", + " tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\"right\")\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + " generation_config = GenerationConfig.from_pretrained(model_path)\n", + " generation_config_kwargs = {\n", + " \"do_sample\": True,\n", + " \"top_k\": 1,\n", + " \"pad_token_id\": tokenizer.pad_token_id,\n", + " }\n", + " generation_config.update(**generation_config_kwargs)\n", + " \n", + " neuron_config = NeuronConfig(\n", + " tp_degree=8,\n", + " batch_size=1,\n", + " max_context_length=128,\n", + " seq_len=256,\n", + " on_device_sampling_config=OnDeviceSamplingConfig(top_k=5),\n", + " enable_bucketing=True,\n", + " context_encoding_buckets=[128],\n", + " token_generation_buckets=[256],\n", + " flash_decoding_enabled=False,\n", + " torch_dtype=torch.bfloat16,\n", + " fused_qkv=False,\n", + " attn_kernel_enabled=True,\n", + " attn_cls=\"NeuronQwen3Attention\"\n", + " )\n", + " config = Qwen3InferenceConfig(\n", + " neuron_config,\n", + " load_config=load_pretrained_config(model_path),\n", + " )\n", + " \n", + " # Compile and save model.\n", + " print(\"\\nCompiling and saving model...\")\n", + " model = NeuronQwen3ForCausalLM(model_path, config)\n", + " model.compile(traced_model_path)\n", + " tokenizer.save_pretrained(traced_model_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run_qwen3_compile()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from modeling_qwen import Qwen3InferenceConfig, NeuronQwen3ForCausalLM\n", + "\n", + "model = NeuronQwen3ForCausalLM(traced_model_path)\n", + "model.load(traced_model_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = model.get_config_cls()\n", + "config.get_neuron_config_cls()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "32" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.config.num_attention_heads" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "8" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.config.num_key_value_heads" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4096" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.config.hidden_size" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(traced_model_path)\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "generation_config = GenerationConfig.from_pretrained(model_path)\n", + "generation_config_kwargs = {\n", + " \"do_sample\": False,\n", + " \"temperature\": 0.9,\n", + " \"top_k\": 5,\n", + " \"pad_token_id\": tokenizer.pad_token_id,\n", + "}\n", + "generation_config.update(**generation_config_kwargs)\n", + "generation_model = HuggingFaceGenerationAdapter(model)\n", + "messages = [{'role': 'user', 'content': \"What's your name?\"}]\n", + "text = tokenizer.apply_chat_template(\n", + " messages,\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + " enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.\n", + ")\n", + "inputs = tokenizer([text], return_tensors=\"pt\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\nGenerating outputs...\")\n", + "outputs = generation_model.generate(\n", + " **inputs,\n", + " max_new_tokens=512\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "thinking content: \n", + "content: My name is Qwen, and I'm a large language model developed by Alibaba Cloud. How can I assist you today?\n" + ] + } + ], + "source": [ + "output_ids = outputs[0][len(inputs.input_ids[0]):].tolist() \n", + "\n", + "# parsing thinking content\n", + "try:\n", + " # rindex finding 151668 ()\n", + " index = len(output_ids) - output_ids[::-1].index(151668)\n", + "except ValueError:\n", + " index = 0\n", + "\n", + "thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip(\"\\n\")\n", + "content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip(\"\\n\")\n", + "\n", + "print(\"thinking content:\", thinking_content)\n", + "print(\"content:\", content)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "model.reset()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "aws_neuronx_venv_pytorch_2_5_nxd_inference", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 4d683ea33c7fb92b1cac2163239aabc0cad7104d Mon Sep 17 00:00:00 2001 From: Josh Longenecker Date: Wed, 30 Apr 2025 14:18:20 -0400 Subject: [PATCH 6/7] Remove .DS_Store files and add to gitignore --- .DS_Store | Bin 6148 -> 0 bytes contributed/.DS_Store | Bin 6148 -> 0 bytes contributed/models/.DS_Store | Bin 6148 -> 0 bytes 3 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store delete mode 100644 contributed/.DS_Store delete mode 100644 contributed/models/.DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 45c0ee21b183e0a5b2f8bc84b377fd3232abf0b8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKPiqrF6o1pM&Bh=GQRqcj@EU4ri&}b#`J)sxg>KY?O5Dwcx^%l!vKw;X_ z_!YeRN&GIJ^!H|_kZcnW4dAnQzq#WXC2$Q$3 z<1~_!rW~hHrhI)p;FjHTzq&o0?jAg@d3%rcW;Jhm@StAv_8;!gW@Y!zz59nJ{j=me zm7f%IB(Og!xnb}UK0{fGucy&b&Dz)ntU0gS%+uB1Z>{)-vFY$MD zCdmp*^{A-Kz~m*9nQZd1WD)=n?NPW1Pyql3m9XSu@rh7B>5P=DrHm-_7%`;BvAPLj z70h_E;Wsisd$#}&pdf@E6!(vQihb1eH!XVl1>azt_U&H31pPQlhxPiaD9x4U7Z#mG zXUVzp@6^~Ic88<1-5p#|>r|;Y$m`wUEb5Irm6ZdP47*X%8|Z*2?4ir$d6a}|+*YF` zOm(bpIvm$=JC*8Wvf0|L$;M7&T9cF3R=p;h+s*0Jbyn9l_D(wY@k64X^sGVP|EXo! zVh+z(DXb|>Z{h&wNCAb8M}LsSD!D~_K3F~oBQwAZFav*#0eg%&OMh(MJY8mh8Tft% z=zNfIqz>hNU4j))v A5dZ)H diff --git a/contributed/models/.DS_Store b/contributed/models/.DS_Store deleted file mode 100644 index 30c8c00efe4f3cd398a5f78e56684116754e9494..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~O>P1)427QyNR`;IWSLD5Ku-`MH~|+x{OnLsRaM>Q=sY{#FtDyh=vi`J?8Ni) zEhb|Cw*Eb>fCYeex+^|>7??4h;TwNA;|rgspX23ly Date: Wed, 14 May 2025 10:24:33 -0400 Subject: [PATCH 7/7] logit val / cleanup --- contributed/models/qwen3/qwen-3-test.ipynb | 1113 ++++++++++++++++++-- 1 file changed, 1014 insertions(+), 99 deletions(-) diff --git a/contributed/models/qwen3/qwen-3-test.ipynb b/contributed/models/qwen3/qwen-3-test.ipynb index cd6cc4e..bbb60dd 100644 --- a/contributed/models/qwen3/qwen-3-test.ipynb +++ b/contributed/models/qwen3/qwen-3-test.ipynb @@ -43,66 +43,6 @@ "traced_model_path = \"/home/ubuntu/traced_model_qwen/qwen/\"" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoModelForCausalLM\n", - "\n", - "model = AutoModelForCausalLM.from_pretrained(\n", - " model_path,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Qwen3ForCausalLM(\n", - " (model): Qwen3Model(\n", - " (embed_tokens): Embedding(151936, 4096)\n", - " (layers): ModuleList(\n", - " (0-35): 36 x Qwen3DecoderLayer(\n", - " (self_attn): Qwen3Attention(\n", - " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", - " (q_norm): Qwen3RMSNorm((128,), eps=1e-06)\n", - " (k_norm): Qwen3RMSNorm((128,), eps=1e-06)\n", - " )\n", - " (mlp): Qwen3MLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=12288, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=12288, bias=False)\n", - " (down_proj): Linear(in_features=12288, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): Qwen3RMSNorm((4096,), eps=1e-06)\n", - " (post_attention_layernorm): Qwen3RMSNorm((4096,), eps=1e-06)\n", - " )\n", - " )\n", - " (norm): Qwen3RMSNorm((4096,), eps=1e-06)\n", - " (rotary_emb): Qwen3RotaryEmbedding()\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=151936, bias=False)\n", - ")" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model" - ] - }, { "cell_type": "code", "execution_count": null, @@ -195,60 +135,27 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "32" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "model.config.num_attention_heads" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "8" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "model.config.num_key_value_heads" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4096" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "model.config.hidden_size" ] @@ -332,6 +239,1014 @@ "source": [ "model.reset()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run Benchmarks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dir = '/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/'\n", + "!cp modeling_qwen.py {dir}" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost\n", + "WARNING:root:Found libneuronpjrt.so. Setting PJRT_DEVICE=NEURON.\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed.modules.moe.blockwise import (\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed.modules.moe.blockwise import (\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed.modules.moe.blockwise import (\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/attention/utils.py:14: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.modules.custom_calls import neuron_cumsum\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:632: UserWarning: Set seed for `privateuseone` device does not take effect, please add API's `_is_in_bad_fork` and `manual_seed_all` to `privateuseone` device module.\n", + " return fn(*args, **kwargs)\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/lora_serving/lora_model.py:12: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.modules.attention.gqa import GQA, GroupQueryAttention_QKV\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/dbrx/modeling_dbrx.py:38: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py:22: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.models.dbrx.modeling_dbrx import NeuronDbrxForCausalLM\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py:24: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from neuronx_distributed_inference.models.mixtral.modeling_mixtral import NeuronMixtralForCausalLM\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/mllama/modeling_mllama.py:72: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n", + " from .modeling_mllama_vision import NeuronMllamaVisionModel # noqa: E402\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/utils/accuracy.py:29: UserWarning: Intel extension for pytorch not found. For faster CPU references install `intel-extension-for-pytorch`.\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:632: UserWarning: Set seed for `privateuseone` device does not take effect, please add API's `_is_in_bad_fork` and `manual_seed_all` to `privateuseone` device module.\n", + " return fn(*args, **kwargs)\n", + "Loading configs...\n", + "WARNING:root:NeuronConfig init: Unexpected keyword arguments: {'model_type': 'qwen3', 'task_type': 'causal-lm', 'model_path': '/home/ubuntu/model_hf_qwen/qwen/', 'compiled_model_path': '/home/ubuntu/traced_model_qwen/qwen/logit', 'benchmark': True, 'check_accuracy_mode': , 'divergence_difference_tol': 0.001, 'prompts': ['To be, or not to be'], 'top_k': 1, 'top_p': 1.0, 'temperature': 1.0, 'do_sample': False, 'dynamic': False, 'pad_token_id': 151645, 'on_device_sampling': False, 'enable_torch_dist': False, 'enable_lora': False, 'skip_warmup': False, 'skip_compile': False, 'compile_only': False, 'hlo_debug': False}\n", + "\n", + "Compiling and saving model...\n", + "INFO:Neuron:Generating HLOs for the following models: ['context_encoding_model', 'token_generation_model']\n", + "[2025-05-14 14:09:05.944: I neuronx_distributed/parallel_layers/parallel_state.py:588] > initializing tensor model parallel with size 8\n", + "[2025-05-14 14:09:05.944: I neuronx_distributed/parallel_layers/parallel_state.py:589] > initializing pipeline model parallel with size 1\n", + "[2025-05-14 14:09:05.944: I neuronx_distributed/parallel_layers/parallel_state.py:590] > initializing context model parallel with size 1\n", + "[2025-05-14 14:09:05.944: I neuronx_distributed/parallel_layers/parallel_state.py:591] > initializing data parallel with size 1\n", + "[2025-05-14 14:09:05.945: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing world size to 8\n", + "[2025-05-14 14:09:05.945: I neuronx_distributed/parallel_layers/parallel_state.py:339] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n", + "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:628] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n", + "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:629] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:630] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:631] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:09:05.946: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "INFO:Neuron:Generating 1 hlos for key: context_encoding_model\n", + "INFO:Neuron:Started loading module context_encoding_model\n", + "INFO:Neuron:Finished loading module context_encoding_model in 0.07737994194030762 seconds\n", + "INFO:Neuron:generating HLO: context_encoding_model, input example shape = torch.Size([1, 16])\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:476: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n", + " with torch.cuda.amp.autocast(enabled=False):\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=1, shape=torch.Size([1, 16]), dtype=torch.int32)\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=3, shape=torch.Size([1]), dtype=torch.int32)\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=4, shape=torch.Size([1, 3]), dtype=torch.float32)\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=5, shape=torch.Size([1]), dtype=torch.int32)\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:158: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=6, shape=torch.Size([1]), dtype=torch.int32)\n", + " warnings.warn(\n", + "INFO:Neuron:Generating 1 hlos for key: token_generation_model\n", + "INFO:Neuron:Started loading module token_generation_model\n", + "INFO:Neuron:Finished loading module token_generation_model in 0.06693840026855469 seconds\n", + "INFO:Neuron:generating HLO: token_generation_model, input example shape = torch.Size([1, 1])\n", + "INFO:Neuron:Started compilation for all HLOs\n", + "....Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "INFO:Neuron:Done compilation for the priority HLO\n", + "INFO:Neuron:Updating the hlo module with optimized layout\n", + "INFO:Neuron:Done optimizing weight layout for all HLOs\n", + "..........Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "INFO:Neuron:Finished Compilation for all HLOs\n", + "..Completed run_backend_driver.\n", + "\n", + "Compiler status PASS\n", + "INFO:Neuron:Done preparing weight layout transformation\n", + "INFO:Neuron:Sharding Weights for ranks: 0...7\n", + "[2025-05-14 14:14:12.537: I neuronx_distributed/parallel_layers/parallel_state.py:588] > initializing tensor model parallel with size 8\n", + "[2025-05-14 14:14:12.537: I neuronx_distributed/parallel_layers/parallel_state.py:589] > initializing pipeline model parallel with size 1\n", + "[2025-05-14 14:14:12.537: I neuronx_distributed/parallel_layers/parallel_state.py:590] > initializing context model parallel with size 1\n", + "[2025-05-14 14:14:12.538: I neuronx_distributed/parallel_layers/parallel_state.py:591] > initializing data parallel with size 1\n", + "[2025-05-14 14:14:12.538: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing world size to 8\n", + "[2025-05-14 14:14:12.540: I neuronx_distributed/parallel_layers/parallel_state.py:339] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n", + "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:628] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n", + "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:629] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:630] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:631] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "[2025-05-14 14:14:12.541: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: lm_head.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.10.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.11.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.12.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.13.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.14.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.15.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.16.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.8.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.9.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.17.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.18.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.19.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.20.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.21.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.22.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.23.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.24.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.25.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.26.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: embed_tokens.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.0.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.1.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.2.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.3.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.4.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.5.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.6.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.7.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.27.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.28.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.29.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.30.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.31.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.32.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.33.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.34.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.input_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.mlp.down_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.mlp.gate_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.mlp.up_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.post_attention_layernorm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.k_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.k_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.o_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.q_norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.q_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: layers.35.self_attn.v_proj.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/application_base.py:347: UserWarning: Found float32 weights in checkpoint: norm.weight. Will convert to torch.bfloat16\n", + " warnings.warn(\n", + "INFO:Neuron:Done Sharding weights in 252.63744661300007\n", + "Compiling and tracing time: 559.4677159970001 seconds\n", + "\n", + "Loading model to Neuron...\n", + "INFO:Neuron:Warming up the model.\n", + "2025-May-14 14:18:35.0232 5872:7328 [2] nccl_net_ofi_rdma_init:7837 CCOM WARN NET/OFI OFI fi_getinfo() call failed: No data available\n", + "2025-May-14 14:18:35.0236 5872:7328 [2] nccl_net_ofi_create_plugin:261 CCOM WARN NET/OFI Unable to find a protocol that worked. Failing initialization.\n", + "2025-May-14 14:18:35.0239 5872:7328 [2] nccl_net_ofi_create_plugin:341 CCOM WARN NET/OFI aws-ofi-nccl initialization failed\n", + "2025-May-14 14:18:35.0242 5872:7328 [2] nccl_net_ofi_init:139 CCOM WARN NET/OFI Initializing plugin failed\n", + "2025-May-14 14:18:35.0245 5872:7328 [2] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?\n", + "INFO:Neuron:Warmup completed in 0.2721595764160156 seconds.\n", + "Total model loading time: 10.090576054999929 seconds\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:653: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n", + " warnings.warn(\n", + "\n", + "Checking accuracy by logit matching\n", + "Loading checkpoint shards: 100%|██████████████████| 5/5 [00:01<00:00, 2.57it/s]\n", + "`generation_config` default values have been modified to match model-specific defaults: {'do_sample': True, 'temperature': 0.6, 'top_p': 0.95}. If this is not desired, please set these values explicitly.\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:631: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", + " warnings.warn(\n", + "/opt/aws_neuronx_venv_pytorch_2_5_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:636: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.95` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", + " warnings.warn(\n", + "Expected Output: [\", that is the question. Whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune\"] tensor([[ 11, 429, 374, 279, 3405, 13, 13139, 364, 83, 285,\n", + " 13049, 1536, 304, 279, 3971, 311, 7676, 279, 1739, 819,\n", + " 323, 36957, 315, 54488, 32315]])\n", + "Expected Logits Shape: torch.Size([25, 1, 151936])\n", + "HuggingFaceGenerationAdapter has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n", + " - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes\n", + " - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n", + " - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n", + "Actual Output: [\", that is the question. Whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune\"] tensor([[ 11, 429, 374, 279, 3405, 13, 13139, 364, 83, 285,\n", + " 13049, 1536, 304, 279, 3971, 311, 7676, 279, 1739, 819,\n", + " 323, 36957, 315, 54488, 32315]])\n", + "Actual Logits Shape: torch.Size([25, 1, 151936])\n", + "Passed logits validation!\n", + "\n", + "Generating outputs...\n", + "Prompts: ['To be, or not to be']\n", + "Generated outputs:\n", + "Output 0: To be, or not to be, that is the question. Whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune\n", + "Benchmark completed and its result is as following\n", + "{\n", + " \"e2e_model\": {\n", + " \"latency_ms_p50\": 156.56781196594238,\n", + " \"latency_ms_p90\": 158.08086395263672,\n", + " \"latency_ms_p95\": 158.1140637397766,\n", + " \"latency_ms_p99\": 158.28602075576782,\n", + " \"latency_ms_p100\": 158.32901000976562,\n", + " \"latency_ms_avg\": 156.99772834777832,\n", + " \"throughput\": 203.82460521412273\n", + " },\n", + " \"context_encoding_model\": {\n", + " \"latency_ms_p50\": 10.202646255493164,\n", + " \"latency_ms_p90\": 10.224390029907227,\n", + " \"latency_ms_p95\": 10.22493839263916,\n", + " \"latency_ms_p99\": 10.226750373840332,\n", + " \"latency_ms_p100\": 10.227203369140625,\n", + " \"latency_ms_avg\": 10.201811790466309,\n", + " \"throughput\": 1568.348870634151\n", + " },\n", + " \"token_generation_model\": {\n", + " \"latency_ms_p50\": 8.858323097229004,\n", + " \"latency_ms_p90\": 8.903312683105469,\n", + " \"latency_ms_p95\": 9.238588809967041,\n", + " \"latency_ms_p99\": 9.264287948608398,\n", + " \"latency_ms_p100\": 9.28950309753418,\n", + " \"latency_ms_avg\": 8.88296922047933,\n", + " \"throughput\": 120.07996877975322\n", + " }\n", + "}\n", + "Completed saving result to benchmark_report.json\n" + ] + } + ], + "source": [ + "!inference_demo \\\n", + " --model-type qwen3 \\\n", + " --task-type causal-lm \\\n", + " run \\\n", + " --model-path /home/ubuntu/model_hf_qwen/qwen/ \\\n", + " --compiled-model-path /home/ubuntu/traced_model_qwen/qwen/logit \\\n", + " --torch-dtype bfloat16 \\\n", + " --tp-degree 8 \\\n", + " --batch-size 1 \\\n", + " --max-context-length 16 \\\n", + " --seq-len 32 \\\n", + " --enable-bucketing \\\n", + " --pad-token-id 151645 \\\n", + " --prompt \"To be, or not to be\" \\\n", + " --check-accuracy-mode logit-matching \\\n", + " --benchmark" + ] } ], "metadata": {