|
21 | 21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
22 | 22 | # See the License for the specific language governing permissions and |
23 | 23 | # limitations under the License. |
24 | | -"""Inference-only LLaMA model compatible with HuggingFace weights.""" |
| 24 | +"""Inference-only SwissAI model compatible with HuggingFace weights.""" |
25 | 25 | from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union |
26 | 26 |
|
27 | 27 | import torch |
28 | 28 | from torch import nn |
29 | | -from transformers import LlamaConfig |
| 29 | +from transformers import SwissAIConfig |
30 | 30 |
|
31 | 31 | from vllm.attention import Attention |
32 | 32 | from vllm.compilation.decorators import support_torch_compile |
33 | 33 | from vllm.config import CacheConfig, VllmConfig |
34 | 34 | from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size |
35 | 35 | from vllm.model_executor.layers.activation import XIELU |
36 | 36 | from vllm.model_executor.layers.layernorm import RMSNorm |
37 | | -from vllm.model_executor.layers.linear import (QKVParallelLinear, |
| 37 | +from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
| 38 | + QKVParallelLinear, |
38 | 39 | RowParallelLinear) |
39 | 40 | from vllm.model_executor.layers.logits_processor import LogitsProcessor |
40 | 41 | from vllm.model_executor.layers.quantization import QuantizationConfig |
@@ -66,7 +67,7 @@ def __init__( |
66 | 67 | prefix: str = "", |
67 | 68 | ) -> None: |
68 | 69 | super().__init__() |
69 | | - self.up_proj = RowParallelLinear( |
| 70 | + self.up_proj = ColumnParallelLinear( |
70 | 71 | input_size=hidden_size, |
71 | 72 | output_size=intermediate_size, |
72 | 73 | bias=bias, |
@@ -95,7 +96,7 @@ def forward(self, x): |
95 | 96 | class SwissAIAttention(nn.Module): |
96 | 97 |
|
97 | 98 | def __init__(self, |
98 | | - config: LlamaConfig, |
| 99 | + config: SwissAIConfig, |
99 | 100 | hidden_size: int, |
100 | 101 | num_heads: int, |
101 | 102 | num_kv_heads: int, |
@@ -216,7 +217,7 @@ class SwissAIDecoderLayer(nn.Module): |
216 | 217 |
|
217 | 218 | def __init__( |
218 | 219 | self, |
219 | | - config: LlamaConfig, |
| 220 | + config: SwissAIConfig, |
220 | 221 | cache_config: Optional[CacheConfig] = None, |
221 | 222 | quant_config: Optional[QuantizationConfig] = None, |
222 | 223 | prefix: str = "", |
|
0 commit comments