Skip to content

Commit 7333547

Browse files
committed
mypy
1 parent 57b090e commit 7333547

19 files changed

Lines changed: 405 additions & 119 deletions

.github/workflows/ci.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@ jobs:
2020
- run: uv tool run ruff check .
2121
- run: uv tool run ruff format --check .
2222

23+
typecheck:
24+
runs-on: ubuntu-latest
25+
steps:
26+
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
27+
- uses: astral-sh/setup-uv@eac588ad8def6316056a12d4907a9d4d84ff7a3b # v7.3.0
28+
with:
29+
enable-cache: true
30+
python-version: "3.13"
31+
- run: uv sync --extra dev
32+
- run: uv run mypy
33+
2334
test:
2435
runs-on: ubuntu-latest
2536
steps:

.pre-commit-config.yaml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
repos:
2-
# Ruff lint + format. Config lives in pyproject.toml.
32
- repo: https://github.com/astral-sh/ruff-pre-commit
43
rev: v0.15.6
54
hooks:
@@ -17,3 +16,13 @@ repos:
1716
- id: end-of-file-fixer
1817
- id: trailing-whitespace
1918
- id: debug-statements
19+
20+
- repo: local
21+
hooks:
22+
- id: mypy
23+
name: mypy
24+
entry: uv run mypy
25+
language: system
26+
types: [python]
27+
pass_filenames: false
28+
require_serial: true

.vscode/extensions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"recommendations": [
3+
"charliermarsh.ruff",
4+
"ms-python.mypy-type-checker"
5+
]
6+
}

.vscode/settings.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@
1515
"rewrap.autoWrap.enabled": true,
1616
"notebook.formatOnSave.enabled": true,
1717
"notebook.defaultFormatter": "charliermarsh.ruff",
18-
"python.defaultInterpreterPath": ".venv/bin/python"
18+
"python.defaultInterpreterPath": ".venv/bin/python",
19+
"mypy-type-checker.reportingScope": "workspace",
20+
"mypy-type-checker.importStrategy": "fromEnvironment",
21+
"mypy-type-checker.preferDaemon": true
1922
}

benchmark/compare_compiled_uncompiled.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import numpy as np
2727
import polars as pl
28+
import torch
2829
from sentence_transformers.util import pairwise_cos_sim
2930
from tap import tapify
3031

@@ -53,7 +54,7 @@ def _load_side(dir_local: Path) -> tuple[pl.DataFrame, np.ndarray, np.ndarray]:
5354

5455

5556
def _cos_sim_truncated(a: np.ndarray, b: np.ndarray, dim: int) -> np.ndarray:
56-
return pairwise_cos_sim(a[..., :dim], b[..., :dim]).detach().cpu().numpy()
57+
return pairwise_cos_sim(torch.as_tensor(a[..., :dim]), torch.as_tensor(b[..., :dim])).detach().cpu().numpy()
5758

5859

5960
def _percentiles(x: np.ndarray) -> dict[str, float]:
@@ -249,8 +250,8 @@ def main(
249250

250251
n_pairs = len(df_joined)
251252

252-
def _fmt_count_pct(n: int) -> str:
253-
return f"{n:,} ({n / n_pairs * 100:.3g}%)"
253+
def _fmt_count_pct(n: int | float) -> str:
254+
return f"{int(n):,} ({n / n_pairs * 100:.3g}%)"
254255

255256
df_threshold = pl.DataFrame(
256257
{

benchmark/report.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ def _write_report(
7474
path_out: Path,
7575
) -> None:
7676
speedup = df["time_base_sec"] / df["time_compiled_sec"]
77-
p10 = float(speedup.quantile(0.1))
78-
p50 = float(speedup.median())
79-
p90 = float(speedup.quantile(0.9))
80-
fraction_wins = float((speedup >= 1.0).mean())
77+
p10 = float(speedup.quantile(0.1)) # type: ignore[arg-type]
78+
p50 = float(speedup.median()) # type: ignore[arg-type]
79+
p90 = float(speedup.quantile(0.9)) # type: ignore[arg-type]
80+
fraction_wins = float((speedup >= 1.0).mean()) # type: ignore[arg-type]
8181

82-
median_compiled_ms = float(df["time_compiled_sec"].median()) * 1000
83-
median_base_ms = float(df["time_base_sec"].median()) * 1000
82+
median_compiled_ms = float(df["time_compiled_sec"].median()) * 1000 # type: ignore[arg-type]
83+
median_base_ms = float(df["time_base_sec"].median()) * 1000 # type: ignore[arg-type]
8484

8585
df_worst = (
8686
df.with_columns(speedup=pl.col("time_base_sec") / pl.col("time_compiled_sec"))

eval/compare.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from dataclasses import dataclass
4747
from itertools import zip_longest
4848
from pathlib import Path
49+
from typing import Any
4950
from unittest.mock import patch
5051

5152
import gspread
@@ -313,7 +314,7 @@ def _compute_metrics_for_model(df: pl.DataFrame, model_name: str) -> dict:
313314
}
314315

315316

316-
def _compute_metrics_avg_over_projects(df: pl.DataFrame, model_name: str) -> dict:
317+
def _compute_metrics_avg_over_projects(df: pl.DataFrame, model_name: str) -> dict[str, float]:
317318
"""Compute metrics averaged over projects so large projects don't dominate."""
318319
metrics_per_project = []
319320
for _, df_project in df.group_by("project_id"):
@@ -350,14 +351,16 @@ def plot_metrics_by_platform(df: pl.DataFrame, model_names: list[str]) -> plt.Fi
350351
metrics_to_plot = None
351352
for (platform,), platform_df in df.group_by("platform"):
352353
for model_name in model_names:
353-
project_metrics_list = []
354+
project_metrics_list: list[dict[str, Any]] = []
354355
for _, proj_df in platform_df.group_by("project_id"):
355356
project_metrics_list.append(_compute_metrics_for_model(proj_df, model_name))
356357
if metrics_to_plot is None:
357358
metrics_to_plot = list(project_metrics_list[0].keys())
358359
avg_metrics = {
359-
k: sum(m[k] for m in project_metrics_list if m[k] == m[k])
360-
/ sum(1 for m in project_metrics_list if m[k] == m[k])
360+
k: (
361+
sum(m[k] for m in project_metrics_list if m[k] == m[k])
362+
/ sum(1 for m in project_metrics_list if m[k] == m[k])
363+
)
361364
for k in project_metrics_list[0]
362365
}
363366
avg_metrics["platform"] = platform
@@ -367,9 +370,10 @@ def plot_metrics_by_platform(df: pl.DataFrame, model_names: list[str]) -> plt.Fi
367370
metrics_df = pl.DataFrame(metrics_rows)
368371

369372
# Convert to pandas and pivot for plotting
373+
assert metrics_to_plot is not None, "No platforms in df"
370374
metrics_pd = metrics_df.to_pandas()
371-
fig, axes = plt.subplots(1, len(metrics_to_plot), figsize=(4 * len(metrics_to_plot), 5))
372-
axes: list[plt.Axes] = list(axes)
375+
fig, axes_arr = plt.subplots(1, len(metrics_to_plot), figsize=(4 * len(metrics_to_plot), 5))
376+
axes: list[plt.Axes] = list(axes_arr)
373377

374378
for ax, metric in zip(axes, metrics_to_plot, strict=True):
375379
pivot_df = metrics_pd.pivot(index="platform", columns="model", values=metric)
@@ -382,7 +386,7 @@ def plot_metrics_by_platform(df: pl.DataFrame, model_names: list[str]) -> plt.Fi
382386
# Single legend for the whole figure (top center)
383387
handles, labels = axes[0].get_legend_handles_labels()
384388
fig.legend(handles, labels, loc="upper center", ncol=len(model_names), bbox_to_anchor=(0.5, 1.02))
385-
plt.tight_layout(rect=[0, 0, 1, 0.95]) # make room for legend on top
389+
plt.tight_layout(rect=(0, 0, 1, 0.95)) # make room for legend on top
386390
return fig
387391

388392

@@ -443,10 +447,8 @@ def plot_dumbbell_by_project(
443447
metrics = [c.replace(f"{model1}_", "") for c in project_metrics_df.columns if c.startswith(f"{model1}_")]
444448

445449
n_metrics = len(metrics)
446-
fig, axes = plt.subplots(1, n_metrics, figsize=(5 * n_metrics, max(8, len(project_metrics_df) * 0.15)))
447-
if n_metrics == 1:
448-
axes = [axes]
449-
axes: list[plt.Axes] = list(axes)
450+
fig, axes_arr = plt.subplots(1, n_metrics, figsize=(5 * n_metrics, max(8, len(project_metrics_df) * 0.15)))
451+
axes: list[plt.Axes] = [axes_arr] if n_metrics == 1 else list(axes_arr)
450452

451453
# Sort once by pred_GROUP_rate delta, use same order for all subplots
452454
group_rate_col1 = f"{model1}_pred_GROUP_rate"
@@ -483,7 +485,7 @@ def plot_dumbbell_by_project(
483485
handles, labels = axes[0].get_legend_handles_labels()
484486
fig.legend(handles, labels, loc="upper center", ncol=len(model_names), bbox_to_anchor=(0.5, 1.02))
485487
fig.suptitle("Metrics by Project (org_id|project_id)", fontsize=14, y=1.05)
486-
plt.tight_layout(rect=[0, 0, 1, 0.98])
488+
plt.tight_layout(rect=(0, 0, 1, 0.98))
487489
return fig
488490

489491

@@ -557,11 +559,13 @@ def compare_models(
557559
# Compute conditional probabilities (reported later)
558560
prod_group = df.filter(pl.col(pred1_col) == "GROUP")
559561
prod_separate = df.filter(pl.col(pred1_col) == "SEPARATE")
560-
p_group_given_group = (prod_group[pred2_col] == "GROUP").mean() if len(prod_group) > 0 else float("nan")
561-
p_group_given_separate = (prod_separate[pred2_col] == "GROUP").mean() if len(prod_separate) > 0 else float("nan")
562+
p_group_given_group = float((prod_group[pred2_col] == "GROUP").mean()) if len(prod_group) > 0 else float("nan") # type: ignore[arg-type]
563+
p_group_given_separate = (
564+
float((prod_separate[pred2_col] == "GROUP").mean()) if len(prod_separate) > 0 else float("nan") # type: ignore[arg-type]
565+
)
562566
df_close = df.filter(pl.col("distance") < 0.005)
563567
close_group = df_close.filter(pl.col(pred1_col) == "GROUP")
564-
p_close = (close_group[pred2_col] == "GROUP").mean() if len(close_group) > 0 else float("nan")
568+
p_close = float((close_group[pred2_col] == "GROUP").mean()) if len(close_group) > 0 else float("nan") # type: ignore[arg-type]
565569

566570
# Columns to keep in output
567571
output_cols = [
@@ -589,6 +593,7 @@ def compare_models(
589593
df_sorted = df.sort(["org_id", "project_id"])
590594
for (org_id, project_id), group_df in df_sorted.group_by(["org_id", "project_id"], maintain_order=True):
591595
total_projects += 1
596+
assert output_dir is not None
592597
proj_dir = output_dir / f"org_{org_id}" / f"project_{project_id}"
593598

594599
# Compute metrics for each model on this project
@@ -675,7 +680,7 @@ def compare_models(
675680

676681
report("\n### Distance distribution\n")
677682
report(df["distance"].describe())
678-
report(f"\nGROUP rate: {(df['label'] == 'GROUP').mean():.2%}")
683+
report(f"\nGROUP rate: {float((df['label'] == 'GROUP').mean()):.2%}") # type: ignore[arg-type]
679684

680685
platform_stats = (
681686
df.group_by("platform")
@@ -754,7 +759,7 @@ def compute_stacktrace_token_percentiles(df: pl.DataFrame) -> pl.DataFrame:
754759

755760
rows = []
756761
for col in token_cols:
757-
row = {"metric": col}
762+
row: dict[str, Any] = {"metric": col}
758763
row["min"] = df[col].min()
759764
row["mean"] = df[col].mean()
760765
for p in percentiles:
@@ -896,6 +901,7 @@ def _compute_project_precisions_per_platform(model: str, thresholds_platform: di
896901
)
897902
else:
898903
baseline_key = f"{baseline_model}@{baseline_threshold}"
904+
assert isinstance(baseline_threshold, float)
899905
project_precisions[baseline_key] = _compute_project_precisions(baseline_model, baseline_threshold)
900906
else:
901907
baseline_key = str(thresholds_sorted[0])
@@ -916,7 +922,7 @@ def _compute_project_precisions_per_platform(model: str, thresholds_platform: di
916922
{
917923
"platform": platform,
918924
"n_projects": len(prec),
919-
"median_pairs": int(platform_df["n_pairs"].median()),
925+
"median_pairs": int(platform_df["n_pairs"].median()), # type: ignore[arg-type]
920926
"mean": prec.mean(),
921927
"p5": prec.quantile(0.05),
922928
"p10": prec.quantile(0.10),
@@ -971,7 +977,8 @@ def metrics_by_platform(
971977
)
972978

973979
rows = []
974-
for (platform,), platform_df in df_t.group_by("platform"):
980+
for (platform_obj,), platform_df in df_t.group_by("platform"):
981+
platform = str(platform_obj)
975982
avg_metrics = _compute_metrics_avg_over_projects(platform_df, model_name)
976983
platform_threshold = threshold.get(platform, threshold["default"]) if isinstance(threshold, dict) else threshold
977984
rows.append(
@@ -1025,12 +1032,15 @@ def find_threshold_by_platform(
10251032
precision_by_platform = min_precision if isinstance(min_precision, dict) else None
10261033

10271034
rows = []
1028-
for (platform,), platform_df in df.group_by("platform"):
1035+
for (platform_obj,), platform_df in df.group_by("platform"):
1036+
platform = str(platform_obj)
10291037
n_pairs = len(platform_df)
10301038
n_projects = platform_df["project_id"].n_unique()
10311039
label_group_rate = (platform_df["label"] == "GROUP").mean()
10321040
threshold_found = None
1033-
target_precision = precision_by_platform[platform] if precision_by_platform else min_precision
1041+
target_precision: float = (
1042+
precision_by_platform[platform] if precision_by_platform else min_precision # type: ignore[assignment]
1043+
)
10341044

10351045
# Walk thresholds from low to high; first one meeting precision is the minimum
10361046
# Precision is averaged over projects to avoid large projects dominating
@@ -1039,11 +1049,11 @@ def find_threshold_by_platform(
10391049
pl.when(pl.col(sim_col) > thresh).then(pl.lit("GROUP")).otherwise(pl.lit("SEPARATE")).alias(pred_col)
10401050
)
10411051
# Compute per-project precision, then average
1042-
project_precisions = []
1052+
project_precisions: list[float] = []
10431053
for _, proj_df in df_t.group_by("project_id"):
10441054
pred_group = proj_df.filter(pl.col(pred_col) == "GROUP")
10451055
if len(pred_group) > 0:
1046-
project_precisions.append((pred_group["label"] == "GROUP").mean())
1056+
project_precisions.append(float((pred_group["label"] == "GROUP").mean())) # type: ignore[arg-type]
10471057
if not project_precisions:
10481058
continue
10491059
precision = sum(project_precisions) / len(project_precisions)
@@ -1115,11 +1125,11 @@ def compare_metrics_by_stacktrace_length(
11151125

11161126
# Print metrics for each bucket
11171127
report(f"\n### Short stacktraces ({token_col} <= p10 = {p10:.0f} tokens, {len(short_df)} pairs)\n")
1118-
report(f"label GROUP rate: {(short_df['label'] == 'GROUP').mean():.2%}")
1128+
report(f"label GROUP rate: {float((short_df['label'] == 'GROUP').mean()):.2%}") # type: ignore[arg-type]
11191129
report(_compute_metrics(short_df, model_names))
11201130

11211131
report(f"\n### Long stacktraces ({token_col} >= p90 = {p90:.0f} tokens, {len(long_df)} pairs)\n")
1122-
report(f"label GROUP rate: {(long_df['label'] == 'GROUP').mean():.2%}")
1132+
report(f"label GROUP rate: {float((long_df['label'] == 'GROUP').mean()):.2%}") # type: ignore[arg-type]
11231133
report(_compute_metrics(long_df, model_names))
11241134

11251135

eval/eval_poller.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def evaluate_baseline(
154154
return
155155

156156
logger.info("Evaluating base model.")
157+
loss: gt.loss.PairwiseLoss
157158
if loss_type == "sigmoid":
158159
loss = gt.loss.SigmoidPairwiseLoss()
159160
elif loss_type == "contrastive":

eval/save_embeddings.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,28 +178,36 @@ def main(
178178
text_prefix=text_prefix,
179179
)
180180
logger.info(f"{st_class.__name__} loaded in {time.monotonic() - start:.1f}s")
181-
if use_compiled:
181+
if isinstance(model, gt.compiled.SentenceTransformer):
182182
model.compile_and_warm_up()
183183
else:
184184
_ = model.encode("warm up")
185185
logger.info(f"{st_class.__name__} loading and warming up done in {time.monotonic() - start:.1f}s")
186186

187187
logger.info("Encoding queries")
188188
texts_query = df["query_stacktrace_string"].to_list()
189-
embeddings_query: np.ndarray = model.encode(
189+
embeddings_query = model.encode(
190190
texts_query, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True
191191
)
192192
logger.info("Encoding candidates")
193193
texts_candidate = df["candidate_stacktrace_string"].to_list()
194-
embeddings_candidate: np.ndarray = model.encode(
194+
embeddings_candidate = model.encode(
195195
texts_candidate, batch_size=batch_size, convert_to_numpy=True, show_progress_bar=True
196196
)
197197

198198
if truncate_dims is None:
199199
truncate_dims = (embeddings_query.shape[-1],)
200200

201201
for dim in truncate_dims:
202-
cos_sims = pairwise_cos_sim(embeddings_query[..., :dim], embeddings_candidate[..., :dim]).detach().cpu().numpy()
202+
cos_sims = (
203+
pairwise_cos_sim(
204+
torch.as_tensor(embeddings_query[..., :dim]),
205+
torch.as_tensor(embeddings_candidate[..., :dim]),
206+
)
207+
.detach()
208+
.cpu()
209+
.numpy()
210+
)
203211
df = df.with_columns(pl.Series(name=f"cos_sim_{dim}", values=cos_sims))
204212

205213
with tempfile.TemporaryDirectory() as dir_tmp_output:

eval/save_gemini_embeddings.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import numpy as np
2222
import polars as pl
23+
import torch
2324
from google import genai
2425
from google.genai import types
2526
from sentence_transformers.util import pairwise_cos_sim
@@ -148,7 +149,15 @@ def main(
148149
embeddings_candidate = all_embeddings[n_q:]
149150

150151
for dim in truncate_dims:
151-
cos_sims = pairwise_cos_sim(embeddings_query[..., :dim], embeddings_candidate[..., :dim]).detach().cpu().numpy()
152+
cos_sims = (
153+
pairwise_cos_sim(
154+
torch.as_tensor(embeddings_query[..., :dim]),
155+
torch.as_tensor(embeddings_candidate[..., :dim]),
156+
)
157+
.detach()
158+
.cpu()
159+
.numpy()
160+
)
152161
df = df.with_columns(pl.Series(name=f"cos_sim_{dim}", values=cos_sims))
153162

154163
with tempfile.TemporaryDirectory() as dir_tmp_output:

0 commit comments

Comments
 (0)