11from pathlib import Path
2- from typing import Callable
32
43import torch
54
87
98
109class Scorer :
11- scorer_callback : Callable
12-
13- num_scores : int
14-
1510 writer : ScoreWriter
1611
17- device : torch .device
18-
1912 def __init__ (
2013 self ,
2114 path : Path ,
@@ -29,10 +22,14 @@ def __init__(
2922 self .dtype = dtype
3023 self .num_items = num_items
3124
32- self .scorer_callback = self .build_scorer_callback (
33- query_grads ,
34- score_cfg ,
25+ self .query_tensor = torch .cat (
26+ [
27+ query_grads [m ].to (device = self .device , dtype = self .dtype )
28+ for m in score_cfg .modules
29+ ],
30+ dim = 1 ,
3531 )
32+ self .score_cfg = score_cfg
3633
3734 num_scores = len (query_grads [score_cfg .modules [0 ]])
3835
@@ -47,37 +44,22 @@ def __call__(
4744 indices : list [int ],
4845 mod_grads : dict [str , torch .Tensor ],
4946 ):
50- first_grad = next ( iter ( mod_grads . values ()))
51- if first_grad .dtype != self .dtype :
47+ # Convert the gradients to the scoring dtype
48+ if next ( iter ( mod_grads . values ())) .dtype != self .dtype :
5249 mod_grads = {name : grad .to (self .dtype ) for name , grad in mod_grads .items ()}
5350
54- scores = self .scorer_callback (mod_grads )
55- self .writer (indices , scores )
56-
57- def build_scorer_callback (
58- self ,
59- query_grads : dict [str , torch .Tensor ],
60- score_cfg : ScoreConfig ,
61- ) -> Callable :
62- """Unified scorer builder for all scorer types."""
63- query_tensor = torch .cat (
64- [
65- query_grads [m ].to (device = self .device , dtype = self .dtype )
66- for m in score_cfg .modules
67- ],
68- dim = 1 ,
69- )
51+ scores = self .score (mod_grads )
7052
71- @torch .inference_mode ()
72- def callback (mod_grads : dict [str , torch .Tensor ]):
73- grads = torch .cat ([mod_grads [m ] for m in score_cfg .modules ], dim = 1 )
74- if score_cfg .unit_normalize :
75- grads /= grads .norm (dim = 1 , keepdim = True )
53+ self .writer (indices , scores )
7654
77- if score_cfg .score == "nearest" :
78- all_scores = grads @ query_tensor .T
79- return all_scores .max (dim = - 1 ).values
55+ @torch .inference_mode ()
56+ def score (self , mod_grads : dict [str , torch .Tensor ]):
57+ grads = torch .cat ([mod_grads [m ] for m in self .score_cfg .modules ], dim = 1 )
58+ if self .score_cfg .unit_normalize :
59+ grads /= grads .norm (dim = 1 , keepdim = True )
8060
81- return grads @ query_tensor .T
61+ if self .score_cfg .score == "nearest" :
62+ all_scores = grads @ self .query_tensor .T
63+ return all_scores .max (dim = - 1 ).values
8264
83- return callback
65+ return grads @ self . query_tensor . T
0 commit comments