2020from bergson .hessians .kfac import CovarianceCollector
2121from bergson .hessians .tkfac import TraceCovarianceCollector
2222from bergson .utils .utils import (
23+ convert_precision_to_torch ,
2324 setup_reproducibility ,
2425 validate_batch_size ,
2526)
@@ -81,7 +82,7 @@ def hessian_worker(
8182 rank : int ,
8283 local_rank : int ,
8384 world_size : int ,
84- cfg : IndexConfig ,
85+ index_cfg : IndexConfig ,
8586 hessian_cfg : HessianConfig ,
8687 ds : Dataset ,
8788):
@@ -135,35 +136,37 @@ def hessian_worker(
135136 world_size = world_size ,
136137 )
137138
138- model , target_modules = setup_model_and_peft (cfg )
139+ model , target_modules = setup_model_and_peft (index_cfg )
139140
140- attention_cfgs = {module : cfg .attention for module in cfg .split_attention_modules }
141+ attention_cfgs = {
142+ module : index_cfg .attention for module in index_cfg .split_attention_modules
143+ }
141144
142145 kwargs = {
143146 "model" : model ,
144147 "data" : ds ,
145- "cfg" : cfg ,
148+ "cfg" : index_cfg ,
146149 "hessian_cfg" : hessian_cfg ,
147150 "target_modules" : target_modules ,
148151 "attention_cfgs" : attention_cfgs ,
149152 }
150153
151- batches = allocate_batches (ds ["length" ], cfg .token_batch_size )
154+ batches = allocate_batches (ds ["length" ], index_cfg .token_batch_size )
152155 kwargs ["batches" ] = batches
153156 collect_hessians (** kwargs )
154157
155158 total_processed = torch .load (
156- f"{ cfg .partial_run_path } /total_processed.pt" ,
159+ f"{ index_cfg .partial_run_path } /total_processed.pt" ,
157160 map_location = "cpu" ,
158161 weights_only = False ,
159162 )
160163
161164 compute_eigendecomposition (
162- os .path .join (cfg .partial_run_path , "activation_sharded" ),
165+ os .path .join (index_cfg .partial_run_path , "activation_sharded" ),
163166 total_processed = total_processed ,
164167 )
165168 compute_eigendecomposition (
166- os .path .join (cfg .partial_run_path , "gradient_sharded" ),
169+ os .path .join (index_cfg .partial_run_path , "gradient_sharded" ),
167170 total_processed = total_processed ,
168171 )
169172
@@ -174,7 +177,7 @@ def hessian_worker(
174177def collect_hessians (
175178 model : PreTrainedModel ,
176179 data : Dataset ,
177- cfg : IndexConfig ,
180+ index_cfg : IndexConfig ,
178181 * ,
179182 batches : list [list [int ]] | None = None ,
180183 target_modules : set [str ] | None = None ,
@@ -190,14 +193,14 @@ def collect_hessians(
190193 hessian_dtype = (
191194 model .dtype
192195 if hessian_cfg .hessian_dtype == "auto"
193- else hessian_cfg .hessian_dtype
196+ else convert_precision_to_torch ( hessian_cfg .hessian_dtype )
194197 )
195198
196199 collector_args = {
197200 "model" : model .base_model , # type: ignore
198201 "target_modules" : target_modules ,
199202 "attention_cfgs" : attention_cfgs or {},
200- "path" : str (cfg .partial_run_path ),
203+ "path" : str (index_cfg .partial_run_path ),
201204 }
202205 desc = f"Approximating Hessians with { hessian_cfg .method } "
203206 if ev_correction :
@@ -207,16 +210,16 @@ def collect_hessians(
207210 collector_args ["dtype" ] = hessian_dtype
208211 collector = HESSIAN_APPROXIMATIONS [hessian_cfg .method ](** collector_args )
209212
210- validate_batch_size (model , cfg .token_batch_size , collector )
213+ validate_batch_size (model , index_cfg .token_batch_size , collector )
211214
212215 computer = CollectorComputer (
213216 model = model , # type: ignore
214217 data = data ,
215218 collector = collector ,
216219 batches = batches ,
217- cfg = cfg ,
220+ cfg = index_cfg ,
218221 )
219222
220- computer .forward_backward = fwd_bwd_hessian_factory (hessian_cfg )
223+ computer .forward_backward = fwd_bwd_hessian_factory (index_cfg , hessian_cfg )
221224
222225 computer .run_with_collector_hooks (desc = desc )
0 commit comments