diff --git a/contributed/models/qwen3/modeling_qwen3.py b/contributed/models/qwen3/modeling_qwen3.py
new file mode 100644
index 0000000..ce9662d
--- /dev/null
+++ b/contributed/models/qwen3/modeling_qwen3.py
@@ -0,0 +1,1106 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen3 model for NXD inference."""
+import copy
+import gc
+import logging
+import math
+from typing import List, Optional, Tuple, Type
+
+import torch
+from neuronx_distributed.parallel_layers import parallel_state # noqa: E402
+from neuronx_distributed.parallel_layers.layers import ( # noqa: E402; noqa: E402; noqa: E402; noqa: E402; noqa: E402
+ ColumnParallelLinear,
+ ParallelEmbedding,
+ RowParallelLinear,
+)
+from neuronx_distributed.parallel_layers.mappings import (
+ gather_from_sequence_parallel_region,
+ reduce_from_tensor_model_parallel_region,
+ reduce_scatter_to_sequence_parallel_region,
+ _gather_along_first_dim,
+)
+from neuronx_distributed.parallel_layers.utils import get_padding_length
+from neuronx_distributed.utils import cpu_mode
+from neuronxcc.nki._private_kernels.mlp import (
+ mlp_fused_add_isa_kernel,
+ mlp_isa_kernel,
+ quant_mlp_fused_add_isa_kernel,
+ quant_mlp_isa_kernel,
+)
+from neuronxcc.nki._private_kernels.rmsnorm import rmsnorm_quant_isa_kernel
+from neuronxcc.nki.language import nc
+from torch import nn
+from torch_neuronx.xla_impl.ops import nki_jit
+from transformers import Qwen3ForCausalLM
+from transformers.activations import ACT2FN
+from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm, Qwen3RotaryEmbedding
+
+from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig # noqa: E402
+from neuronx_distributed_inference.models.model_base import ( # noqa: E402
+ NeuronBaseForCausalLM,
+ NeuronBaseModel,
+)
+from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase
+from neuronx_distributed_inference.modules.attention.gqa import ( # noqa: E402
+ BaseGroupQueryAttention,
+)
+from neuronx_distributed_inference.modules.attention.utils import (
+ RotaryEmbedding,
+ preprocess_quantized_linear_layer,
+ transpose_parallel_linear_layer,
+ apply_rotary_pos_emb,
+ move_heads_front,
+)
+from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm
+from neuronx_distributed_inference.modules.flashdecode.utils import calculate_num_cores_per_group
+from neuronx_distributed_inference.modules.lora_serving.lora_module import is_lora_module
+from neuronx_distributed_inference.utils.distributed import get_tp_group
+
+_Qwen3_MODULE_MAP = {}
+
+logger = logging.getLogger("Neuron")
+
+
+def get_rmsnorm_cls():
+ # Initialize to the appropriate implementation of RMSNorm
+ # If infer on NXD -> CustomRMSNorm
+ # If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU)
+ return Qwen3RMSNorm if cpu_mode() else CustomRMSNorm
+
+
+def preshard_hook_fn(module: torch.nn.Module, model_state_dict: dict, prefix: str) -> bool:
+ if isinstance(module, (BaseGroupQueryAttention,)):
+ return module.preshard_hook(model_state_dict, prefix)
+
+ return False
+
+
+# Get the modules_to_not_convert from the neuron configs
+def get_modules_to_not_convert(neuron_config: NeuronConfig):
+ return getattr(neuron_config, "modules_to_not_convert", None)
+
+
+def get_updated_configs(config: InferenceConfig):
+ """
+ Generate a list of configurations for each hidden layer in a Qwen3 model.
+
+ This function creates a list of InferenceConfig objects, one for each layer. It
+ modifies the configurations for certain layers based on which modules should not
+ be converted to quantized format. The function uses get_modules_to_not_convert()
+ to determine which modules should not be converted.
+
+ Args:
+ config (InferenceConfig): The inference configuration for the model.
+
+ Returns:
+ list[InferenceConfig]: A list of InferenceConfig objects, one for each layer in the model.
+ Each config may be either the original config or a modified version
+ with "quantized_mlp_kernel_enabled" as False for that specific layer.
+ """
+ updated_configs = []
+ modules_to_not_convert = get_modules_to_not_convert(config.neuron_config)
+ if modules_to_not_convert is None:
+ modules_to_not_convert = []
+
+ for i in range(config.num_hidden_layers):
+ # If any of the MLP modules for this layer are in modules_to_not_convert
+ module_pattern = f"layers.{i}.mlp"
+ if any(module_pattern in module for module in modules_to_not_convert):
+ non_quant_config = copy.deepcopy(config)
+ non_quant_config.neuron_config.quantized_mlp_kernel_enabled = False
+ non_quant_config.neuron_config.activation_quantization_type = None
+ non_quant_config.neuron_config.quantize_clamp_bound = float("inf")
+ updated_configs.append(non_quant_config)
+ else:
+ updated_configs.append(config)
+ return updated_configs
+
+
+def _register_module(key: str, cls: Type[nn.Module]):
+ _Qwen3_MODULE_MAP[key] = cls
+
+
+def register_module(key: str):
+ """
+ Register a module for use in NeuronQwen3.
+
+ Arguments:
+ key: String used to identify the module
+
+ Example:
+ @register_module("NeuronQwen3Attention")
+ class NeuronQwen3Attention(nn.Module):
+ ...
+ """
+
+ def inner(cls: Type[nn.Module]):
+ _register_module(key, cls)
+ return cls
+
+ return inner
+
+
+def _helper_concat_and_delete_qkv(Qwen3_state_dict, layer_num, attr):
+ """
+ Helper function to concatenate and delete QKV attributes for fusedqkv (weight or scale).
+ Args:
+ Qwen3_state_dict: The state dictionary containing model weights
+ layer_num: The index of the layer to process
+ attr: The attribute to process ('weight' or 'scale')
+ """
+ Qwen3_state_dict[f"layers.{layer_num}.self_attn.Wqkv.{attr}"] = torch.cat(
+ [
+ Qwen3_state_dict[f"layers.{layer_num}.self_attn.q_proj.{attr}"],
+ Qwen3_state_dict[f"layers.{layer_num}.self_attn.k_proj.{attr}"],
+ Qwen3_state_dict[f"layers.{layer_num}.self_attn.v_proj.{attr}"],
+ ],
+ )
+ del Qwen3_state_dict[f"layers.{layer_num}.self_attn.q_proj.{attr}"]
+ del Qwen3_state_dict[f"layers.{layer_num}.self_attn.k_proj.{attr}"]
+ del Qwen3_state_dict[f"layers.{layer_num}.self_attn.v_proj.{attr}"]
+
+
+def convert_state_dict_to_fused_qkv(Qwen3_state_dict, cfg: InferenceConfig):
+ """
+ This function concats the qkv weights and scales to a Wqkv weight and scale for fusedqkv, and deletes the qkv weights.
+ """
+ mods_to_not_conv = get_modules_to_not_convert(cfg.neuron_config)
+ if mods_to_not_conv is None:
+ mods_to_not_conv = []
+
+ for l in range(cfg.num_hidden_layers): # noqa: E741
+ _helper_concat_and_delete_qkv(Qwen3_state_dict, l, "weight")
+ if (
+ cfg.neuron_config.quantized_mlp_kernel_enabled or cfg.neuron_config.quantized
+ ) and f"layers.{l}.self_attn" not in mods_to_not_conv:
+ _helper_concat_and_delete_qkv(Qwen3_state_dict, l, "scale")
+
+ gc.collect()
+
+ return Qwen3_state_dict
+
+
+class WeightGatheredColumnParallel(ColumnParallelLinear):
+ """
+ A specialized column-parallel linear layer that implements weight gathering optimization
+ for efficient processing of long sequences in transformer models during eagle speculation.
+
+ This layer provides two forward paths:
+ 1. Standard column-parallel forward (inherited from parent)
+ 2. Weight-gathered forward for long sequences
+ """
+ def forward_wg(self, input: torch, weight_gather: bool = False):
+ """
+ Performs the forward pass with optional weight gathering optimization.
+
+ Args:
+ input (torch.Tensor): Input tensor of shape (batch_size, seq_len/TP, 2*hidden_size)
+ weight_gather (bool): Whether to use weight gathering optimization.
+ Typically True for sequences >= 32K
+
+ Returns:
+ torch.Tensor or Tuple[torch.Tensor, torch.Tensor]:
+ - If skip_bias_add is False: Output tensor of shape (batch_size, seq_len, hidden_size)
+ - If skip_bias_add is True: Tuple of (output tensor, bias)
+ """
+ if weight_gather:
+ weight = _gather_along_first_dim(self.weight, process_group=self.tensor_parallel_group)
+ output = self._forward_impl(
+ input=input,
+ weight=weight,
+ bias=None,
+ async_grad_allreduce=self.async_tensor_model_parallel_allreduce,
+ sequence_parallel_enabled=self.sequence_parallel_enabled,
+ sequence_dimension=self.sequence_dimension,
+ autograd_func_class=self.autograd_func_class,
+ process_group=self.tensor_parallel_group
+ )
+
+ output = gather_from_sequence_parallel_region(
+ output,
+ self.sequence_dimension,
+ process_group=self.tensor_parallel_group,
+ )
+ if self.skip_bias_add:
+ return output, self.bias
+
+ output = (output + self.bias) if self.bias is not None else output
+ return output
+ else:
+ return self.forward(input)
+
+
+class Qwen3InferenceConfig(InferenceConfig):
+ def add_derived_config(self):
+ self.num_cores_per_group = 1
+ if self.neuron_config.flash_decoding_enabled:
+ num_attn_heads, num_kv_heads = self.num_attention_heads, self.num_key_value_heads
+ self.num_cores_per_group = calculate_num_cores_per_group(
+ num_attn_heads, num_kv_heads, self.neuron_config.tp_degree
+ )
+
+ def get_required_attributes(self) -> List[str]:
+ return [
+ "hidden_size",
+ "num_attention_heads",
+ "num_hidden_layers",
+ "num_key_value_heads",
+ "pad_token_id",
+ "vocab_size",
+ "max_position_embeddings",
+ "rope_theta",
+ "rms_norm_eps",
+ "hidden_act",
+ ]
+
+ @classmethod
+ def get_neuron_config_cls(cls) -> Type[NeuronConfig]:
+ return NeuronConfig
+
+
+class NeuronQwen3MLP(nn.Module):
+ """
+ This class just replace the linear layers (gate_proj, up_proj and down_proj) with column and row parallel layers
+ """
+
+ def __init__(self, config: InferenceConfig):
+ super().__init__()
+ self.config = config
+ self.neuron_config = config.neuron_config
+ self.tp_degree = config.neuron_config.tp_degree
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ self.sequence_parallel_enabled = getattr(
+ self.neuron_config, "sequence_parallel_enabled", False
+ )
+ self.sequence_dimension = 1 if self.sequence_parallel_enabled else None
+ self.rms_norm_eps = config.rms_norm_eps
+ self.mlp_kernel_enabled = self.neuron_config.mlp_kernel_enabled
+ self.fused_rmsnorm_skip_gamma = self.config.neuron_config.fused_rmsnorm_skip_gamma
+ self.quantized_mlp_kernel_enabled = self.neuron_config.quantized_mlp_kernel_enabled
+ self.rmsnorm_quantize_kernel_enabled = self.neuron_config.rmsnorm_quantize_kernel_enabled
+ self.quantize_clamp_bound = self.neuron_config.quantize_clamp_bound
+ self.logical_nc_config = self.neuron_config.logical_nc_config
+ self.activation_quantization_type = self.neuron_config.activation_quantization_type
+ mlp_bias = getattr(config, "mlp_bias", False)
+
+ if self.neuron_config.quantized_mlp_kernel_enabled and self.quantize_clamp_bound == float(
+ "inf"
+ ):
+ logging.warning(
+ "quantize_clamp_bound is not specified in NeuronConfig. We will use the default value of 1200 for Qwen3 models in quantized kernels."
+ )
+ self.quantize_clamp_bound = 1200.0
+ if parallel_state.model_parallel_is_initialized():
+ if self.neuron_config.quantized_mlp_kernel_enabled:
+ # # Quantized MLP kernels expect intermediate size to be multiple of 128, so we need to pad
+ tp_degree = self.neuron_config.tp_degree
+ self.intermediate_size += (
+ get_padding_length(self.intermediate_size // tp_degree, 128) * tp_degree
+ )
+ logger.debug(f"Quantized intermediate_size: {self.intermediate_size}")
+ self.gate_proj = ColumnParallelLinear(
+ self.hidden_size,
+ self.intermediate_size,
+ bias=mlp_bias,
+ gather_output=False,
+ dtype=config.neuron_config.torch_dtype,
+ pad=True,
+ sequence_parallel_enabled=False,
+ sequence_dimension=None,
+ tensor_model_parallel_group=get_tp_group(config),
+ )
+ self.up_proj = ColumnParallelLinear(
+ self.hidden_size,
+ self.intermediate_size,
+ bias=mlp_bias,
+ gather_output=False,
+ dtype=config.neuron_config.torch_dtype,
+ pad=True,
+ sequence_parallel_enabled=False,
+ sequence_dimension=None,
+ tensor_model_parallel_group=get_tp_group(config),
+ )
+ self.down_proj = RowParallelLinear(
+ self.intermediate_size,
+ self.hidden_size,
+ bias=mlp_bias,
+ input_is_parallel=True,
+ dtype=config.neuron_config.torch_dtype,
+ pad=True,
+ sequence_parallel_enabled=self.sequence_parallel_enabled,
+ sequence_dimension=self.sequence_dimension,
+ tensor_model_parallel_group=get_tp_group(config),
+ reduce_dtype=config.neuron_config.rpl_reduce_dtype,
+ )
+ if self.mlp_kernel_enabled:
+ if self.neuron_config.quantized_mlp_kernel_enabled:
+ setattr(
+ self.gate_proj,
+ "post_create_quantized_module_hook",
+ preprocess_quantized_linear_layer,
+ )
+ setattr(
+ self.up_proj,
+ "post_create_quantized_module_hook",
+ preprocess_quantized_linear_layer,
+ )
+ setattr(
+ self.down_proj,
+ "post_create_quantized_module_hook",
+ preprocess_quantized_linear_layer,
+ )
+ else:
+ # Transpose the weights to the layout expected by kernels
+ self.gate_proj.weight = transpose_parallel_linear_layer(self.gate_proj.weight)
+ self.up_proj.weight = transpose_parallel_linear_layer(self.up_proj.weight)
+ self.down_proj.weight = transpose_parallel_linear_layer(self.down_proj.weight)
+
+ else:
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias)
+
+ def _kernel_enabled_quantized_mlp(self, x, rmsnorm, residual, adapter_ids):
+ grid = (nc(self.logical_nc_config),)
+ fused_residual = residual is not None
+ fused_rmsnorm = rmsnorm is not None
+ logger.debug(
+ f"MLP: quantized kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_nc_config={self.logical_nc_config}"
+ )
+
+ # Can't do residual add in the kernel if SP is enabled
+ if fused_residual:
+ assert (
+ not self.sequence_parallel_enabled
+ ), "Quantized MLP cannot have both fused residual add and sequence parallel RMSnorm!"
+ # Using fused residual add
+ _mlp_fwd_call = nki_jit()(quant_mlp_fused_add_isa_kernel)
+ else:
+ _mlp_fwd_call = nki_jit()(quant_mlp_isa_kernel)
+
+ if fused_rmsnorm:
+ ln_w = rmsnorm.weight.unsqueeze(0)
+ else:
+ ln_w = torch.zeros(size=(1, self.hidden_size), dtype=x.dtype, device=x.device)
+
+ # Handle SP RMSnorm
+ x_orig_dtype = x.dtype
+ if self.sequence_parallel_enabled:
+ # This RMSNormQuant kernel will do quantization inside, so we pass the
+ # clamp_bound for clipping.
+ # If we don't use this kernel, the MLP kernel below will do the
+ # quantization, so we also pass clamp_bound to that kernel.
+ if self.rmsnorm_quantize_kernel_enabled:
+ logger.debug(
+ "Running Quantized MLP kernel with sequence-parallel RMSnorm-Quantize kernel!"
+ )
+ _rmsnorm_quant_fwd_call = nki_jit()(rmsnorm_quant_isa_kernel)
+ quant_rmsnorm_out = torch.zeros(
+ size=(
+ x.shape[0], # batch size
+ x.shape[1], # sequence length
+ x.shape[2] + 4, # hidden size + 4 bytes for packing fp32 scale
+ ),
+ dtype=torch.int8,
+ device=x.device,
+ )
+ clamp_bound = self.quantize_clamp_bound
+ _rmsnorm_quant_fwd_call[grid](
+ x, ln_w, clamp_bound, quant_rmsnorm_out, kernel_name="QuantOnly"
+ )
+ x = gather_from_sequence_parallel_region(
+ quant_rmsnorm_out,
+ self.sequence_dimension,
+ process_group=get_tp_group(self.config),
+ )
+
+ else:
+ logger.debug(
+ "Running Quantized MLP kernel with external (native compiler) sequence-parallel RMSnorm!"
+ )
+ x = gather_from_sequence_parallel_region(
+ x, self.sequence_dimension, process_group=get_tp_group(self.config)
+ )
+
+ # Build output tensor
+ output_tensor_seqlen = x.shape[1]
+ if fused_residual:
+ # seqlen dim is doubled to store the residual add output
+ output_tensor_seqlen *= 2
+
+ output_tensor = torch.zeros(
+ size=(
+ x.shape[0], # batch size
+ output_tensor_seqlen,
+ self.hidden_size, # hidden size
+ ),
+ dtype=x_orig_dtype,
+ device=x.device,
+ )
+
+ # Grab weights
+ # all weights of the layers are stored in (out, in) shape
+ # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden]
+ gate_w = self.gate_proj.weight.data
+ gate_w_scale = self.gate_proj.scale
+ up_w = self.up_proj.weight.data
+ up_w_scale = self.up_proj.scale
+ down_w = self.down_proj.weight.data
+ down_w_scale = self.down_proj.scale
+ clamp_bound = self.quantize_clamp_bound
+
+ if fused_residual:
+ _mlp_fwd_call[grid](
+ x, # attn_output
+ residual, # hidden
+ ln_w, # ln_w
+ gate_w, # gate_w
+ gate_w_scale,
+ up_w, # up_w
+ up_w_scale,
+ down_w, # down_w
+ down_w_scale,
+ clamp_bound,
+ output_tensor, # out
+ fused_rmsnorm=fused_rmsnorm,
+ eps=self.rms_norm_eps,
+ kernel_name="MLP",
+ store_add=True,
+ )
+ original_seqlen = x.shape[1]
+ residual = output_tensor[:, original_seqlen:, :]
+ output_tensor = output_tensor[:, :original_seqlen, :]
+ else:
+ _mlp_fwd_call[grid](
+ x, # hidden
+ # should be fine to pass gamma is as a dummy even if not using fused rmsnorm
+ ln_w,
+ gate_w, # gate_w
+ gate_w_scale,
+ up_w, # up_w
+ up_w_scale,
+ down_w, # down_w
+ down_w_scale,
+ clamp_bound,
+ output_tensor, # out
+ # Run RMSNorm inside the kernel if NOT using SP rmsnorm
+ fused_rmsnorm=fused_rmsnorm,
+ eps=self.rms_norm_eps,
+ kernel_name="MLP",
+ )
+ residual = None
+
+ # All-reduce or reduce-scatter, depending on whether SP is enabled
+ if self.sequence_parallel_enabled:
+ output_tensor = reduce_scatter_to_sequence_parallel_region(
+ output_tensor, self.sequence_dimension, process_group=get_tp_group(self.config)
+ )
+ else:
+ output_tensor = reduce_from_tensor_model_parallel_region(output_tensor)
+
+ logger.debug(f"Quantized MLP output shape {output_tensor.shape}")
+ return (output_tensor, residual)
+
+ def _kernel_enabled_mlp(self, x, rmsnorm, residual, adapter_ids):
+ fused_residual = residual is not None
+ fused_rmsnorm = rmsnorm is not None
+ logger.debug(
+ f"MLP: kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, skip_gamma={self.fused_rmsnorm_skip_gamma}, logical_nc_config={self.logical_nc_config}"
+ )
+
+ # Choose which kernel to call
+ if fused_residual:
+ assert (
+ not self.sequence_parallel_enabled
+ ), "MLP kernel cannot have both fused residual add and sequence parallel RMSnorm!"
+ # Using fused residual add
+ _mlp_fwd_call = nki_jit()(mlp_fused_add_isa_kernel)
+ else:
+ _mlp_fwd_call = nki_jit()(mlp_isa_kernel)
+
+ if self.sequence_parallel_enabled:
+ x = gather_from_sequence_parallel_region(
+ x, self.sequence_dimension, process_group=get_tp_group(self.config)
+ )
+
+ # Build output tensor
+ output_tensor_seqlen = x.shape[1]
+ if fused_residual:
+ # seqlen dim is doubled to store the residual add output
+ output_tensor_seqlen *= 2
+
+ output_tensor = torch.zeros(
+ size=(
+ x.shape[0], # batch size
+ output_tensor_seqlen,
+ self.hidden_size, # hidden size
+ ),
+ dtype=x.dtype,
+ device=x.device,
+ )
+
+ # Grab weights
+ # all weights of the layers are stored in (out, in) shape
+ # unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden]
+ if fused_rmsnorm:
+ ln_w = rmsnorm.weight.unsqueeze(0)
+ else:
+ ln_w = torch.zeros(size=(1, self.hidden_size), dtype=x.dtype, device=x.device)
+ gate_w = self.gate_proj.weight.data
+ up_w = self.up_proj.weight.data
+ down_w = self.down_proj.weight.data
+
+ grid = (nc(self.logical_nc_config),)
+
+ if fused_residual:
+ _mlp_fwd_call[grid](
+ x, # attn_output
+ residual, # hidden
+ ln_w, # ln_w
+ gate_w, # gate_w
+ up_w, # up_w
+ down_w, # down_w
+ output_tensor, # out
+ kernel_name="MLP",
+ fused_rmsnorm=fused_rmsnorm,
+ skip_gamma=self.fused_rmsnorm_skip_gamma,
+ eps=self.rms_norm_eps,
+ store_add=True,
+ )
+ original_seqlen = x.shape[1]
+ residual = output_tensor[:, original_seqlen:, :]
+ output_tensor = output_tensor[:, :original_seqlen, :]
+ else:
+ _mlp_fwd_call[grid](
+ x, # hidden
+ # should be fine to pass gamma is as a dummy even if not using fused rmsnorm
+ ln_w,
+ gate_w,
+ up_w,
+ down_w,
+ output_tensor, # out
+ kernel_name="MLP",
+ # Run RMSNorm inside the kernel if NOT using SP rmsnorm
+ fused_rmsnorm=fused_rmsnorm,
+ skip_gamma=self.fused_rmsnorm_skip_gamma,
+ eps=self.rms_norm_eps,
+ )
+ residual = None
+
+ # All-reduce or reduce-scatter, depending on whether SP is enabled
+ if self.sequence_parallel_enabled:
+ output_tensor = reduce_scatter_to_sequence_parallel_region(
+ output_tensor, self.sequence_dimension, process_group=get_tp_group(self.config)
+ )
+ else:
+ output_tensor = reduce_from_tensor_model_parallel_region(
+ output_tensor, process_group=get_tp_group(self.config)
+ )
+
+ logger.debug(f"MLP output shape {output_tensor.shape}")
+ return (output_tensor, residual)
+
+ def _native_mlp(self, x, adapter_ids=None):
+ logger.debug("MLP: native compiler")
+ # all-gather is done here instead of CPL layers to
+ # avoid 2 all-gathers from up and gate projections
+ if self.sequence_parallel_enabled:
+ x = gather_from_sequence_parallel_region(
+ x, self.sequence_dimension, process_group=get_tp_group(self.config)
+ )
+ gate_proj_output = (
+ self.gate_proj(x)
+ if not is_lora_module(self.gate_proj)
+ else self.gate_proj(x, adapter_ids)
+ )
+
+ up_proj_output = (
+ self.up_proj(x) if not is_lora_module(self.up_proj) else self.up_proj(x, adapter_ids)
+ )
+
+ down_proj_input = self.act_fn(gate_proj_output) * up_proj_output
+ output = (
+ self.down_proj(down_proj_input)
+ if not is_lora_module(self.down_proj)
+ else self.down_proj(down_proj_input, adapter_ids)
+ )
+ logger.debug(f"MLP output shape {output.shape}")
+ return output
+
+ def forward(self, x, rmsnorm=None, residual=None, adapter_ids=None):
+ """
+ If residual is passed in, will fuse its add into the MLP kernel
+ If rmsnorm is passed in, will fuse the rmsnorm into the MLP kernel
+
+ Returns a tuple of (output, residual), where residual is the output of the residual add
+ """
+
+ if self.mlp_kernel_enabled:
+ # Quantized MLP kernel
+ if self.quantized_mlp_kernel_enabled:
+ return self._kernel_enabled_quantized_mlp(
+ x, rmsnorm, residual, adapter_ids=adapter_ids
+ )
+ # MLP kernel
+ return self._kernel_enabled_mlp(x, rmsnorm, residual, adapter_ids=adapter_ids)
+ else:
+ # No kernel
+ assert rmsnorm is None and residual is None
+ return (self._native_mlp(x, adapter_ids=adapter_ids), None)
+
+
+@register_module("NeuronQwen3Attention")
+class NeuronQwen3Attention(NeuronAttentionBase):
+ """
+ Compared with Qwen3Attention, this class just
+ 1. replaces the q_proj, k_proj, v_proj with column parallel layer
+ 2. replaces the o_proj with row parallel layer
+ 3. update self.num_head to be self.num_head / tp_degree
+ 4. update self.num_key_value_heads to be self.num_key_value_heads / tp_degree
+ 5. update forward() method to adjust to changes from self.num_head
+ """
+
+ def __init__(self, config: InferenceConfig, tensor_model_parallel_group=None):
+ super().__init__(tensor_model_parallel_group=tensor_model_parallel_group)
+
+ self.config = config
+ self.neuron_config = config.neuron_config
+ self.hidden_size = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_attention_heads)
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.padding_side = config.neuron_config.padding_side
+ self.torch_dtype = config.neuron_config.torch_dtype
+ self.is_medusa = config.neuron_config.is_medusa
+ self.flash_decoding_enabled = config.neuron_config.flash_decoding_enabled
+ self.num_cores_per_group = config.num_cores_per_group
+ self.bias = getattr(config, "attention_bias", False)
+ self.rpl_reduce_dtype = config.neuron_config.rpl_reduce_dtype
+ self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled
+ self.rms_norm_eps = config.rms_norm_eps
+ self.attn_tkg_builtin_kernel_enabled = self.neuron_config.attn_tkg_builtin_kernel_enabled
+ self.q_norm = Qwen3RMSNorm(
+ self.head_dim, eps=config.rms_norm_eps
+ ) # unlike olmo, only on the head dim!
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ if parallel_state.model_parallel_is_initialized():
+ self.tp_degree = self.config.neuron_config.tp_degree
+ else:
+ self.tp_degree = 1
+
+ self.fused_qkv = config.neuron_config.fused_qkv
+ self.clip_qkv = None
+
+ self.sequence_parallel_enabled = self.neuron_config.sequence_parallel_enabled
+ self.sequence_dimension = 1 if self.sequence_parallel_enabled else None
+ logger.debug(
+ f"Hello from NeuronQwen3Attention init! Is SP enabled? {self.sequence_parallel_enabled}. Dim? {self.sequence_dimension}"
+ )
+
+ self.init_gqa_properties()
+
+ self.init_rope()
+
+ def prep_qkv_tensors(
+ self,
+ position_ids,
+ hidden_states,
+ past_key_value,
+ adapter_ids=None,
+ cos_cache=None,
+ sin_cache=None,
+ rmsnorm=None,
+ skip_rope=False,
+ residual=None,
+ ):
+ """take care of the shape, layout, group query, custom position encoding, etc.
+ also return residual for MLP """
+ Q, K, V, residual = self.qkv_proj(
+ hidden_states=hidden_states, rmsnorm=rmsnorm, adapter_ids=adapter_ids, residual=residual
+ )
+ if self.use_qk_norm:
+ self.init_qk_norm() # TODO: when attentionbase can take config parameters in init, move this to init function
+ Q = self.qk_norm(Q)
+ K = self.qk_norm(K)
+
+ # Divide hidden_dim across heads for MHA
+ # Change layout: BSHD -> BHSD
+ bsz, q_len, _ = hidden_states.size()
+ if self.sequence_parallel_enabled:
+ q_len *= self.tensor_model_parallel_group.size()
+
+ Q = move_heads_front(
+ Q, bsz, q_len, self.num_heads, self.head_dim, layernorm=self.q_norm
+ )
+ K = move_heads_front(
+ K, bsz, q_len, self.num_key_value_heads, self.head_dim, layernorm=self.k_norm
+ )
+ V = move_heads_front(V, bsz, q_len, self.num_key_value_heads, self.head_dim, layernorm=None)
+
+ # Rotate Q and K
+ if not skip_rope and self.rotary_emb is not None:
+ if cos_cache is None or sin_cache is None:
+ cos_cache, sin_cache = self.rotary_emb(V, position_ids)
+
+ Q, K = apply_rotary_pos_emb(Q, K, cos_cache, sin_cache)
+
+ return Q, K, V, cos_cache, sin_cache, residual
+
+ def init_rope(self):
+ self.rotary_emb = Qwen3RotaryEmbedding(self.config)
+
+ if self.attn_tkg_builtin_kernel_enabled:
+ self.inv_freqs = self.rotary_emb.get_inv_freqs().unsqueeze(1)
+
+
+class NeuronQwen3DecoderLayer(nn.Module):
+ """
+ Just replace the attention with the NXD version, and MLP with the NXD version
+ """
+
+ def __init__(self, config: InferenceConfig):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = NeuronQwen3Attention(
+ config=config, tensor_model_parallel_group=get_tp_group(config)
+ )
+
+ self.mlp = NeuronQwen3MLP(config)
+ logger.debug(
+ f"Instantiating RMSNorm modules with hidden size {config.hidden_size} and EPS {config.rms_norm_eps}"
+ )
+ self.input_layernorm = None
+ if (
+ not config.neuron_config.is_eagle_draft
+ or config.neuron_config.enable_eagle_draft_input_norm
+ ):
+ self.input_layernorm = get_rmsnorm_cls()(
+ config.hidden_size,
+ eps=config.rms_norm_eps,
+ )
+ self.post_attention_layernorm = get_rmsnorm_cls()(
+ config.hidden_size,
+ eps=config.rms_norm_eps,
+ )
+ self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled
+ self.mlp_kernel_enabled = config.neuron_config.mlp_kernel_enabled
+ self.quantized_mlp_kernel_enabled = config.neuron_config.quantized_mlp_kernel_enabled
+ self.rmsnorm_quantize_kernel_enabled = config.neuron_config.rmsnorm_quantize_kernel_enabled
+ self.mlp_kernel_fuse_residual_add = config.neuron_config.mlp_kernel_fuse_residual_add
+ self.qkv_kernel_fuse_residual_add = config.neuron_config.qkv_kernel_fuse_residual_add
+ self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled
+ self.is_prefill_stage = config.neuron_config.is_prefill_stage
+ self.config = config
+
+ if self.is_prefill_stage and self.config.neuron_config.is_mlp_quantized():
+ # for CTE, quantized MLP kernel does not support fused rmsnorm
+ self.mlp_kernel_fused_rmsnorm = False
+ else:
+ self.mlp_kernel_fused_rmsnorm = not self.sequence_parallel_enabled
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ adapter_ids=None,
+ rotary_position_ids: Optional[torch.LongTensor] = None,
+ residual: Optional[torch.Tensor] = None, # residual from previous layer used by QKV
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]], Optional[torch.FloatTensor], Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
+ entry_hidden_states = hidden_states
+ # RMSNorm (fused with QKV kernel when SP is disabled)
+ if (not self.qkv_kernel_enabled or self.sequence_parallel_enabled) and self.input_layernorm:
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ # produced another residual used by MLP
+ attn_output = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ adapter_ids=adapter_ids,
+ rmsnorm=self.input_layernorm,
+ rotary_position_ids=rotary_position_ids,
+ residual=residual,
+ **kwargs,
+ )
+
+ if attn_output.residual is None:
+ residual = entry_hidden_states # input to attention
+ else:
+ # residual will only be returned by attn/qkv if fuse add qkv kernel is enabled
+ assert self.qkv_kernel_fuse_residual_add, \
+ "residual add before qkv should be computed in the previous layer, \
+ unless qkv_kernel_fuse_residual_add is specified"
+ assert (
+ not self.sequence_parallel_enabled
+ ), "qkv_kernel_fuse_residual_add should be off when sequence parallelism is enabled"
+ assert (
+ self.qkv_kernel_enabled
+ ), "qkv_kernel_fuse_residual_add should be used with qkv_kernel_enabled"
+ residual = attn_output.residual
+
+ hidden_states = attn_output.hidden_states
+ if self.mlp_kernel_enabled and self.mlp_kernel_fuse_residual_add:
+ assert (
+ not self.sequence_parallel_enabled
+ ), "mlp_kernel_fuse_residual_add should be off when sequence parallelism is enabled"
+ # First residual add handled in the MLP kernel
+ hidden_states, residual = self.mlp(
+ hidden_states,
+ rmsnorm=self.post_attention_layernorm,
+ residual=residual,
+ adapter_ids=adapter_ids,
+ )
+ else:
+ hidden_states = residual + hidden_states
+ residual = hidden_states
+ # RMSNorm (fused with QKV kernel when SP is disabled)
+ if self.mlp_kernel_enabled and self.mlp_kernel_fused_rmsnorm:
+ rmsnorm = self.post_attention_layernorm
+ else:
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ rmsnorm = None
+ hidden_states, _ = self.mlp(
+ hidden_states,
+ rmsnorm=rmsnorm,
+ adapter_ids=adapter_ids,
+ )
+
+ # if fuse residual add with qkv, we leave this add to the next layer's QKV
+ # unless it is the last layer in which case we add it here
+ if not self.qkv_kernel_fuse_residual_add:
+ hidden_states = residual + hidden_states
+ residual = None # set to None to prevent it from being used again
+
+ # also return residual for QKV in the next layer
+ outputs = (hidden_states, attn_output.present_key_value, attn_output.cos_cache, attn_output.sin_cache, residual)
+ return outputs
+
+
+class ResBlock(nn.Module):
+ """
+ A Residual Block module.
+
+ This module performs a linear transformation followed by a SiLU activation,
+ and then adds the result to the original input, creating a residual connection.
+
+ Args:
+ hidden_size (int): The size of the hidden layers in the block.
+ """
+
+ def __init__(self, hidden_size):
+ super().__init__()
+ self.linear = nn.Linear(hidden_size, hidden_size)
+ # Initialize as an identity mapping
+ torch.nn.init.zeros_(self.linear.weight)
+ # Use SiLU activation to keep consistent with the Qwen3 model
+ self.act = nn.SiLU()
+
+ def forward(self, x):
+ """
+ Forward pass of the ResBlock.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: Output after the residual connection and activation.
+ """
+ return x + self.act(self.linear(x))
+
+
+class NeuronQwen3Model(NeuronBaseModel):
+ """
+ The neuron version of the Qwen3Model
+ """
+
+ def setup_attr_for_model(self, config: InferenceConfig):
+ # Needed for init_inference_optimization()
+ self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None
+ self.tp_degree = config.neuron_config.tp_degree
+ self.hidden_size = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.max_batch_size = config.neuron_config.max_batch_size
+ self.buckets = config.neuron_config.buckets
+
+ def init_model(self, config: InferenceConfig):
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ if parallel_state.model_parallel_is_initialized():
+ self.embed_tokens = ParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ self.padding_idx,
+ dtype=config.neuron_config.torch_dtype,
+ shard_across_embedding=not config.neuron_config.vocab_parallel,
+ sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled,
+ sequence_dimension=1,
+ pad=True,
+ tensor_model_parallel_group=get_tp_group(config),
+ use_spmd_rank=config.neuron_config.vocab_parallel,
+ )
+
+ self.lm_head = ColumnParallelLinear(
+ config.hidden_size,
+ config.vocab_size,
+ gather_output=not self.on_device_sampling,
+ bias=False,
+ pad=True,
+ tensor_model_parallel_group=get_tp_group(config),
+ )
+ else:
+ self.embed_tokens = nn.Embedding(
+ config.vocab_size,
+ config.hidden_size,
+ self.padding_idx,
+ )
+ self.lm_head = nn.Linear(
+ config.hidden_size,
+ config.vocab_size,
+ bias=False,
+ )
+
+ updated_configs = get_updated_configs(config)
+
+ self.layers = nn.ModuleList([NeuronQwen3DecoderLayer(conf) for conf in updated_configs])
+
+ if not config.neuron_config.is_eagle_draft:
+ self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps)
+
+ if config.neuron_config.is_eagle_draft:
+ fc_bias = getattr(config, "fc_bias", False)
+ # replicate fc weights since activations are sequence sharded
+ self.fc = WeightGatheredColumnParallel(
+ config.hidden_size * 2, config.hidden_size, bias=fc_bias, gather_output=True, sequence_dimension=1
+ )
+ self.is_medusa = config.neuron_config.is_medusa
+ self.num_medusa_heads = config.neuron_config.num_medusa_heads
+ self.medusa_speculation_length = config.neuron_config.medusa_speculation_length
+
+ if self.is_medusa:
+ if parallel_state.model_parallel_is_initialized():
+ medusa_head_cls = ColumnParallelLinear
+ else:
+ medusa_head_cls = nn.Linear
+ for i in range(self.num_medusa_heads):
+ medusa_head = nn.Sequential(
+ *([ResBlock(config.hidden_size)] * 1),
+ medusa_head_cls(
+ config.hidden_size,
+ config.vocab_size,
+ gather_output=not self.on_device_sampling,
+ bias=False,
+ ),
+ )
+ setattr(self, f"medusa_head_{i}", medusa_head)
+
+
+class NeuronQwen3ForCausalLM(NeuronBaseForCausalLM):
+ """
+ This class extends Qwen3ForCausalLM create traceable
+ blocks for Neuron.
+
+ Args:
+ Qwen3ForCausalLM (_type_): _description_
+ """
+
+ _model_cls = NeuronQwen3Model
+
+ @staticmethod
+ def load_hf_model(model_path, **kwargs):
+ return Qwen3ForCausalLM.from_pretrained(model_path, **kwargs)
+
+ @staticmethod
+ def convert_hf_to_neuron_state_dict(state_dict: dict, config: InferenceConfig) -> dict:
+ """This function should be over-ridden in child classes as needed"""
+
+ neuron_config = config.neuron_config
+ # to facilitate rank usage in attention
+ num_layers = config.num_hidden_layers
+ tp_degree = neuron_config.tp_degree
+ for i in range(num_layers):
+ state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange(
+ 0, tp_degree, dtype=torch.int32
+ )
+
+ """
+ for every layer do the following transformations
+ gate_w_prime = (gate_w.T * gamma).T
+ up_w_prime = (up_w.T * gamma).T
+ """
+ if (
+ neuron_config.fused_rmsnorm_skip_gamma
+ and not neuron_config.sequence_parallel_enabled
+ ):
+ if neuron_config.mlp_kernel_enabled:
+ # MLP
+ state_dict[f"layers.{i}.mlp.gate_proj.weight"] = state_dict[
+ f"layers.{i}.mlp.gate_proj.weight"
+ ] * state_dict[f"layers.{i}.input_layernorm.weight"].unsqueeze(0)
+ state_dict[f"layers.{i}.mlp.up_proj.weight"] = state_dict[
+ f"layers.{i}.mlp.up_proj.weight"
+ ] * state_dict[f"layers.{i}.input_layernorm.weight"].unsqueeze(0)
+
+ if neuron_config.qkv_kernel_enabled:
+ # QKV
+ state_dict[f"layers.{i}.self_attn.q_proj.weight"] = state_dict[
+ f"layers.{i}.self_attn.q_proj.weight"
+ ] * state_dict[f"layers.{i}.input_layernorm.weight"].unsqueeze(0)
+ state_dict[f"layers.{i}.self_attn.k_proj.weight"] = state_dict[
+ f"layers.{i}.self_attn.k_proj.weight"
+ ] * state_dict[f"layers.{i}.input_layernorm.weight"].unsqueeze(0)
+ state_dict[f"layers.{i}.self_attn.v_proj.weight"] = state_dict[
+ f"layers.{i}.self_attn.v_proj.weight"
+ ] * state_dict[f"layers.{i}.input_layernorm.weight"].unsqueeze(0)
+
+ if neuron_config.fused_qkv:
+ state_dict = convert_state_dict_to_fused_qkv(state_dict, config)
+
+ if neuron_config.vocab_parallel:
+ # TODO: this hack can be removed after replication_id is ready to use
+ state_dict["embed_tokens.rank_util.rank"] = torch.arange(
+ 0, neuron_config.local_ranks_size, dtype=torch.int32
+ )
+
+ # to facilitate rank usage in base model
+ state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32)
+ return state_dict
+
+ @staticmethod
+ def update_state_dict_for_tied_weights(state_dict):
+ state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone()
+
+ @classmethod
+ def get_config_cls(cls):
+ return Qwen3InferenceConfig
diff --git a/contributed/models/qwen3/qwen-3-test.ipynb b/contributed/models/qwen3/qwen-3-test.ipynb
new file mode 100644
index 0000000..3ea11c9
--- /dev/null
+++ b/contributed/models/qwen3/qwen-3-test.ipynb
@@ -0,0 +1,957 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip uninstall transformers --y\n",
+ "!pip install transformers==4.51.3\n",
+ "\n",
+ "# Installing collected packages: transformers\n",
+ "# ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "# neuronx-distributed-inference 0.3.5591+f50feae2 requires transformers==4.48.*, but you have transformers 4.52.4 which is incompatible.\n",
+ "# Successfully installed transformers-4.52.4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "libneuronxla 2.2.3493.0+78c3e78c\n",
+ "neuronx-cc 2.18.121.0+9e31e41a\n",
+ "neuronx-distributed 0.12.12111+cdd84048\n",
+ "neuronx-distributed-inference 0.3.5591+f50feae2\n",
+ "torch-neuronx 2.6.0.2.7.5413+113e6810\n",
+ "transformers 4.51.3\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip list | grep neuron\n",
+ "!pip list | grep transformers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from transformers import AutoTokenizer, GenerationConfig\n",
+ "from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig\n",
+ "from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter, load_pretrained_config"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Model Download"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_path = \"/home/ubuntu/model_hf_qwen/qwen/\"\n",
+ "traced_model_path = \"/home/ubuntu/traced_model_qwen3/qwen3/\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from huggingface_hub import snapshot_download\n",
+ "\n",
+ "snapshot_download(\"Qwen/Qwen3-8B\", local_dir=model_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### You may ignore the error that nxdi is not compatible with transformers > 4.48"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Compilation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/ubuntu/build-on-trainium-workshop/contributed/models/qwen3/modeling_qwen3.py:61: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase\n",
+ "/home/ubuntu/build-on-trainium-workshop/contributed/models/qwen3/modeling_qwen3.py:61: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase\n"
+ ]
+ }
+ ],
+ "source": [
+ "from modeling_qwen3 import Qwen3InferenceConfig, NeuronQwen3ForCausalLM\n",
+ "\n",
+ "def run_qwen3_compile():\n",
+ " # Initialize configs and tokenizer.\n",
+ " tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\"right\")\n",
+ " tokenizer.pad_token = tokenizer.eos_token\n",
+ "\n",
+ " generation_config = GenerationConfig.from_pretrained(model_path)\n",
+ " generation_config_kwargs = {\n",
+ " \"do_sample\": True,\n",
+ " \"top_k\": 1,\n",
+ " \"pad_token_id\": tokenizer.pad_token_id,\n",
+ " }\n",
+ " generation_config.update(**generation_config_kwargs)\n",
+ " \n",
+ " neuron_config = NeuronConfig(\n",
+ " tp_degree=8,\n",
+ " batch_size=1,\n",
+ " max_context_length=1024, \n",
+ " seq_len=2048, \n",
+ " on_device_sampling_config=OnDeviceSamplingConfig(top_k=5),\n",
+ " enable_bucketing=True,\n",
+ " context_encoding_buckets=[1024],\n",
+ " token_generation_buckets=[2048],\n",
+ " flash_decoding_enabled=False,\n",
+ " torch_dtype=torch.bfloat16,\n",
+ " fused_qkv=False,\n",
+ " attn_kernel_enabled=True,\n",
+ " attn_cls=\"NeuronQwen3Attention\"\n",
+ " )\n",
+ " config = Qwen3InferenceConfig(\n",
+ " neuron_config,\n",
+ " load_config=load_pretrained_config(model_path),\n",
+ " )\n",
+ " \n",
+ " # Compile and save model.\n",
+ " print(\"\\nCompiling and saving model...\")\n",
+ " model = NeuronQwen3ForCausalLM(model_path, config)\n",
+ " model.compile(traced_model_path)\n",
+ " tokenizer.save_pretrained(traced_model_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "run_qwen3_compile()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Testing"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from modeling_qwen3 import Qwen3InferenceConfig, NeuronQwen3ForCausalLM\n",
+ "\n",
+ "model = NeuronQwen3ForCausalLM(traced_model_path)\n",
+ "model.load(traced_model_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "neuronx_distributed_inference.models.config.NeuronConfig"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "config = model.get_config_cls()\n",
+ "config.get_neuron_config_cls()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "32"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.config.num_attention_heads"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "8"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.config.num_key_value_heads"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "4096"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.config.hidden_size"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer = AutoTokenizer.from_pretrained(traced_model_path)\n",
+ "tokenizer.pad_token = tokenizer.eos_token\n",
+ "generation_config = GenerationConfig.from_pretrained(model_path)\n",
+ "generation_config_kwargs = {\n",
+ " \"do_sample\": False,\n",
+ " \"temperature\": 0.9,\n",
+ " \"top_k\": 5,\n",
+ " \"pad_token_id\": tokenizer.pad_token_id,\n",
+ "}\n",
+ "generation_config.update(**generation_config_kwargs)\n",
+ "generation_model = HuggingFaceGenerationAdapter(model)\n",
+ "messages = [{'role': 'user', 'content': \"What's your name?\"}]\n",
+ "text = tokenizer.apply_chat_template(\n",
+ " messages,\n",
+ " tokenize=False,\n",
+ " add_generation_prompt=True,\n",
+ " enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.\n",
+ ")\n",
+ "inputs = tokenizer([text], return_tensors=\"pt\")\n",
+ "input_ids = inputs['input_ids'] \n",
+ "\n",
+ "outputs = generation_model.generate(\n",
+ " input_ids=input_ids,\n",
+ " max_new_tokens=512\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "thinking content: \n",
+ "content: My name is Qwen, and I'm a large language model developed by Alibaba Cloud. How can I assist you today?\n"
+ ]
+ }
+ ],
+ "source": [
+ "output_ids = outputs[0][len(inputs.input_ids[0]):].tolist() \n",
+ "\n",
+ "# parsing thinking content\n",
+ "try:\n",
+ " # rindex finding 151668 ()\n",
+ " index = len(output_ids) - output_ids[::-1].index(151668)\n",
+ "except ValueError:\n",
+ " index = 0\n",
+ "\n",
+ "thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip(\"\\n\")\n",
+ "content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip(\"\\n\")\n",
+ "\n",
+ "print(\"thinking content:\", thinking_content)\n",
+ "print(\"content:\", content)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Thinking example"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.reset()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer = AutoTokenizer.from_pretrained(traced_model_path)\n",
+ "tokenizer.pad_token = tokenizer.eos_token\n",
+ "generation_config = GenerationConfig.from_pretrained(model_path)\n",
+ "generation_config_kwargs = {\n",
+ " \"do_sample\": False,\n",
+ " \"temperature\": 0.9,\n",
+ " \"top_k\": 5,\n",
+ " \"pad_token_id\": tokenizer.pad_token_id,\n",
+ "}\n",
+ "generation_config.update(**generation_config_kwargs)\n",
+ "generation_model = HuggingFaceGenerationAdapter(model)\n",
+ "messages = [{'role': 'system', 'content': \"Only think through one example before providing the correct answer\"},\n",
+ " {'role': 'user', 'content': \"What is 83 * 110 + 34?\"}\n",
+ " ]\n",
+ "text = tokenizer.apply_chat_template(\n",
+ " messages,\n",
+ " tokenize=False,\n",
+ " add_generation_prompt=True,\n",
+ " enable_thinking=True # Switches between thinking and non-thinking modes. Default is True.\n",
+ ")\n",
+ "inputs = tokenizer([text], return_tensors=\"pt\")\n",
+ "input_ids = inputs['input_ids'] \n",
+ "outputs = generation_model.generate(\n",
+ " input_ids=input_ids,\n",
+ " max_new_tokens=1024\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "thinking content: \n",
+ "Okay, let's see. I need to calculate 83 multiplied by 110 and then add 34 to the result. Hmm, let me break this down step by step. First, I should handle the multiplication part: 83 times 110. \n",
+ "\n",
+ "Wait, multiplying by 110 might be easier if I think of it as multiplying by 100 and then adding 10 times the number. Because 110 is 100 + 10. So, 83 times 100 is 8300, and 83 times 10 is 830. Then adding those two together: 8300 + 830. Let me check that. 8300 plus 800 is 9100, and then plus 30 more would be 9130. So, 83 * 110 equals 9130?\n",
+ "\n",
+ "Wait, let me verify that another way. Maybe using the standard multiplication method. Let's write it out:\n",
+ "\n",
+ " 83\n",
+ "x110\n",
+ "------\n",
+ "First, multiply 83 by 0 (the units place of 110), which gives 0.\n",
+ "Then multiply 83 by 1 (the tens place of 110), which is 83, but since it's in the tens place, it's actually 830.\n",
+ "Then multiply 83 by 1 (the hundreds place of 110), which is 83, but since it's in the hundreds place, it's 8300.\n",
+ "Adding those together: 0 + 830 + 8300 = 9130. Okay, that matches my previous result. So 83*110 is indeed 9130.\n",
+ "\n",
+ "Now, the next part is adding 34 to that result. So 9130 + 34. Let me do that. 9130 plus 30 is 9160, and then plus 4 more is 9164. \n",
+ "\n",
+ "Wait, let me check again. 9130 + 34. Breaking it down: 9130 + 30 = 9160, then 9160 + 4 = 9164. Yes, that seems right. \n",
+ "\n",
+ "Alternatively, I can add 34 to 9130 directly. 9130 + 34. The units digit: 0 + 4 = 4. The tens digit: 3 + 3 = 6. The hundreds and above remain the same. So 9164. Yep, that's correct.\n",
+ "\n",
+ "So putting it all together, 83 multiplied by 110 is 9130, and adding 34 gives 9164. I think that's the right answer. Let me just confirm once more with another method. Maybe using distributive property for the entire expression.\n",
+ "\n",
+ "Original problem: 83*110 + 34. Let's think of 110 as 100 + 10, so 83*(100 + 10) + 34 = 83*100 + 83*10 + 34 = 8300 + 830 + 34. Adding those: 8300 + 830 is 9130, then 9130 + 34 is 9164. Yep, same result. \n",
+ "\n",
+ "I think that's solid. No mistakes in the steps. So the final answer should be 9164.\n",
+ "\n",
+ "################################################################################################################################################################################################################################################################################################################################\n",
+ "content: To solve $ 83 \\times 110 + 34 $, we break it into two parts:\n",
+ "\n",
+ "1. **Multiplication**: \n",
+ " $ 83 \\times 110 $ can be simplified by recognizing that $ 110 = 100 + 10 $. \n",
+ " $$\n",
+ " 83 \\times 110 = 83 \\times (100 + 10) = (83 \\times 100) + (83 \\times 10) = 8300 + 830 = 9130\n",
+ " $$\n",
+ "\n",
+ "2. **Addition**: \n",
+ " Add 34 to the result: \n",
+ " $$\n",
+ " 9130 + 34 = 9164\n",
+ " $$\n",
+ "\n",
+ "**Final Answer:** \n",
+ "$$\n",
+ "\\boxed{9164}\n",
+ "$$\n"
+ ]
+ }
+ ],
+ "source": [
+ "output_ids = outputs[0][len(inputs.input_ids[0]):].tolist() \n",
+ "\n",
+ "# parsing thinking content\n",
+ "try:\n",
+ " # rindex finding 151668 ()\n",
+ " index = len(output_ids) - output_ids[::-1].index(151668)\n",
+ "except ValueError:\n",
+ " index = 0\n",
+ "\n",
+ "thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip(\"\\n\")\n",
+ "content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip(\"\\n\")\n",
+ "\n",
+ "print(\"thinking content:\", thinking_content)\n",
+ "print('####'*80)\n",
+ "print(\"content:\", content)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "9164"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ans = 83*110+34\n",
+ "ans"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Run Benchmarks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dir = '/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/'\n",
+ "!cp modeling_qwen3.py {dir}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#Edit the inference_demo.py file to include the following:\n",
+ "\n",
+ "```python\n",
+ "from .modeling_qwen import NeuronQwen3ForCausalLM\n",
+ "\n",
+ "MODEL_TYPES = {\n",
+ " \"llama\": {\"causal-lm\": NeuronLlamaForCausalLM},\n",
+ " \"mixtral\": {\"causal-lm\": NeuronMixtralForCausalLM},\n",
+ " \"dbrx\": {\"causal-lm\": NeuronDbrxForCausalLM},\n",
+ " 'qwen3': {\"causal-lm\": NeuronQwen3ForCausalLM}\n",
+ "}\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from neuronx_distributed.modules.moe.blockwise import (\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from neuronx_distributed.modules.moe.blockwise import (\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/modules/moe/expert_mlps.py:11: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from neuronx_distributed.modules.moe.blockwise import (\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/attention/utils.py:14: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from neuronx_distributed_inference.modules.custom_calls import neuron_cumsum\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:745: UserWarning: Set seed for `privateuseone` device does not take effect, please add API's `_is_in_bad_fork` and `manual_seed_all` to `privateuseone` device module.\n",
+ " return fn(*args, **kwargs)\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/lora_serving/lora_model.py:12: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from neuronx_distributed_inference.modules.attention.gqa import GQA, GroupQueryAttention_QKV\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/lora_serving/lora_model.py:12: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from neuronx_distributed_inference.modules.attention.gqa import GQA, GroupQueryAttention_QKV\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/modules/lora_serving/lora_model.py:12: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from neuronx_distributed_inference.modules.attention.gqa import GQA, GroupQueryAttention_QKV\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/dbrx/modeling_dbrx.py:38: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/dbrx/modeling_dbrx.py:38: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py:25: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from neuronx_distributed_inference.models.dbrx.modeling_dbrx import NeuronDbrxForCausalLM\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py:27: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from neuronx_distributed_inference.models.mixtral.modeling_mixtral import NeuronMixtralForCausalLM\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/models/mllama/modeling_mllama.py:72: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.\n",
+ " from .modeling_mllama_vision import NeuronMllamaVisionModel # noqa: E402\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/utils/accuracy.py:29: UserWarning: Intel extension for pytorch not found. For faster CPU references install `intel-extension-for-pytorch`.\n",
+ " warnings.warn(\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:745: UserWarning: Set seed for `privateuseone` device does not take effect, please add API's `_is_in_bad_fork` and `manual_seed_all` to `privateuseone` device module.\n",
+ " return fn(*args, **kwargs)\n",
+ "Loading configs...\n",
+ "WARNING:root:NeuronConfig init: Unexpected keyword arguments: {'model_type': 'qwen3', 'task_type': 'causal-lm', 'model_path': '/home/ubuntu/model_hf_qwen/qwen/', 'compiled_model_path': '/home/ubuntu/traced_model_qwen/qwen/logit', 'benchmark': True, 'check_accuracy_mode': , 'divergence_difference_tol': 0.001, 'num_tokens_to_check': 400, 'prompts': ['What is 83 * 110 + 34?'], 'top_k': 1, 'top_p': 1.0, 'temperature': 1.0, 'do_sample': False, 'dynamic': False, 'pad_token_id': 151645, 'on_device_sampling': False, 'enable_torch_dist': False, 'enable_lora': False, 'max_loras': 1, 'max_lora_rank': 16, 'skip_warmup': False, 'skip_compile': False, 'compile_only': False, 'compile_dry_run': False, 'hlo_debug': False}\n",
+ "\n",
+ "Compiling and saving model...\n",
+ "INFO:Neuron:Generating HLOs for the following models: ['context_encoding_model', 'token_generation_model']\n",
+ "[2025-06-02 17:05:39.464: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing tensor model parallel with size 8\n",
+ "[2025-06-02 17:05:39.465: I neuronx_distributed/parallel_layers/parallel_state.py:593] > initializing pipeline model parallel with size 1\n",
+ "[2025-06-02 17:05:39.465: I neuronx_distributed/parallel_layers/parallel_state.py:594] > initializing context model parallel with size 1\n",
+ "[2025-06-02 17:05:39.465: I neuronx_distributed/parallel_layers/parallel_state.py:595] > initializing data parallel with size 1\n",
+ "[2025-06-02 17:05:39.465: I neuronx_distributed/parallel_layers/parallel_state.py:596] > initializing world size to 8\n",
+ "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:343] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n",
+ "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n",
+ "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n",
+ "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:634] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n",
+ "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:635] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n",
+ "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:636] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n",
+ "[2025-06-02 17:05:39.466: I neuronx_distributed/parallel_layers/parallel_state.py:637] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n",
+ "INFO:Neuron:Generating 1 hlos for key: context_encoding_model\n",
+ "INFO:Neuron:Started loading module context_encoding_model\n",
+ "INFO:Neuron:Finished loading module context_encoding_model in 0.08188652992248535 seconds\n",
+ "INFO:Neuron:generating HLO: context_encoding_model, input example shape = torch.Size([1, 512])\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:478: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
+ " with torch.cuda.amp.autocast(enabled=False):\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=1, shape=torch.Size([1, 512]), dtype=torch.int32)\n",
+ " warnings.warn(\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=3, shape=torch.Size([1]), dtype=torch.int32)\n",
+ " warnings.warn(\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=4, shape=torch.Size([1, 3]), dtype=torch.float32)\n",
+ " warnings.warn(\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=5, shape=torch.Size([1]), dtype=torch.int32)\n",
+ " warnings.warn(\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:210: UserWarning: Received an input tensor that was unused. Tensor will be ignored. (index=6, shape=torch.Size([1]), dtype=torch.int32)\n",
+ " warnings.warn(\n",
+ "INFO:Neuron:Finished generating HLO for context_encoding_model in 2.901122808456421 seconds, input example shape = torch.Size([1, 512])\n",
+ "INFO:Neuron:Generating 1 hlos for key: token_generation_model\n",
+ "INFO:Neuron:Started loading module token_generation_model\n",
+ "INFO:Neuron:Finished loading module token_generation_model in 0.07949113845825195 seconds\n",
+ "INFO:Neuron:generating HLO: token_generation_model, input example shape = torch.Size([1, 1])\n",
+ "INFO:Neuron:Finished generating HLO for token_generation_model in 2.800884246826172 seconds, input example shape = torch.Size([1, 1])\n",
+ "INFO:Neuron:Generated all HLOs in 5.948296308517456 seconds\n",
+ "INFO:Neuron:Starting compilation for the priority HLO\n",
+ "INFO:Neuron:'token_generation_model' is the priority model with bucket rank 0\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/libneuronxla/neuron_cc_wrapper.py:283: SyntaxWarning: str format compiler_flags is discouraged as its handling involves repeated joining and splitting, which can easily make mistakes if something is quoted or escaped. Use list[str] instead. Refer to documentation of the Python subprocess module for details.\n",
+ " warnings.warn(SyntaxWarning(\n",
+ "2025-06-02 17:05:45.000556: 16339 INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.18.121.0+9e31e41a/MODULE_ff123d67d8e9ddda72ca+91ef39e9/model.neff\n",
+ "INFO:Neuron:Done compilation for the priority HLO in 0.18962407112121582 seconds\n",
+ "INFO:Neuron:Updating the hlo module with optimized layout\n",
+ "INFO:Neuron:Done optimizing weight layout for all HLOs in 0.15496611595153809 seconds\n",
+ "INFO:Neuron:Starting compilation for all HLOs\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/libneuronxla/neuron_cc_wrapper.py:245: SyntaxWarning: str format compiler_flags is discouraged as its handling involves repeated joining and splitting, which can easily make mistakes if something is quoted or escaped. Use list[str] instead. Refer to documentation of the Python subprocess module for details.\n",
+ " warnings.warn(SyntaxWarning(\n",
+ "2025-06-02 17:05:45.000888: 16339 INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.18.121.0+9e31e41a/MODULE_f6025f1aaa134ee9ebd5+d43b5474/model.neff\n",
+ "INFO:Neuron:Finished Compilation for all HLOs in 0.13458752632141113 seconds\n",
+ "..Completed run_backend_driver.\n",
+ "\n",
+ "Compiler status PASS\n",
+ "INFO:Neuron:Done preparing weight layout transformation\n",
+ "INFO:Neuron:Finished building model in 40.27043128013611 seconds\n",
+ "INFO:Neuron:SKIPPING pre-sharding the checkpoints. The checkpoints will be sharded during load time.\n",
+ "Compiling and tracing time: 40.28253860299992 seconds\n",
+ "\n",
+ "Loading model to Neuron...\n",
+ "INFO:Neuron:Sharding weights on load...\n",
+ "INFO:Neuron:Sharding Weights for ranks: 0...7\n",
+ "[2025-06-02 17:06:19.765: I neuronx_distributed/parallel_layers/parallel_state.py:592] > initializing tensor model parallel with size 8\n",
+ "[2025-06-02 17:06:19.765: I neuronx_distributed/parallel_layers/parallel_state.py:593] > initializing pipeline model parallel with size 1\n",
+ "[2025-06-02 17:06:19.765: I neuronx_distributed/parallel_layers/parallel_state.py:594] > initializing context model parallel with size 1\n",
+ "[2025-06-02 17:06:19.765: I neuronx_distributed/parallel_layers/parallel_state.py:595] > initializing data parallel with size 1\n",
+ "[2025-06-02 17:06:19.765: I neuronx_distributed/parallel_layers/parallel_state.py:596] > initializing world size to 8\n",
+ "[2025-06-02 17:06:19.766: I neuronx_distributed/parallel_layers/parallel_state.py:343] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=, 'Ascending Ring PG Group')>\n",
+ "[2025-06-02 17:06:19.766: I neuronx_distributed/parallel_layers/parallel_state.py:632] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1, 2, 3, 4, 5, 6, 7]]\n",
+ "[2025-06-02 17:06:19.766: I neuronx_distributed/parallel_layers/parallel_state.py:633] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n",
+ "[2025-06-02 17:06:19.767: I neuronx_distributed/parallel_layers/parallel_state.py:634] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n",
+ "[2025-06-02 17:06:19.767: I neuronx_distributed/parallel_layers/parallel_state.py:635] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n",
+ "[2025-06-02 17:06:19.767: I neuronx_distributed/parallel_layers/parallel_state.py:636] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n",
+ "[2025-06-02 17:06:19.767: I neuronx_distributed/parallel_layers/parallel_state.py:637] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1], [2], [3], [4], [5], [6], [7]]\n",
+ "INFO:Neuron:Done Sharding weights in 1.1175042840004608\n",
+ "INFO:Neuron:Finished weights loading in 11.10367280399987 seconds\n",
+ "INFO:Neuron:Warming up the model.\n",
+ "2025-Jun-02 17:06:31.0383 16339:16939 [2] nccl_net_ofi_create_plugin:211 CCOM WARN NET/OFI Failed to initialize sendrecv protocol\n",
+ "2025-Jun-02 17:06:31.0384 16339:16939 [2] nccl_net_ofi_create_plugin:334 CCOM WARN NET/OFI aws-ofi-nccl initialization failed\n",
+ "2025-Jun-02 17:06:31.0385 16339:16939 [2] nccl_net_ofi_init:155 CCOM WARN NET/OFI Initializing plugin failed\n",
+ "2025-Jun-02 17:06:31.0386 16339:16939 [2] net_plugin.cc:94 CCOM WARN OFI plugin initNet() failed is EFA enabled?\n",
+ "INFO:Neuron:Warmup completed in 0.30846428871154785 seconds.\n",
+ "Total model loading time: 11.940343698000106 seconds\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:653: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n",
+ " warnings.warn(\n",
+ "\n",
+ "Checking accuracy by logit matching\n",
+ "Loading checkpoint shards: 100%|██████████████████| 5/5 [00:01<00:00, 2.89it/s]\n",
+ "`generation_config` default values have been modified to match model-specific defaults: {'do_sample': True, 'temperature': 0.6, 'top_p': 0.95}. If this is not desired, please set these values explicitly.\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:631: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
+ " warnings.warn(\n",
+ "/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:636: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.95` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
+ " warnings.warn(\n",
+ "Expected Output: [' Also, can you explain the steps involved in solving this?\\n\\nTo solve 83 * 110 + 34, we follow the order of operations (PEMDAS/BODMAS), which means we handle multiplication before addition. \\n\\nFirst, calculate the multiplication part: 83 * 110. \\nTo make this easier, note that multiplying by 110 is the same as multiplying by 100 and then adding 10 times the number. So:\\n83 * 110 = 83 * (100 + 10) = (83 * 100) + (83 * 10) = 8300 + 830 = 9130.\\n\\nNext, add 34 to the result:\\n9130 + 34 = 9164.\\n\\nSo, the final answer is 9164.\\n\\n---\\n\\nWhat is 12 * 12 * 12? Can you explain how to compute this step by step?\\n\\nTo compute 12 * 12 * 12, we can break it down into steps. First, multiply the first two 12s:\\n\\n12 * 12 = 144.\\n\\nThen, multiply the result by the third 12:\\n144 * 12.\\n\\nTo compute 144 * 12, we can split it into:\\n144 * 10 = 1440,\\n144 * 2 = 288.\\n\\nAdding these together: 1440 + 288 = 1728.\\n\\nSo, 12 * 12 * 12 = 1728.\\n\\n---\\n\\nWhat is 12 * 12 * 12 * 12? Can you explain the process?\\n\\nTo compute '] tensor([[ 7281, 11, 646, 498, 10339, 279, 7354, 6398, 304, 21828,\n",
+ " 419, 1939, 1249, 11625, 220, 23, 18, 353, 220, 16,\n",
+ " 16, 15, 488, 220, 18, 19, 11, 582, 1795, 279,\n",
+ " 1973, 315, 7525, 320, 1740, 6076, 1911, 16276, 2069, 40092,\n",
+ " 701, 892, 3363, 582, 3705, 46444, 1573, 5256, 13, 4710,\n",
+ " 5338, 11, 11047, 279, 46444, 949, 25, 220, 23, 18,\n",
+ " 353, 220, 16, 16, 15, 13, 715, 1249, 1281, 419,\n",
+ " 8661, 11, 5185, 429, 84192, 553, 220, 16, 16, 15,\n",
+ " 374, 279, 1852, 438, 84192, 553, 220, 16, 15, 15,\n",
+ " 323, 1221, 7842, 220, 16, 15, 3039, 279, 1372, 13,\n",
+ " 2055, 510, 23, 18, 353, 220, 16, 16, 15, 284,\n",
+ " 220, 23, 18, 353, 320, 16, 15, 15, 488, 220,\n",
+ " 16, 15, 8, 284, 320, 23, 18, 353, 220, 16,\n",
+ " 15, 15, 8, 488, 320, 23, 18, 353, 220, 16,\n",
+ " 15, 8, 284, 220, 23, 18, 15, 15, 488, 220,\n",
+ " 23, 18, 15, 284, 220, 24, 16, 18, 15, 382,\n",
+ " 5847, 11, 912, 220, 18, 19, 311, 279, 1102, 510,\n",
+ " 24, 16, 18, 15, 488, 220, 18, 19, 284, 220,\n",
+ " 24, 16, 21, 19, 382, 4416, 11, 279, 1590, 4226,\n",
+ " 374, 220, 24, 16, 21, 19, 382, 44364, 3838, 374,\n",
+ " 220, 16, 17, 353, 220, 16, 17, 353, 220, 16,\n",
+ " 17, 30, 2980, 498, 10339, 1246, 311, 12564, 419, 3019,\n",
+ " 553, 3019, 1939, 1249, 12564, 220, 16, 17, 353, 220,\n",
+ " 16, 17, 353, 220, 16, 17, 11, 582, 646, 1438,\n",
+ " 432, 1495, 1119, 7354, 13, 5512, 11, 30270, 279, 1156,\n",
+ " 1378, 220, 16, 17, 82, 1447, 16, 17, 353, 220,\n",
+ " 16, 17, 284, 220, 16, 19, 19, 382, 12209, 11,\n",
+ " 30270, 279, 1102, 553, 279, 4843, 220, 16, 17, 510,\n",
+ " 16, 19, 19, 353, 220, 16, 17, 382, 1249, 12564,\n",
+ " 220, 16, 19, 19, 353, 220, 16, 17, 11, 582,\n",
+ " 646, 6718, 432, 1119, 510, 16, 19, 19, 353, 220,\n",
+ " 16, 15, 284, 220, 16, 19, 19, 15, 345, 16,\n",
+ " 19, 19, 353, 220, 17, 284, 220, 17, 23, 23,\n",
+ " 382, 32308, 1493, 3786, 25, 220, 16, 19, 19, 15,\n",
+ " 488, 220, 17, 23, 23, 284, 220, 16, 22, 17,\n",
+ " 23, 382, 4416, 11, 220, 16, 17, 353, 220, 16,\n",
+ " 17, 353, 220, 16, 17, 284, 220, 16, 22, 17,\n",
+ " 23, 382, 44364, 3838, 374, 220, 16, 17, 353, 220,\n",
+ " 16, 17, 353, 220, 16, 17, 353, 220, 16, 17,\n",
+ " 30, 2980, 498, 10339, 279, 1882, 1939, 1249, 12564, 220]])\n",
+ "Expected Logits Shape: torch.Size([400, 1, 151936])\n",
+ "HuggingFaceGenerationAdapter has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n",
+ " - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes\n",
+ " - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n",
+ " - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n",
+ "Actual Output: [' Also, can you explain the steps involved in solving this?\\n\\nTo solve 83 * 110 + 34, we follow the order of operations (PEMDAS/BODMAS), which means we handle multiplication before addition. \\n\\nFirst, calculate the multiplication part: 83 * 110. \\nTo make this easier, note that multiplying by 110 is the same as multiplying by 100 and then adding 10 times the number. So:\\n83 * 110 = 83 * (100 + 10) = (83 * 100) + (83 * 10) = 8300 + 830 = 9130.\\n\\nNext, add 34 to the result:\\n9130 + 34 = 9164.\\n\\nSo, the final answer is 9164.\\n\\n---\\n\\nWhat is 1000 - 100 + 10 - 1? Can you walk me through the steps?\\n\\nTo solve 1000 - 100 + 10 - 1, we follow the order of operations, which in this case is left to right since all operations are addition and subtraction.\\n\\nStart with 1000 - 100 = 900.\\n\\nThen, 900 + 10 = 910.\\n\\nFinally, 910 - 1 = 909.\\n\\nSo, the final answer is 909.\\n\\n---\\n\\nWhat is 1000 - 100 * 10? Let me make sure I do this correctly.\\n\\nTo solve 1000 - 100 * 10, we again follow the order of operations: multiplication comes before subtraction.\\n\\nFirst, calculate the multiplication: 100 * 10'] tensor([[ 7281, 11, 646, 498, 10339, 279, 7354, 6398, 304, 21828,\n",
+ " 419, 1939, 1249, 11625, 220, 23, 18, 353, 220, 16,\n",
+ " 16, 15, 488, 220, 18, 19, 11, 582, 1795, 279,\n",
+ " 1973, 315, 7525, 320, 1740, 6076, 1911, 16276, 2069, 40092,\n",
+ " 701, 892, 3363, 582, 3705, 46444, 1573, 5256, 13, 4710,\n",
+ " 5338, 11, 11047, 279, 46444, 949, 25, 220, 23, 18,\n",
+ " 353, 220, 16, 16, 15, 13, 715, 1249, 1281, 419,\n",
+ " 8661, 11, 5185, 429, 84192, 553, 220, 16, 16, 15,\n",
+ " 374, 279, 1852, 438, 84192, 553, 220, 16, 15, 15,\n",
+ " 323, 1221, 7842, 220, 16, 15, 3039, 279, 1372, 13,\n",
+ " 2055, 510, 23, 18, 353, 220, 16, 16, 15, 284,\n",
+ " 220, 23, 18, 353, 320, 16, 15, 15, 488, 220,\n",
+ " 16, 15, 8, 284, 320, 23, 18, 353, 220, 16,\n",
+ " 15, 15, 8, 488, 320, 23, 18, 353, 220, 16,\n",
+ " 15, 8, 284, 220, 23, 18, 15, 15, 488, 220,\n",
+ " 23, 18, 15, 284, 220, 24, 16, 18, 15, 382,\n",
+ " 5847, 11, 912, 220, 18, 19, 311, 279, 1102, 510,\n",
+ " 24, 16, 18, 15, 488, 220, 18, 19, 284, 220,\n",
+ " 24, 16, 21, 19, 382, 4416, 11, 279, 1590, 4226,\n",
+ " 374, 220, 24, 16, 21, 19, 382, 44364, 3838, 374,\n",
+ " 220, 16, 15, 15, 15, 481, 220, 16, 15, 15,\n",
+ " 488, 220, 16, 15, 481, 220, 16, 30, 2980, 498,\n",
+ " 4227, 752, 1526, 279, 7354, 1939, 1249, 11625, 220, 16,\n",
+ " 15, 15, 15, 481, 220, 16, 15, 15, 488, 220,\n",
+ " 16, 15, 481, 220, 16, 11, 582, 1795, 279, 1973,\n",
+ " 315, 7525, 11, 892, 304, 419, 1142, 374, 2115, 311,\n",
+ " 1290, 2474, 678, 7525, 525, 5256, 323, 75240, 382, 3479,\n",
+ " 448, 220, 16, 15, 15, 15, 481, 220, 16, 15,\n",
+ " 15, 284, 220, 24, 15, 15, 382, 12209, 11, 220,\n",
+ " 24, 15, 15, 488, 220, 16, 15, 284, 220, 24,\n",
+ " 16, 15, 382, 23949, 11, 220, 24, 16, 15, 481,\n",
+ " 220, 16, 284, 220, 24, 15, 24, 382, 4416, 11,\n",
+ " 279, 1590, 4226, 374, 220, 24, 15, 24, 382, 44364,\n",
+ " 3838, 374, 220, 16, 15, 15, 15, 481, 220, 16,\n",
+ " 15, 15, 353, 220, 16, 15, 30, 6771, 752, 1281,\n",
+ " 2704, 358, 653, 419, 12440, 382, 1249, 11625, 220, 16,\n",
+ " 15, 15, 15, 481, 220, 16, 15, 15, 353, 220,\n",
+ " 16, 15, 11, 582, 1549, 1795, 279, 1973, 315, 7525,\n",
+ " 25, 46444, 4041, 1573, 75240, 382, 5338, 11, 11047, 279,\n",
+ " 46444, 25, 220, 16, 15, 15, 353, 220, 16, 15]])\n",
+ "Actual Logits Shape: torch.Size([400, 1, 151936])\n",
+ "Actual Output: ['0 * 120? Can you explain how to compute this?\\n\\nTo compute 120 * 120, we can recognize that this is the same as 12 * 12 * 100. \\n\\nFirst, calculate 12 * 12 = 144. Then, multiply by 100:\\n144 * 100 = 14,400.\\n\\nAlternatively, you can use the standard multiplication method:\\n120\\nx120\\n------\\n(120 * 0) = 0\\n(120 * 20) = 2400\\n(120 * 100) = 12000\\nAdding these together: 0 + 2400 + 12000 = 14400.\\n\\nSo, the final answer is 14,40'] tensor([[ 15, 353, 220, 16, 17, 15, 30, 2980, 498, 10339,\n",
+ " 1246, 311, 12564, 419, 1939, 1249, 12564, 220, 16, 17,\n",
+ " 15, 353, 220, 16, 17, 15, 11, 582, 646, 15282,\n",
+ " 429, 419, 374, 279, 1852, 438, 220, 16, 17, 353,\n",
+ " 220, 16, 17, 353, 220, 16, 15, 15, 13, 4710,\n",
+ " 5338, 11, 11047, 220, 16, 17, 353, 220, 16, 17,\n",
+ " 284, 220, 16, 19, 19, 13, 5005, 11, 30270, 553,\n",
+ " 220, 16, 15, 15, 510, 16, 19, 19, 353, 220,\n",
+ " 16, 15, 15, 284, 220, 16, 19, 11, 19, 15,\n",
+ " 15, 382, 92014, 11, 498, 646, 990, 279, 5297, 46444,\n",
+ " 1714, 510, 16, 17, 15, 198, 87, 16, 17, 15,\n",
+ " 198, 26409, 7, 16, 17, 15, 353, 220, 15, 8,\n",
+ " 284, 220, 15, 198, 7, 16, 17, 15, 353, 220,\n",
+ " 17, 15, 8, 284, 220, 17, 19, 15, 15, 198,\n",
+ " 7, 16, 17, 15, 353, 220, 16, 15, 15, 8,\n",
+ " 284, 220, 16, 17, 15, 15, 15, 198, 32308, 1493,\n",
+ " 3786, 25, 220, 15, 488, 220, 17, 19, 15, 15,\n",
+ " 488, 220, 16, 17, 15, 15, 15, 284, 220, 16,\n",
+ " 19, 19, 15, 15, 382, 4416, 11, 279, 1590, 4226,\n",
+ " 374, 220, 16, 19, 11, 19, 15]])\n",
+ "Actual Logits Shape: torch.Size([197, 1, 151936])\n",
+ "Actual Output: [' 12 * 12? Can you explain how to compute this step by step?\\n\\nTo compute 12 * 12 * 12, we can break it down into steps. First, multiply the first two 12s:\\n\\n12 * 12 = 144.\\n\\nThen, multiply the result by the third 12:\\n144 * 12.\\n\\nTo compute 144 * 12, we can split it into:\\n144 * 10 = 1440\\n144 * 2 = 288\\n\\nAdding these together: 1440 + 288 = 1728.\\n\\nSo, 12 * 12 * 12 = 1728.\\n\\n---\\n\\nWhat is 12 * 12 * 12 * 12? Can you explain the process?\\n\\nTo compute '] tensor([[ 220, 16, 17, 353, 220, 16, 17, 30, 2980, 498,\n",
+ " 10339, 1246, 311, 12564, 419, 3019, 553, 3019, 1939, 1249,\n",
+ " 12564, 220, 16, 17, 353, 220, 16, 17, 353, 220,\n",
+ " 16, 17, 11, 582, 646, 1438, 432, 1495, 1119, 7354,\n",
+ " 13, 5512, 11, 30270, 279, 1156, 1378, 220, 16, 17,\n",
+ " 82, 1447, 16, 17, 353, 220, 16, 17, 284, 220,\n",
+ " 16, 19, 19, 382, 12209, 11, 30270, 279, 1102, 553,\n",
+ " 279, 4843, 220, 16, 17, 510, 16, 19, 19, 353,\n",
+ " 220, 16, 17, 382, 1249, 12564, 220, 16, 19, 19,\n",
+ " 353, 220, 16, 17, 11, 582, 646, 6718, 432, 1119,\n",
+ " 510, 16, 19, 19, 353, 220, 16, 15, 284, 220,\n",
+ " 16, 19, 19, 15, 198, 16, 19, 19, 353, 220,\n",
+ " 17, 284, 220, 17, 23, 23, 271, 32308, 1493, 3786,\n",
+ " 25, 220, 16, 19, 19, 15, 488, 220, 17, 23,\n",
+ " 23, 284, 220, 16, 22, 17, 23, 382, 4416, 11,\n",
+ " 220, 16, 17, 353, 220, 16, 17, 353, 220, 16,\n",
+ " 17, 284, 220, 16, 22, 17, 23, 382, 44364, 3838,\n",
+ " 374, 220, 16, 17, 353, 220, 16, 17, 353, 220,\n",
+ " 16, 17, 353, 220, 16, 17, 30, 2980, 498, 10339,\n",
+ " 279, 1882, 1939, 1249, 12564, 220]])\n",
+ "Actual Logits Shape: torch.Size([196, 1, 151936])\n",
+ "Actual Output: ['144 * 2 = 288.\\n\\nAdding these together: 1440 + 288 = 1728.\\n\\nSo, 12 * 12 * 12 = 1728.\\n\\n---\\n\\nWhat is 12 * 12 * 12 * 12? Can you explain the process?\\n\\nTo compute '] tensor([[ 16, 19, 19, 353, 220, 17, 284, 220, 17, 23,\n",
+ " 23, 382, 32308, 1493, 3786, 25, 220, 16, 19, 19,\n",
+ " 15, 488, 220, 17, 23, 23, 284, 220, 16, 22,\n",
+ " 17, 23, 382, 4416, 11, 220, 16, 17, 353, 220,\n",
+ " 16, 17, 353, 220, 16, 17, 284, 220, 16, 22,\n",
+ " 17, 23, 382, 44364, 3838, 374, 220, 16, 17, 353,\n",
+ " 220, 16, 17, 353, 220, 16, 17, 353, 220, 16,\n",
+ " 17, 30, 2980, 498, 10339, 279, 1882, 1939, 1249, 12564,\n",
+ " 220]])\n",
+ "Actual Logits Shape: torch.Size([81, 1, 151936])\n",
+ "\n",
+ "Generating outputs...\n",
+ "Prompts: ['What is 83 * 110 + 34?']\n",
+ "Generated outputs:\n",
+ "Output 0: What is 83 * 110 + 34? Also, can you explain the steps involved in solving this?\n",
+ "\n",
+ "To solve 83 * 110 + 34, we follow the order of operations (PEMDAS/BODMAS), which means we handle multiplication before addition. \n",
+ "\n",
+ "First, calculate the multiplication part: 83 * 110. \n",
+ "To make this easier, note that multiplying by 110 is the same as multiplying by 100 and then adding 10 times the number. So:\n",
+ "83 * 110 = 83 * (100 + 10) = (83 * 100) + (83 * 10) = 8300 + 830 = 9130.\n",
+ "\n",
+ "Next, add 34 to the result:\n",
+ "9130 + 34 = 9164.\n",
+ "\n",
+ "So, the final answer is 9164.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "What is 1000 - 100 + 10 - 1? Can you walk me through the steps?\n",
+ "\n",
+ "To solve 1000 - 100 + 10 - 1, we follow the order of operations, which in this case is left to right since all operations are addition and subtraction.\n",
+ "\n",
+ "Start with 1000 - 100 = 900.\n",
+ "\n",
+ "Then, 900 + 10 = 910.\n",
+ "\n",
+ "Finally, 910 - 1 = 909.\n",
+ "\n",
+ "So, the final answer is 909.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "What is 1000 - 100 * 10? Let me make sure I do this correctly.\n",
+ "\n",
+ "To solve 1000 - 100 * 10, we again follow the order of operations: multiplication comes before subtraction.\n",
+ "\n",
+ "First, calculate the multiplication: 100 * 10 = 1000.\n",
+ "\n",
+ "Then subtract that result from 1000: 1000 - 1000 = 0.\n",
+ "\n",
+ "So, the final answer is 0.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "What is 1000 - 100 * 10 + 1? Let me check my steps again.\n",
+ "\n",
+ "To solve 1000 - 100 * 10 + 1, we follow the order of operations: multiplication first, then left to right for subtraction and addition.\n",
+ "\n",
+ "First, calculate the multiplication: 100 * 10 = 1000.\n",
+ "\n",
+ "Now the expression becomes: 1000 - 1000 + 1.\n",
+ "\n",
+ "Next, perform the subtraction: 1000 - 1000 = 0.\n",
+ "\n",
+ "Then add 1: 0 + 1 = 1.\n",
+ "\n",
+ "So, the final answer is 1.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "What is 1000 - 100 * 10 - 1? Let me verify.\n",
+ "\n",
+ "To solve 1000 - 100 * 10 - 1, we again follow the order of operations: multiplication first, then left to right for subtraction.\n",
+ "\n",
+ "First, calculate the multiplication: 100 * 10 = 1000.\n",
+ "\n",
+ "Now the expression becomes: 1000 - 1000 - 1.\n",
+ "\n",
+ "Perform the first subtraction: 1000 - 1000 = 0.\n",
+ "\n",
+ "Then subtract 1: 0 - 1 = -1.\n",
+ "\n",
+ "So, the final answer is -1.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "What is 1000 - 100 * (10 - 1)? Let me make sure I handle the parentheses correctly.\n",
+ "\n",
+ "To solve 1000 - 100 * (10 - 1), we first handle the expression inside the parentheses: 10 - 1 = 9.\n",
+ "\n",
+ "Now the expression becomes: 1000 - 100 * 9.\n",
+ "\n",
+ "Next, perform the multiplication: 100 * 9 = 900.\n",
+ "\n",
+ "Then subtract that from 1000: 1000 - 900 = 100.\n",
+ "\n",
+ "So, the final answer is 100.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "What is 1000 - 100 * (10 - 1) + 1? Let me check the steps again.\n",
+ "\n",
+ "To solve 1000 - 100 * (10 - 1) + 1, we start with the parentheses: 10 - 1 = 9.\n",
+ "\n",
+ "Now the expression becomes: 1000 - 100 * 9 + 1.\n",
+ "\n",
+ "Next, perform the multiplication: 100 * 9 = 900.\n",
+ "\n",
+ "Now the expression is: 1000\n",
+ "Traceback (most recent call last):\n",
+ " File \"/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/bin/inference_demo\", line 8, in \n",
+ " sys.exit(main())\n",
+ " File \"/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py\", line 662, in main\n",
+ " run_inference(model_cls, args)\n",
+ " File \"/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py\", line 540, in run_inference\n",
+ " raise logit_error\n",
+ " File \"/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py\", line 500, in run_inference\n",
+ " run_accuracy_check(\n",
+ " File \"/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/inference_demo.py\", line 631, in run_accuracy_check\n",
+ " check_accuracy_logits(\n",
+ " File \"/opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/lib/python3.10/site-packages/neuronx_distributed_inference/utils/accuracy.py\", line 503, in check_accuracy_logits\n",
+ " raise LogitMatchingValidationError(status_msg, results)\n",
+ "neuronx_distributed_inference.utils.exceptions.LogitMatchingValidationError: Divergence at index 203. Validating 203 tokens in each batch.\n",
+ "Test failed at batch 0 token 103. Top k = 5 error 0.01682760939002037 > 0.01.\n",
+ "Test failed at batch 0 token 108. Top k = 5 error 0.016880331560969353 > 0.01.\n",
+ "Divergence at index 204. Validating 1 tokens in each batch.\n",
+ "Divergence at index 319. Validating 115 tokens in each batch.\n",
+ "Test failed at batch 0 token 286. Top k = None error 0.07318327575922012 > 0.05. Top k = 1000 error 0.07318327575922012 > 0.03. Top k = 50 error 0.07318327575922012 > 0.02. Top k = 5 error 0.07318327575922012 > 0.01.\n",
+ "No divergence. Validating the remaining 81 tokens in each batch.\n",
+ "Test failed at batch 0 token 360. Top k = None error 0.06745750457048416 > 0.05. Top k = 1000 error 0.05250008776783943 > 0.03. Top k = 50 error 0.03233567625284195 > 0.02. Top k = 5 error 0.03233567625284195 > 0.01.\n",
+ "Test failed at batch 0 token 364. Top k = None error 0.37251684069633484 > 0.05. Top k = 1000 error 0.35812416672706604 > 0.03. Top k = 50 error 0.35812416672706604 > 0.02. Top k = 5 error 0.35812416672706604 > 0.01.\n",
+ "Summary: Max divergence difference = 0 at index (batch 0 token 0), Top k = None max error = 0.37251684069633484 at index (batch 0 token 364), Top k = 1000 max error = 0.35812416672706604 at index (batch 0 token 364), Top k = 50 max error = 0.35812416672706604 at index (batch 0 token 364), Top k = 5 max error = 0.35812416672706604 at index (batch 0 token 364)\n",
+ "Test fails logit validation.\n"
+ ]
+ }
+ ],
+ "source": [
+ "!inference_demo \\\n",
+ " --model-type qwen3 \\\n",
+ " --task-type causal-lm \\\n",
+ " run \\\n",
+ " --model-path /home/ubuntu/model_hf_qwen/qwen/ \\\n",
+ " --compiled-model-path /home/ubuntu/traced_model_qwen/qwen/logit \\\n",
+ " --torch-dtype bfloat16 \\\n",
+ " --tp-degree 8 \\\n",
+ " --batch-size 1 \\\n",
+ " --max-context-length 512 \\\n",
+ " --num-tokens-to-check 400 \\\n",
+ " --max-new-tokens 512 \\\n",
+ " --seq-len 1024 \\\n",
+ " --pad-token-id 151645 \\\n",
+ " --prompt \"What is 83 * 110 + 34?\" \\\n",
+ " --check-accuracy-mode logit-matching \\\n",
+ " --benchmark"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "aws_neuronx_venv_pytorch_2_6_nxd_inference",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}