4646from dataclasses import dataclass
4747from itertools import zip_longest
4848from pathlib import Path
49+ from typing import Any
4950from unittest .mock import patch
5051
5152import gspread
@@ -313,7 +314,7 @@ def _compute_metrics_for_model(df: pl.DataFrame, model_name: str) -> dict:
313314 }
314315
315316
316- def _compute_metrics_avg_over_projects (df : pl .DataFrame , model_name : str ) -> dict :
317+ def _compute_metrics_avg_over_projects (df : pl .DataFrame , model_name : str ) -> dict [ str , float ] :
317318 """Compute metrics averaged over projects so large projects don't dominate."""
318319 metrics_per_project = []
319320 for _ , df_project in df .group_by ("project_id" ):
@@ -350,14 +351,16 @@ def plot_metrics_by_platform(df: pl.DataFrame, model_names: list[str]) -> plt.Fi
350351 metrics_to_plot = None
351352 for (platform ,), platform_df in df .group_by ("platform" ):
352353 for model_name in model_names :
353- project_metrics_list = []
354+ project_metrics_list : list [ dict [ str , Any ]] = []
354355 for _ , proj_df in platform_df .group_by ("project_id" ):
355356 project_metrics_list .append (_compute_metrics_for_model (proj_df , model_name ))
356357 if metrics_to_plot is None :
357358 metrics_to_plot = list (project_metrics_list [0 ].keys ())
358359 avg_metrics = {
359- k : sum (m [k ] for m in project_metrics_list if m [k ] == m [k ])
360- / sum (1 for m in project_metrics_list if m [k ] == m [k ])
360+ k : (
361+ sum (m [k ] for m in project_metrics_list if m [k ] == m [k ])
362+ / sum (1 for m in project_metrics_list if m [k ] == m [k ])
363+ )
361364 for k in project_metrics_list [0 ]
362365 }
363366 avg_metrics ["platform" ] = platform
@@ -367,9 +370,10 @@ def plot_metrics_by_platform(df: pl.DataFrame, model_names: list[str]) -> plt.Fi
367370 metrics_df = pl .DataFrame (metrics_rows )
368371
369372 # Convert to pandas and pivot for plotting
373+ assert metrics_to_plot is not None , "No platforms in df"
370374 metrics_pd = metrics_df .to_pandas ()
371- fig , axes = plt .subplots (1 , len (metrics_to_plot ), figsize = (4 * len (metrics_to_plot ), 5 ))
372- axes : list [plt .Axes ] = list (axes )
375+ fig , axes_arr = plt .subplots (1 , len (metrics_to_plot ), figsize = (4 * len (metrics_to_plot ), 5 ))
376+ axes : list [plt .Axes ] = list (axes_arr )
373377
374378 for ax , metric in zip (axes , metrics_to_plot , strict = True ):
375379 pivot_df = metrics_pd .pivot (index = "platform" , columns = "model" , values = metric )
@@ -382,7 +386,7 @@ def plot_metrics_by_platform(df: pl.DataFrame, model_names: list[str]) -> plt.Fi
382386 # Single legend for the whole figure (top center)
383387 handles , labels = axes [0 ].get_legend_handles_labels ()
384388 fig .legend (handles , labels , loc = "upper center" , ncol = len (model_names ), bbox_to_anchor = (0.5 , 1.02 ))
385- plt .tight_layout (rect = [ 0 , 0 , 1 , 0.95 ] ) # make room for legend on top
389+ plt .tight_layout (rect = ( 0 , 0 , 1 , 0.95 ) ) # make room for legend on top
386390 return fig
387391
388392
@@ -443,10 +447,8 @@ def plot_dumbbell_by_project(
443447 metrics = [c .replace (f"{ model1 } _" , "" ) for c in project_metrics_df .columns if c .startswith (f"{ model1 } _" )]
444448
445449 n_metrics = len (metrics )
446- fig , axes = plt .subplots (1 , n_metrics , figsize = (5 * n_metrics , max (8 , len (project_metrics_df ) * 0.15 )))
447- if n_metrics == 1 :
448- axes = [axes ]
449- axes : list [plt .Axes ] = list (axes )
450+ fig , axes_arr = plt .subplots (1 , n_metrics , figsize = (5 * n_metrics , max (8 , len (project_metrics_df ) * 0.15 )))
451+ axes : list [plt .Axes ] = [axes_arr ] if n_metrics == 1 else list (axes_arr )
450452
451453 # Sort once by pred_GROUP_rate delta, use same order for all subplots
452454 group_rate_col1 = f"{ model1 } _pred_GROUP_rate"
@@ -483,7 +485,7 @@ def plot_dumbbell_by_project(
483485 handles , labels = axes [0 ].get_legend_handles_labels ()
484486 fig .legend (handles , labels , loc = "upper center" , ncol = len (model_names ), bbox_to_anchor = (0.5 , 1.02 ))
485487 fig .suptitle ("Metrics by Project (org_id|project_id)" , fontsize = 14 , y = 1.05 )
486- plt .tight_layout (rect = [ 0 , 0 , 1 , 0.98 ] )
488+ plt .tight_layout (rect = ( 0 , 0 , 1 , 0.98 ) )
487489 return fig
488490
489491
@@ -557,11 +559,13 @@ def compare_models(
557559 # Compute conditional probabilities (reported later)
558560 prod_group = df .filter (pl .col (pred1_col ) == "GROUP" )
559561 prod_separate = df .filter (pl .col (pred1_col ) == "SEPARATE" )
560- p_group_given_group = (prod_group [pred2_col ] == "GROUP" ).mean () if len (prod_group ) > 0 else float ("nan" )
561- p_group_given_separate = (prod_separate [pred2_col ] == "GROUP" ).mean () if len (prod_separate ) > 0 else float ("nan" )
562+ p_group_given_group = float ((prod_group [pred2_col ] == "GROUP" ).mean ()) if len (prod_group ) > 0 else float ("nan" ) # type: ignore[arg-type]
563+ p_group_given_separate = (
564+ float ((prod_separate [pred2_col ] == "GROUP" ).mean ()) if len (prod_separate ) > 0 else float ("nan" ) # type: ignore[arg-type]
565+ )
562566 df_close = df .filter (pl .col ("distance" ) < 0.005 )
563567 close_group = df_close .filter (pl .col (pred1_col ) == "GROUP" )
564- p_close = ( close_group [pred2_col ] == "GROUP" ).mean () if len (close_group ) > 0 else float ("nan" )
568+ p_close = float (( close_group [pred2_col ] == "GROUP" ).mean ()) if len (close_group ) > 0 else float ("nan" ) # type: ignore[arg-type]
565569
566570 # Columns to keep in output
567571 output_cols = [
@@ -589,6 +593,7 @@ def compare_models(
589593 df_sorted = df .sort (["org_id" , "project_id" ])
590594 for (org_id , project_id ), group_df in df_sorted .group_by (["org_id" , "project_id" ], maintain_order = True ):
591595 total_projects += 1
596+ assert output_dir is not None
592597 proj_dir = output_dir / f"org_{ org_id } " / f"project_{ project_id } "
593598
594599 # Compute metrics for each model on this project
@@ -675,7 +680,7 @@ def compare_models(
675680
676681 report ("\n ### Distance distribution\n " )
677682 report (df ["distance" ].describe ())
678- report (f"\n GROUP rate: { ( df ['label' ] == 'GROUP' ).mean ():.2%} " )
683+ report (f"\n GROUP rate: { float (( df ['label' ] == 'GROUP' ).mean ()) :.2%} " ) # type: ignore[arg-type]
679684
680685 platform_stats = (
681686 df .group_by ("platform" )
@@ -754,7 +759,7 @@ def compute_stacktrace_token_percentiles(df: pl.DataFrame) -> pl.DataFrame:
754759
755760 rows = []
756761 for col in token_cols :
757- row = {"metric" : col }
762+ row : dict [ str , Any ] = {"metric" : col }
758763 row ["min" ] = df [col ].min ()
759764 row ["mean" ] = df [col ].mean ()
760765 for p in percentiles :
@@ -896,6 +901,7 @@ def _compute_project_precisions_per_platform(model: str, thresholds_platform: di
896901 )
897902 else :
898903 baseline_key = f"{ baseline_model } @{ baseline_threshold } "
904+ assert isinstance (baseline_threshold , float )
899905 project_precisions [baseline_key ] = _compute_project_precisions (baseline_model , baseline_threshold )
900906 else :
901907 baseline_key = str (thresholds_sorted [0 ])
@@ -916,7 +922,7 @@ def _compute_project_precisions_per_platform(model: str, thresholds_platform: di
916922 {
917923 "platform" : platform ,
918924 "n_projects" : len (prec ),
919- "median_pairs" : int (platform_df ["n_pairs" ].median ()),
925+ "median_pairs" : int (platform_df ["n_pairs" ].median ()), # type: ignore[arg-type]
920926 "mean" : prec .mean (),
921927 "p5" : prec .quantile (0.05 ),
922928 "p10" : prec .quantile (0.10 ),
@@ -971,7 +977,8 @@ def metrics_by_platform(
971977 )
972978
973979 rows = []
974- for (platform ,), platform_df in df_t .group_by ("platform" ):
980+ for (platform_obj ,), platform_df in df_t .group_by ("platform" ):
981+ platform = str (platform_obj )
975982 avg_metrics = _compute_metrics_avg_over_projects (platform_df , model_name )
976983 platform_threshold = threshold .get (platform , threshold ["default" ]) if isinstance (threshold , dict ) else threshold
977984 rows .append (
@@ -1025,12 +1032,15 @@ def find_threshold_by_platform(
10251032 precision_by_platform = min_precision if isinstance (min_precision , dict ) else None
10261033
10271034 rows = []
1028- for (platform ,), platform_df in df .group_by ("platform" ):
1035+ for (platform_obj ,), platform_df in df .group_by ("platform" ):
1036+ platform = str (platform_obj )
10291037 n_pairs = len (platform_df )
10301038 n_projects = platform_df ["project_id" ].n_unique ()
10311039 label_group_rate = (platform_df ["label" ] == "GROUP" ).mean ()
10321040 threshold_found = None
1033- target_precision = precision_by_platform [platform ] if precision_by_platform else min_precision
1041+ target_precision : float = (
1042+ precision_by_platform [platform ] if precision_by_platform else min_precision # type: ignore[assignment]
1043+ )
10341044
10351045 # Walk thresholds from low to high; first one meeting precision is the minimum
10361046 # Precision is averaged over projects to avoid large projects dominating
@@ -1039,11 +1049,11 @@ def find_threshold_by_platform(
10391049 pl .when (pl .col (sim_col ) > thresh ).then (pl .lit ("GROUP" )).otherwise (pl .lit ("SEPARATE" )).alias (pred_col )
10401050 )
10411051 # Compute per-project precision, then average
1042- project_precisions = []
1052+ project_precisions : list [ float ] = []
10431053 for _ , proj_df in df_t .group_by ("project_id" ):
10441054 pred_group = proj_df .filter (pl .col (pred_col ) == "GROUP" )
10451055 if len (pred_group ) > 0 :
1046- project_precisions .append (( pred_group ["label" ] == "GROUP" ).mean ())
1056+ project_precisions .append (float (( pred_group ["label" ] == "GROUP" ).mean ())) # type: ignore[arg-type]
10471057 if not project_precisions :
10481058 continue
10491059 precision = sum (project_precisions ) / len (project_precisions )
@@ -1115,11 +1125,11 @@ def compare_metrics_by_stacktrace_length(
11151125
11161126 # Print metrics for each bucket
11171127 report (f"\n ### Short stacktraces ({ token_col } <= p10 = { p10 :.0f} tokens, { len (short_df )} pairs)\n " )
1118- report (f"label GROUP rate: { ( short_df ['label' ] == 'GROUP' ).mean ():.2%} " )
1128+ report (f"label GROUP rate: { float (( short_df ['label' ] == 'GROUP' ).mean ()) :.2%} " ) # type: ignore[arg-type]
11191129 report (_compute_metrics (short_df , model_names ))
11201130
11211131 report (f"\n ### Long stacktraces ({ token_col } >= p90 = { p90 :.0f} tokens, { len (long_df )} pairs)\n " )
1122- report (f"label GROUP rate: { ( long_df ['label' ] == 'GROUP' ).mean ():.2%} " )
1132+ report (f"label GROUP rate: { float (( long_df ['label' ] == 'GROUP' ).mean ()) :.2%} " ) # type: ignore[arg-type]
11231133 report (_compute_metrics (long_df , model_names ))
11241134
11251135
0 commit comments