Skip to content

Commit 1d35db7

Browse files
ZonePGmrwyattii
andauthored
Refactor the Qwen positional emebdding config code (#4955)
follow PR #4920 on Qwen inference code Co-authored-by: Michael Wyatt <[email protected]>
1 parent 5a8bf3f commit 1d35db7

File tree

2 files changed

+9
-23
lines changed

2 files changed

+9
-23
lines changed

deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
// DeepSpeed Team
55

6+
#include <cassert>
67
#include "blocked_kv_rotary.cuh"
78
#include "conversion_utils.h"
89
#include "ds_kernel_utils.h"

deepspeed/inference/v2/model_implementations/qwen/model.py

+8-23
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ def norm_type(self) -> NormTypeEnum:
100100
def positional_embedding_type(self) -> PositionalEmbeddingType:
101101
return PositionalEmbeddingType.rotate_half
102102

103+
@property
104+
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
105+
return RotateHalfConfig(theta_base=self._config.rotary_emb_base)
106+
103107
def make_norm_layer(self) -> None:
104108
"""
105109
Instantiates the normalization layer for the model. This sets the `self.norm` attribute.
@@ -119,27 +123,6 @@ def make_norm_layer(self) -> None:
119123

120124
self.norm = heuristics.instantiate_pre_norm(norm_config, self._engine_config)
121125

122-
def make_attn_layer(self) -> None:
123-
"""
124-
Builds the attention layer for the model. This sets the `self.attn` attribute.
125-
"""
126-
softmax_scale = 1.0 / (self.head_size**0.5)
127-
128-
rotary_config = RotateHalfConfig(theta_base=self._config.rotary_emb_base)
129-
130-
attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
131-
n_heads_q=self.n_heads_q_local,
132-
n_heads_kv=self.n_heads_kv_local,
133-
head_size=self.head_size,
134-
max_sequences=self._engine_config.state_manager.max_ragged_sequence_count,
135-
scale_factor=softmax_scale,
136-
input_dtype=self.activation_dtype,
137-
output_dtype=self.activation_dtype,
138-
positional_embedding_type=self.positional_embedding_type,
139-
positional_embedding_config=rotary_config)
140-
141-
self.attn = heuristics.instantiate_attention(attn_config, self._engine_config)
142-
143126
"""
144127
Forward implementations
145128
"""
@@ -210,8 +193,10 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge
210193
Performs unembedding of the hidden states to logits. This will only sample the final
211194
token of each sequence.
212195
"""
213-
logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info,
214-
self._non_transformer.final_norm)
196+
logits = self.unembed(hidden_states,
197+
self._non_transformer.word_unembed,
198+
ragged_batch_info,
199+
gamma=self._non_transformer.final_norm)
215200

216201
if self.tp_size > 1:
217202
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))

0 commit comments

Comments
 (0)