-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathconfig.py
More file actions
133 lines (102 loc) · 4.95 KB
/
Copy pathconfig.py
File metadata and controls
133 lines (102 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import logging
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from arctic_inference.patching import ArcticPatch
logger = logging.getLogger(__name__)
@dataclass
class ArcticParallelConfig(ParallelConfig):
ulysses_sequence_parallel_size: int = 1
enable_shift_parallel: bool = False
shift_parallel_threshold: int = 512
def __post_init__(self, *args, **kwargs):
if (self.enable_shift_parallel
and self.ulysses_sequence_parallel_size == 1):
raise ValueError("ulysses_sequence_parallel_size must be > 1 "
"when enable_shift_parallel is True.")
super().__post_init__(*args, **kwargs)
@property
def world_size(self) -> int:
return (self.pipeline_parallel_size * self.tensor_parallel_size *
self.ulysses_sequence_parallel_size)
@world_size.setter
def world_size(self, value: int) -> None:
# ParallelConfig.__post_init__ will assign world_size to PP * TP, while
# we want PP * TP * SP to be the world size. So we define world_size as
# a property with a no-op setter to ignore the value later assigned by
# ParallelConfig.__post_init__.
pass
@dataclass
class ArcticSpeculativeConfig(SpeculativeConfig):
enable_suffix_decoding: bool = False
suffix_cache_max_depth: int = 64
suffix_speculative_tokens: int = 0
suffix_cache_max_requests: int = 100000
suffix_max_spec_factor: float = 1.0
suffix_max_spec_offset: float = 0.0
suffix_min_token_prob: float = 0.1
class ParallelConfigPatch(ArcticPatch[ParallelConfig]):
def __new__(cls, *args, **kwargs):
# Override __new__ to return an ArcticParallelConfig instead of a
# ParallelConfig when creating a new instance of the class.
if cls is ParallelConfig:
return ArcticParallelConfig.__new__(ArcticParallelConfig, *args,
**kwargs)
return super(ParallelConfig, cls).__new__(cls)
class SpeculativeConfigPatch(ArcticPatch[SpeculativeConfig]):
_orig_post_init = SpeculativeConfig.__post_init__
def __new__(cls, *args, **kwargs):
# Override __new__ to return an ArcticSpeculativeConfig instead of a
# SpeculativeConfig when creating a new instance of the class.
if cls is SpeculativeConfig:
return ArcticSpeculativeConfig.__new__(ArcticSpeculativeConfig, *args, **kwargs)
return super(SpeculativeConfig, cls).__new__(cls)
def __post_init__(self):
use_suffix = (self.method == "suffix") or (self.method is None and self.enable_suffix_decoding)
if (use_suffix or self.method == "arctic") and self.disable_by_batch_size is None:
logger.info("Defaulting disable_by_batch_size to 64")
self.disable_by_batch_size = 64
if use_suffix:
self.method = "suffix"
self.enable_suffix_decoding = True
self.num_speculative_tokens = self.suffix_cache_max_depth
self._verify_args()
else:
SpeculativeConfigPatch._orig_post_init(self)
@classmethod
def from_dict(cls, dict_value: dict) -> SpeculativeConfig:
"""Parse the CLI value for the speculative config."""
# Generic dataclass from_dict implementation
# Accepts both ArcticSpeculativeConfig and SpeculativeConfig
if cls is SpeculativeConfig:
target_cls = ArcticSpeculativeConfig
else:
target_cls = cls
return target_cls(**dict_value)
class VllmConfigPatch(ArcticPatch[VllmConfig]):
_orig_str = VllmConfig.__str__
def __str__(self, *args, **kwargs):
string = self._orig_str(*args, **kwargs)
string += f", ulysses_sequence_parallel_size={self.parallel_config.ulysses_sequence_parallel_size}"
string += f", enable_shift_parallel={self.parallel_config.enable_shift_parallel}"
string += f", shift_parallel_threshold={self.parallel_config.shift_parallel_threshold}"
return string
class MLPSpeculatorConfigPatch(ArcticPatch[MLPSpeculatorConfig]):
_orig_init = MLPSpeculatorConfig.__init__
def __init__(self, *args, **kwargs):
self.base_model_arch = kwargs.pop("base_model_arch", "")
self._orig_init(*args, **kwargs)