Skip to content

Commit 81f0ded

Browse files
authored
[None][feat] Add GPT OSS support for AutoDeploy (NVIDIA#6641)
Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>
1 parent a060e12 commit 81f0ded

File tree

3 files changed

+99
-23
lines changed

3 files changed

+99
-23
lines changed

examples/auto_deploy/build_and_run_ad.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,6 @@ class PromptConfig(BaseModel):
4141
"In simple words and in a single sentence, explain the concept of gravity: ",
4242
"How to fix slicing in golf? ",
4343
"Where is the capital of Iceland? ",
44-
"How big is the universe? ",
45-
"In simple words and in a single sentence, explain the concept of gravity: ",
46-
"How to fix slicing in golf? ",
47-
"Where is the capital of Iceland? ",
4844
]
4945
)
5046
sp_kwargs: Dict[str, Any] = Field(

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,8 @@ class AutoModelForCausalLMFactory(ModelFactory):
7373

7474
_model_defaults = {
7575
"use_cache": False,
76-
"max_position_embeddings": 1024,
7776
}
7877

79-
def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
80-
"""Get the max position embeddings config for the model."""
81-
return {
82-
"max_position_embeddings": self.max_seq_len,
83-
}
84-
8578
def __init__(self, *args, **kwargs):
8679
super().__init__(*args, **kwargs)
8780
self._quant_config_reader: QuantConfigReader | None = None
@@ -90,7 +83,6 @@ def __init__(self, *args, **kwargs):
9083
self.model_kwargs = deep_merge_dicts(
9184
self._model_defaults,
9285
self.model_kwargs,
93-
self._get_max_position_embeddings_config(),
9486
)
9587

9688
# special handling for torch_dtype in model_kwargs since HF does not correctly update
@@ -342,22 +334,11 @@ def _load_quantization_config(self, fetched_dir: str):
342334
class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
343335
_model_defaults = {
344336
"use_cache": False,
345-
"max_position_embeddings": 1024,
346337
"text_config": {
347-
"max_position_embeddings": 1024,
348338
"use_cache": False,
349339
},
350340
}
351341

352-
def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
353-
"""Get the max position embeddings config for the model."""
354-
return {
355-
"max_position_embeddings": self.max_seq_len,
356-
"text_config": {
357-
"max_position_embeddings": self.max_seq_len,
358-
},
359-
}
360-
361342
@property
362343
def automodel_from_config(self):
363344
return AutoModelForImageTextToText.from_config
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import types
2+
from typing import Callable, Dict, Optional
3+
4+
import torch
5+
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
6+
7+
8+
def gpt_oss_attention(
9+
self,
10+
hidden_states: torch.Tensor,
11+
position_embeddings: torch.Tensor,
12+
attention_mask: Optional[torch.Tensor] = None,
13+
past_key_value: Optional[torch.Tensor] = None,
14+
cache_position: Optional[torch.LongTensor] = None,
15+
**kwargs,
16+
):
17+
"""GPT OSS Attention forward function rewritten to wrap attention as a custom op."""
18+
from transformers.models.gpt_oss.modeling_gpt_oss import apply_rotary_pos_emb
19+
20+
# Add new parameters
21+
sliding_window = getattr(self, "sliding_window", -1) # Default to -1 if not present
22+
23+
input_shape = hidden_states.shape[:-1]
24+
hidden_shape = (*input_shape, -1, self.head_dim)
25+
26+
# Apply Q, K, V projections (same as original)
27+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
28+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
29+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
30+
31+
# Use original rope implementation
32+
cos, sin = position_embeddings
33+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
34+
35+
# Handle KV cache properly
36+
if past_key_value is not None:
37+
# Update KV cache - check if it has update method (modern cache objects)
38+
if hasattr(past_key_value, "update"):
39+
cache_kwargs = {"cache_position": cache_position}
40+
key_states, value_states = past_key_value.update(
41+
key_states, value_states, self.layer_idx, cache_kwargs
42+
)
43+
else:
44+
# Handle legacy tuple-based cache
45+
if isinstance(past_key_value, tuple) and len(past_key_value) == 2:
46+
past_key, past_value = past_key_value
47+
key_states = torch.cat([past_key, key_states], dim=2)
48+
value_states = torch.cat([past_value, value_states], dim=2)
49+
50+
# Convert from [batch, num_heads, seq_len, head_dim] to [batch, seq_len, num_heads, head_dim]
51+
query_states = query_states.transpose(1, 2).contiguous()
52+
key_states = key_states.transpose(1, 2).contiguous()
53+
value_states = value_states.transpose(1, 2).contiguous()
54+
55+
# Get sinks parameter from model if available
56+
sinks = None
57+
if hasattr(self, "sinks"):
58+
# If sinks is a model parameter, use it directly
59+
sinks = self.sinks
60+
61+
# Use custom op to capture attention. This layout is bsnd (batch, seq, num_heads, head_dim)
62+
attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(
63+
query_states,
64+
key_states,
65+
value_states,
66+
attn_mask=attention_mask,
67+
dropout_p=0.0,
68+
is_causal=True,
69+
scale=self.scaling,
70+
sinks=sinks,
71+
sliding_window=sliding_window,
72+
)
73+
74+
# Reshape back to original input shape
75+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
76+
attn_output = self.o_proj(attn_output)
77+
78+
return attn_output, past_key_value
79+
80+
81+
_from_config_original = AutoModelForCausalLM.from_config
82+
83+
CUSTOM_MODULE_PATCHES: Dict[str, Callable] = {
84+
"GptOssAttention": gpt_oss_attention,
85+
}
86+
87+
88+
def get_model_from_config_patched(config, **kwargs):
89+
model = _from_config_original(config, **kwargs)
90+
# Patch modules
91+
for _, module in model.named_modules():
92+
if type(module).__name__ in CUSTOM_MODULE_PATCHES.keys():
93+
# Replace the forward method
94+
module.forward = types.MethodType(CUSTOM_MODULE_PATCHES[type(module).__name__], module)
95+
96+
return model
97+
98+
99+
AutoModelForCausalLM.from_config = get_model_from_config_patched

0 commit comments

Comments
 (0)