-
Notifications
You must be signed in to change notification settings - Fork 309
Expand file tree
/
Copy pathconfig.py
More file actions
38 lines (29 loc) · 889 Bytes
/
config.py
File metadata and controls
38 lines (29 loc) · 889 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
import torch
from lorax_server.adapters.weights import AdapterWeights
if TYPE_CHECKING:
from lorax_server.models.model import Model
ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]]
@dataclass
class AdapterConfig(ABC):
base_model_name_or_path: str
@abstractmethod
def map_weights_for_model(
self,
adapter_weights: Dict,
weight_names: Tuple[str],
embedding_weight_name: str,
) -> Tuple[ModuleMap, Set[str]]:
pass
@abstractmethod
def load_batched_adapter_weights(
self,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
pass