Skip to content

Commit b982de0

Browse files
whx-sjtucwazai
authored andcommitted
[PluggableLayer][1/N] Define PluggableLayer (Fix ci) (vllm-project#32744)
Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: 陈建华 <1647430658@qq.com>
1 parent 20f38fc commit b982de0

7 files changed

Lines changed: 108 additions & 48 deletions

File tree

benchmarks/kernels/benchmark_activation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88

99
import vllm.model_executor.layers.activation # noqa F401
10-
from vllm.model_executor.custom_op import CustomOp
10+
from vllm.model_executor.custom_op import op_registry
1111
from vllm.triton_utils import triton
1212
from vllm.utils.argparse_utils import FlexibleArgumentParser
1313
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed
@@ -33,14 +33,14 @@ def benchmark_activation(
3333
torch.set_default_device(device)
3434

3535
if func_name == "gelu_and_mul":
36-
layer = CustomOp.op_registry[func_name](approximate="none")
36+
layer = op_registry[func_name](approximate="none")
3737
elif func_name == "gelu_and_mul_tanh":
38-
layer = CustomOp.op_registry["gelu_and_mul"](approximate="tanh")
38+
layer = op_registry["gelu_and_mul"](approximate="tanh")
3939
elif func_name == "fatrelu_and_mul":
4040
threshold = 0.5
41-
layer = CustomOp.op_registry[func_name](threshold)
41+
layer = op_registry[func_name](threshold)
4242
else:
43-
layer = CustomOp.op_registry[func_name]()
43+
layer = op_registry[func_name]()
4444

4545
x = torch.randn(num_tokens, dim, dtype=dtype, device=device)
4646
compiled_layer = torch.compile(layer.forward_native)

docs/design/custom_op.md

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,6 @@ This document will introduce how CustomOp works in vLLM and how to implement a n
88

99
`CustomOp` manages two dictionaries of all custom ops (i.e., op classes, indexed by registered name) in its class, for vLLM and OOT plugins respectively.
1010

11-
??? code
12-
13-
```python
14-
class CustomOp(nn.Module):
15-
16-
op_registry: dict[str, type["CustomOp"]] = {}
17-
op_registry_oot: dict[str, type["CustomOp"]] = {}
18-
```
19-
2011
We can use `@CustomOp.register("op_name")` to register an op class to the `CustomOp` system. After this, the `op_name` and its class will be added into the `op_registry` dictionary. In addition, We can also register an OOT op by `@CustomOp.register_oot("op_name")`. We will introduce this mechanism in detail later.
2112

2213
When a `CustomOp` is called (i.e., call its `forward()` method), if it is enabled (i.e., with `--compilation_config.custom_ops '["+op_name"]'`), it will automatically dispatch the forward method to the appropriate backend according to `current_platform`. Otherwise (i.e., it is disabled), it will only call the `forward_native()` method to use PyTorch-native implementation of this forward method.

tests/kernels/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch._prims_common import TensorLikeType
1414

1515
from tests.kernels.quant_utils import native_w8a8_block_matmul
16-
from vllm.model_executor.custom_op import CustomOp
16+
from vllm.model_executor.custom_op import op_registry
1717
from vllm.model_executor.layers.activation import SiluAndMul
1818
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
1919
from vllm.utils.torch_utils import make_tensor_with_pad
@@ -883,7 +883,7 @@ def torch_experts(
883883

884884
f32 = torch.float32
885885

886-
act = CustomOp.op_registry[activation]
886+
act = op_registry[activation]
887887

888888
for i in range(num_experts):
889889
mask = topk_ids == i

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
get_cached_compilation_config,
1212
set_current_vllm_config,
1313
)
14-
from vllm.model_executor.custom_op import CustomOp
14+
from vllm.model_executor.custom_op import CustomOp, op_registry
1515
from vllm.model_executor.layers.activation import (
1616
GeluAndMul,
1717
ReLUSquaredActivation,
@@ -98,17 +98,17 @@ def test_enabled_ops(
9898
ops_enabled = [bool(x) for x in ops_enabled]
9999

100100
assert RMSNorm(1024).enabled() == ops_enabled[0]
101-
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
101+
assert op_registry["rms_norm"].enabled() == ops_enabled[0]
102102

103103
assert SiluAndMul().enabled() == ops_enabled[1]
104-
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
104+
assert op_registry["silu_and_mul"].enabled() == ops_enabled[1]
105105

106106
assert GeluAndMul().enabled() == ops_enabled[2]
107-
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
107+
assert op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
108108

109109
# If registered, subclasses should follow their own name
110110
assert Relu3().enabled() == ops_enabled[3]
111-
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
111+
assert op_registry["relu3"].enabled() == ops_enabled[3]
112112

113113
# Unregistered subclass
114114
class SiluAndMul2(SiluAndMul):

vllm/config/compilation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,13 +1033,13 @@ def custom_op_log_check(self):
10331033
# check if op name exists in model
10341034
op_name = op[1:]
10351035
if op_name not in all_ops_in_model:
1036-
from vllm.model_executor.custom_op import CustomOp
1036+
from vllm.model_executor.custom_op import op_registry
10371037

10381038
# Does op exist at all or is it just not present in this model?
10391039
# Note: Only imported op classes appear in the registry.
10401040
missing_str = (
10411041
"doesn't exist (or wasn't imported/registered)"
1042-
if op_name not in CustomOp.op_registry
1042+
if op_name not in op_registry
10431043
else "not present in model"
10441044
)
10451045

vllm/model_executor/custom_op.py

Lines changed: 86 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,86 @@
1111
logger = init_logger(__name__)
1212

1313

14+
# Dictionary of all custom ops (classes, indexed by registered name).
15+
# To check if an op with a name is enabled, call .enabled() on the class.
16+
# Examples:
17+
# - MyOp.enabled()
18+
# - op_registry["my_op"].enabled()
19+
op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
20+
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
21+
22+
23+
class PluggableLayer(nn.Module):
24+
"""
25+
Base class for pluggable layers.
26+
27+
A PluggableLayer is a *module-composing* abstraction: it may instantiate other
28+
``torch.nn.Module`` objects as sub-layers, and its functionality depends on
29+
these sub-layers following a generalized invocation sequence. Also, it is stateful
30+
and may hold parameters or buffers.
31+
32+
Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform
33+
``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement
34+
of the entire layer class at instantiation time, allowing customized
35+
initialization and submodule composition.
36+
"""
37+
38+
def __new__(cls, *args, **kwargs):
39+
try:
40+
layer_class_name = cls.__name__
41+
except AttributeError:
42+
raise TypeError(
43+
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
44+
f"was not set, possibly because it was not decorated with "
45+
f"@PluggableLayer.register, or it's the PluggableLayer itself."
46+
) from None
47+
48+
if layer_class_name not in op_registry_oot:
49+
layer_cls_to_instantiate = cls
50+
else:
51+
layer_cls_to_instantiate = op_registry_oot[layer_class_name]
52+
logger.debug(
53+
"Instantiating pluggable layer: %s using %s",
54+
layer_class_name,
55+
str(layer_cls_to_instantiate),
56+
)
57+
return super().__new__(layer_cls_to_instantiate)
58+
59+
# Decorator to register pluggable layers.
60+
@classmethod
61+
def register(cls, name: str):
62+
def decorator(op_cls):
63+
assert name not in op_registry, f"Duplicate op name: {name}"
64+
op_cls.name = name
65+
op_registry[name] = op_cls
66+
return op_cls
67+
68+
return decorator
69+
70+
# Decorator to register out-of-tree(oot) pluggable layers.
71+
# For OOT pluggable layers:
72+
# if in-tree layer class is registered with an oot_custom_layer,
73+
# the oot_custom_layer will be used instead.
74+
@classmethod
75+
def register_oot(cls, _decorated_layer_cls=None, name: str | None = None):
76+
def decorator(layer_cls):
77+
reg_name = name if name is not None else cls.__name__
78+
assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}"
79+
layer_cls.name = reg_name
80+
op_registry_oot[reg_name] = layer_cls
81+
return layer_cls
82+
83+
if _decorated_layer_cls is None:
84+
# Called with parentheses: @PluggableLayer.register_oot()
85+
# or @PluggableLayer.register_oot(name="...")
86+
return decorator
87+
elif isinstance(_decorated_layer_cls, type): # Check if it's a class
88+
# Called without parentheses: @PluggableLayer.register_oot
89+
return decorator(_decorated_layer_cls)
90+
else:
91+
raise TypeError("Decorator can only be applied to classes.")
92+
93+
1494
class CustomOp(nn.Module):
1595
"""
1696
Base class for custom ops.
@@ -27,10 +107,10 @@ def __new__(cls, *args, **kwargs):
27107
f"@CustomOp.register, or it's the CustomOp base class itself."
28108
) from None
29109

30-
if op_name not in cls.op_registry_oot:
110+
if op_name not in op_registry_oot:
31111
op_cls_to_instantiate = cls
32112
else:
33-
op_cls_to_instantiate = cls.op_registry_oot[op_name]
113+
op_cls_to_instantiate = op_registry_oot[op_name]
34114
logger.debug(
35115
"Instantiating custom op: %s using %s",
36116
op_name,
@@ -150,21 +230,13 @@ def default_on() -> bool:
150230

151231
return not count_none > 0 or count_all > 0
152232

153-
# Dictionary of all custom ops (classes, indexed by registered name).
154-
# To check if an op with a name is enabled, call .enabled() on the class.
155-
# Examples:
156-
# - MyOp.enabled()
157-
# - op_registry["my_op"].enabled()
158-
op_registry: dict[str, type["CustomOp"]] = {}
159-
op_registry_oot: dict[str, type["CustomOp"]] = {}
160-
161233
# Decorator to register custom ops.
162234
@classmethod
163235
def register(cls, name: str):
164236
def decorator(op_cls):
165-
assert name not in cls.op_registry, f"Duplicate op name: {name}"
237+
assert name not in op_registry, f"Duplicate op name: {name}"
166238
op_cls.name = name
167-
cls.op_registry[name] = op_cls
239+
op_registry[name] = op_cls
168240
return op_cls
169241

170242
return decorator
@@ -182,9 +254,9 @@ def decorator(op_cls):
182254
def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
183255
def decorator(op_cls):
184256
reg_name = name if name is not None else cls.__name__
185-
assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}"
257+
assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}"
186258
op_cls.name = reg_name
187-
cls.op_registry_oot[reg_name] = op_cls
259+
op_registry_oot[reg_name] = op_cls
188260
return op_cls
189261

190262
if _decorated_op_cls is None:

vllm/model_executor/layers/mla.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from vllm.attention.layer import MLAAttention
88
from vllm.config import CacheConfig
9-
from vllm.model_executor.custom_op import CustomOp
9+
from vllm.model_executor.custom_op import PluggableLayer
1010
from vllm.model_executor.layers.quantization import QuantizationConfig
1111

1212

@@ -30,13 +30,13 @@ class MLAModules:
3030

3131

3232
# --8<-- [start:multi_head_latent_attention]
33-
@CustomOp.register("multi_head_latent_attention")
34-
class MultiHeadLatentAttentionWrapper(CustomOp):
35-
"""MLA layer registered as CustomOp to allow OOT backends to add
33+
@PluggableLayer.register("multi_head_latent_attention")
34+
class MultiHeadLatentAttentionWrapper(PluggableLayer):
35+
"""Pluggable MLA layer which allows OOT backends to add
3636
custom implementations of the outer MLA layer (including rope & o_proj).
37-
Note that currently MLA ignores the enable/disable mechanism of CustomOp
38-
because there is only one in-tree implementation in forward_native.
39-
TODO: implement this with a new PluggableLayer mechanism.
37+
Note that currently oot platforms can still use CustomOp.register_oot to
38+
replace MLA layer entirly, although we use PluggableLayer to register
39+
this layer now.
4040
4141
This class takes positions and hidden_states as input.
4242
The input tensors can either contain prefill tokens or decode tokens.
@@ -110,7 +110,7 @@ def __init__(
110110

111111
self.prefix = prefix
112112

113-
def forward_native(
113+
def forward(
114114
self,
115115
positions: torch.Tensor,
116116
hidden_states: torch.Tensor,
@@ -174,6 +174,3 @@ def forward_native(
174174
)
175175

176176
return self.o_proj(attn_out)[0]
177-
178-
def forward_cuda(self, *args, **kwargs):
179-
return self.forward_native(*args, **kwargs)

0 commit comments

Comments
 (0)