@@ -54,49 +54,79 @@ def load_sparsify_sparse_coders(
5454 model (Any): The model to load autoencoders for.
5555 name (str): The name of the sparse model to load. If the model is on-disk
5656 this is the path to the directory containing the sparse model weights.
57- hookpoints (list[str]): list of hookpoints to load autoencoders for .
57+ hookpoints (list[str]): list of hookpoints to identify the sparse models .
5858 device (str | torch.device | None, optional): The device to load the
5959 sparse models on. If not specified the sparse models will be loaded
6060 on the same device as the base model.
6161
6262 Returns:
63- tuple[dict[str, Any], Any]: A tuple containing the submodules dictionary
64- and the edited model.
63+ dict[str, Any]: A dictionary mapping hookpoints to sparse models.
6564 """
6665 if device is None :
6766 device = model .device or "cpu"
6867
6968 # Load the sparse models
70- hookpoint_to_sparse = {}
69+ sparse_model_dict = {}
7170 name_path = Path (name )
7271 if name_path .exists ():
7372 for hookpoint in hookpoints :
74- hookpoint_to_sparse [hookpoint ] = Sae .load_from_disk (
73+ sparse_model_dict [hookpoint ] = Sae .load_from_disk (
7574 name_path / hookpoint , device = device
7675 )
7776 if compile :
78- hookpoint_to_sparse [hookpoint ] = torch .compile (
79- hookpoint_to_sparse [hookpoint ]
77+ sparse_model_dict [hookpoint ] = torch .compile (
78+ sparse_model_dict [hookpoint ]
8079 )
8180 else :
8281 sparse_models = Sae .load_many (name , device = device )
8382 for hookpoint in hookpoints :
84- hookpoint_to_sparse [hookpoint ] = sparse_models [hookpoint ]
83+ sparse_model_dict [hookpoint ] = sparse_models [hookpoint ]
8584 if compile :
86- hookpoint_to_sparse [hookpoint ] = torch .compile (
87- hookpoint_to_sparse [hookpoint ]
85+ sparse_model_dict [hookpoint ] = torch .compile (
86+ sparse_model_dict [hookpoint ]
8887 )
8988
9089 del sparse_models
90+ return sparse_model_dict
9191
92- submodules = {}
93- for hookpoint , sparse_model in hookpoint_to_sparse .items ():
92+
93+ def load_sparsify_hooks (
94+ model : PreTrainedModel ,
95+ name : str ,
96+ hookpoints : list [str ],
97+ device : str | torch .device | None = None ,
98+ compile : bool = False ,
99+ ) -> dict [str , Callable ]:
100+ """
101+ Load the encode functions for sparsify sparse coders on specified hookpoints.
102+
103+ Args:
104+ model (Any): The model to load autoencoders for.
105+ name (str): The name of the sparse model to load. If the model is on-disk
106+ this is the path to the directory containing the sparse model weights.
107+ hookpoints (list[str]): list of hookpoints to identify the sparse models.
108+ device (str | torch.device | None, optional): The device to load the
109+ sparse models on. If not specified the sparse models will be loaded
110+ on the same device as the base model.
111+
112+ Returns:
113+ dict[str, Callable]: A dictionary mapping hookpoints to encode functions.
114+ """
115+ sparse_model_dict = load_sparsify_sparse_coders (
116+ model ,
117+ name ,
118+ hookpoints ,
119+ device ,
120+ compile ,
121+ )
122+ hookpoint_to_sparse_encode = {}
123+ for hookpoint , sparse_model in sparse_model_dict .items ():
94124 path_segments = resolve_path (model , hookpoint .split ("." ))
95125 if path_segments is None :
96126 raise ValueError (f"Could not find valid path for hookpoint: { hookpoint } " )
97127
98- submodules ["." .join (path_segments )] = partial (
128+ hookpoint_to_sparse_encode ["." .join (path_segments )] = partial (
99129 sae_dense_latents , sae = sparse_model
100130 )
101131
102- return submodules
132+ return hookpoint_to_sparse_encode
0 commit comments