@@ -349,7 +349,7 @@ def plot_metrics_by_platform(df: pl.DataFrame, model_names: list[str]) -> plt.Fi
349349 fig , axes = plt .subplots (1 , len (metrics_to_plot ), figsize = (4 * len (metrics_to_plot ), 5 ))
350350 axes : list [plt .Axes ] = list (axes )
351351
352- for ax , metric in zip (axes , metrics_to_plot ):
352+ for ax , metric in zip (axes , metrics_to_plot , strict = True ):
353353 pivot_df = metrics_pd .pivot (index = "platform" , columns = "model" , values = metric )
354354 pivot_df = pivot_df [model_names ] # ensure consistent column order
355355 pivot_df .plot (kind = "bar" , ax = ax , rot = 45 , legend = False , color = MODEL_COLORS )
@@ -388,7 +388,7 @@ def plot_similarity_distribution(
388388 if n == 1 :
389389 axes = [axes ]
390390
391- for ax , platform in zip (axes , platforms ):
391+ for ax , platform in zip (axes , platforms , strict = True ):
392392 data = df .filter (pl .col ("platform" ) == platform )[sim_col ].to_numpy ()
393393 ax .hist (data , bins = bins , edgecolor = "none" , alpha = 0.8 )
394394 ax .set_ylabel (platform , rotation = 0 , labelpad = 60 , ha = "right" )
@@ -434,7 +434,7 @@ def plot_dumbbell_by_project(
434434 ).sort ("_delta" )
435435 y_labels = [f"{ row ['org_id' ]} |{ row ['project_id' ]} " for row in sorted_df .iter_rows (named = True )]
436436
437- for ax , metric in zip (axes , metrics ):
437+ for ax , metric in zip (axes , metrics , strict = True ):
438438 col1 = f"{ model1 } _{ metric } "
439439 col2 = f"{ model2 } _{ metric } "
440440
@@ -443,7 +443,7 @@ def plot_dumbbell_by_project(
443443 y = range (len (sorted_df ))
444444
445445 # Draw lines colored by direction
446- for i , (v1 , v2 ) in enumerate (zip (x1 , x2 )):
446+ for i , (v1 , v2 ) in enumerate (zip (x1 , x2 , strict = True )):
447447 color = "green" if v2 >= v1 else "red"
448448 ax .hlines (y = i , xmin = min (v1 , v2 ), xmax = max (v1 , v2 ), color = color , alpha = 0.6 )
449449
@@ -485,7 +485,7 @@ def compare_models(
485485 used for platforms not explicitly listed.
486486 First key = model1 (baseline), second key = model2 (new model).
487487 output_dir: Directory for writing CSVs. Required if write_csvs is True.
488- min_group_rate_increase: Track projects where model2 GROUP rate is >= this value higher than model1. None to skip .
488+ min_group_rate_increase: Track projects where model2 GROUP rate is >= this value higher than model1. None skips .
489489 min_group_rate_decrease: Track projects where model2 GROUP rate is >= this value lower than model1 (absolute).
490490 E.g., 0.10 means model2 has at least 10pp lower GROUP rate. None to skip.
491491 write_csvs: If True, write new.csv and merged.csv files for each project.
@@ -753,7 +753,7 @@ def compute_stacktrace_token_percentiles(df: pl.DataFrame) -> pl.DataFrame:
753753def sweep_thresholds (
754754 df : pl .DataFrame ,
755755 model_name : str ,
756- thresholds : list [float ] = [ 0.80 , 0.85 , 0.87 , 0.90 ] ,
756+ thresholds : list [float ] | None = None ,
757757) -> pl .DataFrame :
758758 """
759759 Show metrics for a single model at multiple similarity thresholds.
@@ -766,6 +766,8 @@ def sweep_thresholds(
766766 Returns:
767767 DataFrame with one row per threshold and metric columns.
768768 """
769+ if thresholds is None :
770+ thresholds = [0.80 , 0.85 , 0.87 , 0.90 ]
769771 sim_col = f"cos_sim_{ model_name } "
770772 rows = []
771773 for thresh in thresholds :
@@ -788,7 +790,7 @@ def sweep_thresholds(
788790def sweep_thresholds_by_project (
789791 df : pl .DataFrame ,
790792 model_name : str ,
791- thresholds : list [float ] = [ 0.80 , 0.85 , 0.87 , 0.90 ] ,
793+ thresholds : list [float ] | None = None ,
792794 precision_floor : float = 0.8 ,
793795 harm_threshold : float = 0.05 ,
794796 thresholds_platform : dict [str , float ] | None = None ,
@@ -811,8 +813,8 @@ def sweep_thresholds_by_project(
811813 baseline_threshold: Threshold for the baseline model. Can be a float or a
812814 per-platform dict (with a "default" key), same format as thresholds_platform.
813815 """
814- sim_col = f"cos_sim_ { model_name } "
815- pred_col = f"pred_ { model_name } "
816+ if thresholds is None :
817+ thresholds = [ 0.80 , 0.85 , 0.87 , 0.90 ]
816818 thresholds_sorted = sorted (thresholds , reverse = True )
817819
818820 def _compute_project_precisions (model : str , threshold : float ) -> pl .DataFrame :
0 commit comments