Skip to content

Commit 6dbc384

Browse files
committed
bm25
1 parent 417057d commit 6dbc384

11 files changed

Lines changed: 520 additions & 5 deletions

File tree

.envrc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@ dotenv
1010

1111
gcloud config set project "$GOOGLE_CLOUD_PROJECT"
1212

13+
if ! gcloud auth print-access-token > /dev/null 2>&1; then
14+
printf "gcloud credentials need refresh. Run 'gcloud auth login' now? [Y/n] " > /dev/tty
15+
read -r answer < /dev/tty
16+
if [ "$answer" != "n" ] && [ "$answer" != "N" ]; then
17+
gcloud auth login < /dev/tty
18+
else
19+
echo "Aborting .envrc — gcloud auth required" > /dev/tty
20+
exit 1
21+
fi
22+
fi
23+
1324
WANDB_API_KEY=$(gcloud secrets versions access latest --secret=wandb-api-key)
1425
export WANDB_API_KEY
1526
: "${WANDB_API_KEY:?upload to GCP Secret Manager as 'wandb-api-key' (wandb is free) — the remote startup script fetches from there too}"

decisions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ way to select from the combinatorial explosion of triplets without dropping some
4545
relationships.
4646

4747
In general, non-pairwise losses coerce a jagged but rich similarity structure into a rectangular one. The data already
48-
intentionally contains hard positives and negatives. Pairwise losses put their faith in the data and accomodate the
48+
intentionally contains hard positives and hard negatives. Pairwise losses put their faith in the data and accomodate the
4949
jagged structure by melting it into a rectangular one.
5050

5151
Pairwise losses do have downsides. The statistical one is that we don't have many negatives per positive. To softly

eval/compare.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,16 @@ def compare_metrics_by_stacktrace_length(
11321132
COLUMNS_ANONYMIZED_DENYLIST = ("path",)
11331133

11341134

1135+
def _parse_threshold_list(value: str | None) -> list[float] | None:
1136+
"""Parse a comma-separated list of floats, e.g. "10,15,20,25" -> [10.0, 15.0, 20.0, 25.0].
1137+
1138+
None / empty string returns None (caller falls back to the function's default).
1139+
"""
1140+
if not value:
1141+
return None
1142+
return [float(part.strip()) for part in value.split(",") if part.strip()]
1143+
1144+
11351145
def _parse_threshold(value: str) -> float | dict[str, float]:
11361146
"""Parse a threshold CLI argument.
11371147
@@ -1306,6 +1316,8 @@ def _main(
13061316
dim_model2: int = 768,
13071317
threshold_model1: str = "0.99",
13081318
threshold_model2: str = "0.90",
1319+
sweep_thresholds_model1: str | None = None,
1320+
sweep_thresholds_model2: str | None = None,
13091321
min_group_rate_increase: float = 0.15,
13101322
min_group_rate_decrease: float = 0.10,
13111323
max_display_projects: int = 30,
@@ -1385,13 +1397,18 @@ def _main(
13851397

13861398
report("\n## Threshold sweep\n")
13871399

1388-
# Threshold sweep for model2
1389-
sweep_thresholds(df, name_model2)
1400+
sweep_list_model1 = _parse_threshold_list(sweep_thresholds_model1)
1401+
sweep_list_model2 = _parse_threshold_list(sweep_thresholds_model2)
1402+
sweep_lists_by_name = {name_model1: sweep_list_model1, name_model2: sweep_list_model2}
1403+
1404+
sweep_thresholds(df, name_model1, thresholds=sweep_list_model1)
1405+
sweep_thresholds(df, name_model2, thresholds=sweep_list_model2)
13901406
threshold1_parsed = thresholds[name_model1]
13911407
threshold2_parsed = thresholds[name_model2]
13921408
sweep_thresholds_by_project(
13931409
df,
13941410
name_model2,
1411+
thresholds=sweep_list_model2,
13951412
thresholds_platform=threshold2_parsed if isinstance(threshold2_parsed, dict) else None,
13961413
baseline_model=name_model1,
13971414
baseline_threshold=threshold1_parsed,
@@ -1405,7 +1422,7 @@ def _main(
14051422

14061423
# Find minimum threshold per platform for each model
14071424
for name in [name_model1, name_model2]:
1408-
find_threshold_by_platform(df, name)
1425+
find_threshold_by_platform(df, name, thresholds=sweep_lists_by_name[name])
14091426

14101427
fig = plot_metrics_by_platform(result.df, result.model_names)
14111428
fig.savefig(dir_output / "metrics_by_platform.png", dpi=150, bbox_inches="tight")
@@ -1509,6 +1526,8 @@ def main(
15091526
dim_model2: int = 768,
15101527
threshold_model1: str = "0.99",
15111528
threshold_model2: str = "0.90",
1529+
sweep_thresholds_model1: str | None = None,
1530+
sweep_thresholds_model2: str | None = None,
15121531
min_group_rate_increase: float = 0.15,
15131532
min_group_rate_decrease: float = 0.10,
15141533
max_display_projects: int = 30,
@@ -1542,6 +1561,11 @@ def main(
15421561
or comma-separated platform=value pairs (e.g. "default=0.92,cocoa=0.80,node=0.90").
15431562
threshold_model2
15441563
Cosine similarity threshold for model 2. Same format as threshold_model1.
1564+
sweep_thresholds_model1
1565+
Comma-separated thresholds to sweep for model 1's per-threshold metrics and per-platform threshold finder,
1566+
e.g. "0.95,0.97,0.99". Override the cosine-range defaults when model 1's score isn't in [0, 1].
1567+
sweep_thresholds_model2
1568+
Comma-separated thresholds to sweep for model 2. Same format as sweep_thresholds_model1.
15451569
min_group_rate_increase
15461570
Flag projects where model2 GROUP rate exceeds model1 by at least this amount.
15471571
min_group_rate_decrease

0 commit comments

Comments
 (0)