@@ -16,50 +16,58 @@ class Metric:
1616 requires_properties : list [str ] = field (default_factory = list )
1717 _parameters : list = field (init = False )
1818
19- def compute_lazy (self , df : pl .LazyFrame , ** kwargs ) -> tuple [pl .LazyFrame , dict [str ,pl .LazyFrame ]]:
19+ def compute_lazy (self , df : pl .LazyFrame , ** kwargs ) -> tuple [pl .LazyFrame , dict [str , pl .LazyFrame ]]:
2020 try :
2121 df , properties = self .compute_func (df , ** kwargs )
2222 assert isinstance (df , pl .LazyFrame )
2323 assert all (p in properties for p in self .computes_properties )
2424 return df , properties
2525
2626 except TypeError as e :
27- raise TypeError (f'Missing paramter for Metric with compute_func { self .compute_func .__name__ } : { repr (e )} ' ) from e
27+ raise TypeError (
28+ f"Missing paramter for Metric with compute_func { self .compute_func .__name__ } : { repr (e )} "
29+ ) from e
2830
2931 def __post_init__ (self ):
3032 sig = inspect .signature (self .compute_func )
3133 parameters = sig .parameters
32- assert 'df' in parameters
34+ assert "df" in parameters
3335 assert all (p in parameters for p in self .requires_properties )
3436
35- self ._parameters = [v for k , v in parameters .items () if k not in ['df' , 'args' , 'kwargs' ]+ self .requires_properties ]
37+ self ._parameters = [
38+ v for k , v in parameters .items () if k not in ["df" , "args" , "kwargs" ] + self .requires_properties
39+ ]
3640
3741 def __call__ (self , df : pl .DataFrame , ** kwargs ):
3842 try :
3943 if not isinstance (df , pl .LazyFrame ):
4044 df = pl .LazyFrame (df )
4145 return self .compute_lazy (df , ** kwargs )
4246 except TypeError as e :
43- raise TypeError (f'Missing paramter for Metric with compute_func { self .compute_func .__name__ } : { repr (e )} ' ) from e
47+ raise TypeError (
48+ f"Missing paramter for Metric with compute_func { self .compute_func .__name__ } : { repr (e )} "
49+ ) from e
4450
4551
46- def metric (computes_columns : list [str ]| None = None ,
47- computes_properties : list [str ]| None = None ,
48- requires_columns : list [str ]| None = None ,
49- requires_properties : list [str ]| None = None ):
52+ def metric (
53+ computes_columns : list [str ] | None = None ,
54+ computes_properties : list [str ] | None = None ,
55+ requires_columns : list [str ] | None = None ,
56+ requires_properties : list [str ] | None = None ,
57+ ):
5058 def decorator (func ):
5159 return Metric (
52- compute_func = func ,
53- computes_columns = computes_columns or [],
54- computes_properties = computes_properties or [],
55- requires_columns = requires_columns or [],
56- requires_properties = requires_properties or []
60+ compute_func = func ,
61+ computes_columns = computes_columns or [],
62+ computes_properties = computes_properties or [],
63+ requires_columns = requires_columns or [],
64+ requires_properties = requires_properties or [],
5765 )
58- return decorator
5966
67+ return decorator
6068
6169
62- @metric (computes_columns = [ ' distance_traveled' , ' vel' ])
70+ @metric (computes_columns = [ " distance_traveled" , " vel" ])
6371def driven_distance_and_vel (df ) -> tuple [pl .DataFrame , dict [str , pl .DataFrame ]]:
6472 return df .with_columns (
6573 (pl .col ("x" ).diff () ** 2 + pl .col ("y" ).diff () ** 2 )
@@ -72,10 +80,11 @@ def driven_distance_and_vel(df) -> tuple[pl.DataFrame, dict[str, pl.DataFrame]]:
7280 ), {}
7381
7482
75- @metric (requires_columns = ['distance_traveled' , 'vel' ],
76- computes_properties = ['timegaps' , 'min_timegaps' , 'p_timegaps' , 'min_p_timegaps' ])
83+ @metric (
84+ requires_columns = ["distance_traveled" , "vel" ],
85+ computes_properties = ["timegaps" , "min_timegaps" , "p_timegaps" , "min_p_timegaps" ],
86+ )
7787def timegaps_and_p_timegaps (df , / , ego_id , time_buffer = 2e9 ):
78-
7988 ego_df = df .filter (idx = ego_id )
8089
8190 crossed = df .join (ego_df , how = "cross" , suffix = "_ego" )
@@ -150,32 +159,36 @@ def timegaps_and_p_timegaps(df, /, ego_id, time_buffer=2e9):
150159 }
151160
152161
153-
154-
155- metrics = [
156- timegaps_and_p_timegaps ,
157- driven_distance_and_vel
158- ]
162+ metrics = [timegaps_and_p_timegaps , driven_distance_and_vel ]
159163
160164
161165@dataclass
162166class MetricManager :
163167 metrics : list [Metric ] = field (default_factory = lambda : metrics )
164168 exclude_columns : list [str ] = field (default_factory = list )
165169 exclude_properties : list [str ] = field (default_factory = list )
166- _dependencies : dict [int | str , list [int | str ]] = field (init = False )
170+ _dependencies : dict [int | str , list [int | str ]] = field (init = False )
167171 _ordered_metrics : list [Metric ] = field (init = False )
168172 _parameters : list = field (init = False )
169173
170174 def __post_init__ (self ):
171- self ._dependencies = {val : [i ] for i ,m in enumerate (self .metrics ) for val in [f'column_{ n } ' for n in m .computes_columns ]+ [f'property_{ n } ' for n in m .computes_properties ]}| {
172- i : [f'column_{ n } ' for n in m .requires_columns ]+ [f'property_{ n } ' for n in m .requires_properties ] for i ,m in enumerate (self .metrics )
175+ self ._dependencies = {
176+ val : [i ]
177+ for i , m in enumerate (self .metrics )
178+ for val in [f"column_{ n } " for n in m .computes_columns ] + [f"property_{ n } " for n in m .computes_properties ]
179+ } | {
180+ i : [f"column_{ n } " for n in m .requires_columns ] + [f"property_{ n } " for n in m .requires_properties ]
181+ for i , m in enumerate (self .metrics )
173182 }
174183
175- unresovled_dependencies = {k : v for k ,vv in self ._dependencies .items () for v in vv if v not in self ._dependencies }
184+ unresovled_dependencies = {
185+ k : v for k , vv in self ._dependencies .items () for v in vv if v not in self ._dependencies
186+ }
176187 if len (unresovled_dependencies ) > 0 :
177- error_dict = {f'self.metrics[{ k } ]' : v for k ,v in unresovled_dependencies .items ()}
178- raise RuntimeError (f"There are columns and properties required by metrics, that are never computed: { error_dict } " )
188+ error_dict = {f"self.metrics[{ k } ]" : v for k , v in unresovled_dependencies .items ()}
189+ raise RuntimeError (
190+ f"There are columns and properties required by metrics, that are never computed: { error_dict } "
191+ )
179192
180193 self ._parameters = [v for m in self .metrics for v in m ._parameters ]
181194
@@ -185,7 +198,7 @@ def __post_init__(self):
185198 def __repr__ (self ):
186199 return f"computes columns: { [c for m in self ._ordered_metrics for c in m .computes_columns ]} - computes properties { [p for m in self ._ordered_metrics for p in m .computes_properties ]} - args { self ._parameters } "
187200
188- def compute (self , r : Recording , ** kwargs ) -> tuple [pl .DataFrame , dict [str ,pl .DataFrame ]]:
201+ def compute (self , r : Recording , ** kwargs ) -> tuple [pl .DataFrame , dict [str , pl .DataFrame ]]:
189202 if "polygon" not in r ._df .columns :
190203 r ._df = r ._add_polygons (r ._df )
191204 if "geometry" not in r ._df .columns :
@@ -197,13 +210,13 @@ def compute(self, r: Recording, **kwargs) -> tuple[pl.DataFrame, dict[str,pl.Dat
197210 df , new_p = m .compute_lazy (
198211 df = df ,
199212 ** {k : properties [k ] for k in m .requires_properties },
200- ** {k : v for k ,v in kwargs .items () if k in [p .name for p in m ._parameters ]}
213+ ** {k : v for k , v in kwargs .items () if k in [p .name for p in m ._parameters ]},
201214 )
202215 properties |= new_p
203216 for k in self .exclude_properties :
204217 del properties [k ]
205218 df = df .drop (self .exclude_columns )
206- res = pl .collect_all ([df ]+ list (properties .values ()))
207- df , computed_props = res [0 ], res [1 :]
219+ res = pl .collect_all ([df ] + list (properties .values ()))
220+ df , computed_props = res [0 ], res [1 :]
208221 assert all (c in df .columns or c in self .exclude_columns for m in self .metrics for c in m .computes_columns )
209- return df , {k :v for k ,v in zip (properties .keys (), computed_props )}
222+ return df , {k : v for k , v in zip (properties .keys (), computed_props )}
0 commit comments