Skip to content

Commit ef002ea

Browse files
committed
temp test uv
1 parent afea89c commit ef002ea

19 files changed

Lines changed: 3733 additions & 83 deletions

.github/workflows/ci.yml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
8+
concurrency:
9+
group: ${{ github.workflow }}-${{ github.ref }}
10+
cancel-in-progress: true
11+
12+
jobs:
13+
lint:
14+
runs-on: ubuntu-latest
15+
steps:
16+
- uses: actions/checkout@v4
17+
- uses: astral-sh/setup-uv@v5
18+
with:
19+
enable-cache: true
20+
- run: uv tool run ruff check .
21+
- run: uv tool run ruff format --check .
22+
23+
test:
24+
runs-on: ubuntu-latest
25+
steps:
26+
- uses: actions/checkout@v4
27+
- uses: astral-sh/setup-uv@v5
28+
with:
29+
enable-cache: true
30+
python-version: "3.13"
31+
- run: uv sync --extra dev
32+
- uses: actions/cache@v4
33+
with:
34+
path: ~/.cache/huggingface
35+
key: hf-${{ runner.os }}-gte-modernbert-base
36+
- run: uv run pytest -q

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.13

bin/_startup.sh

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
#!/bin/bash
22
# Sets up the environment for grouping-trainer instances.
3+
# Normally invoked as root via GCP instance startup. To step through manually
4+
# after SSH'ing in, run `sudo -i` first so $HOME=/root and paths line up.
35
set -euo pipefail
46

5-
apt-get update -y && apt-get install -y python3.12-venv
7+
# Install uv (manages its own Python; respects .python-version in the repo).
8+
curl -LsSf https://astral.sh/uv/install.sh | sh
9+
export PATH="$HOME/.local/bin:$PATH"
610

711
REPO_DIR="/root/grouping-trainer"
812

@@ -18,16 +22,15 @@ cd "$REPO_DIR"
1822
mkdir -p lightonai/modernbert-embed-large
1923
gcloud storage cp -r gs://grouping-data/base_models/lightonai/modernbert-embed-large/* lightonai/modernbert-embed-large
2024

21-
python3.12 -m venv .venv
22-
# shellcheck disable=SC1091
23-
source .venv/bin/activate
24-
pip install --upgrade pip
25-
pip install -e .
25+
uv sync --locked
2626

2727
gcloud storage cp -r gs://grouping-data/final_csvs/ .
2828

29-
# Auto-cd into the repo and activate the venv on `sudo -i`.
30-
echo "cd $REPO_DIR && source .venv/bin/activate" >> /root/.bashrc
29+
# Auto-cd into the repo, put uv on PATH, and activate the venv on `sudo -i`.
30+
{
31+
echo "export PATH=\"\$HOME/.local/bin:\$PATH\""
32+
echo "cd $REPO_DIR && source .venv/bin/activate"
33+
} >> /root/.bashrc
3134

3235
# screen -S run
3336
# ctrl+a d
@@ -46,4 +49,5 @@ if [ -n "$COMMAND" ]; then
4649
eval "$COMMAND" >>"$LOG_FILE" 2>&1 || true
4750
shutdown -h now
4851
fi
49-
# To follow the log: `sudo tail -f /var/log/grouping_trainer_run.log`
52+
# To follow the log:
53+
# sudo tail -f /var/log/grouping_trainer_run.log

bin/set_up_local.sh

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,5 @@
22
set -eu
33

44
direnv allow
5-
python3.13 -m venv .venv
6-
# shellcheck source=/dev/null
7-
. .venv/bin/activate
8-
python -m pip install -e ".[dev,sheets]"
9-
pre-commit install
5+
uv sync --extra dev --extra sheets
6+
uv run pre-commit install

eval/acc_across_dims.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@
293293
"name": "python",
294294
"nbconvert_exporter": "python",
295295
"pygments_lexer": "ipython3",
296-
"version": "3.13.12"
296+
"version": "3.13.1"
297297
}
298298
},
299299
"nbformat": 4,

eval/compare.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def plot_metrics_by_platform(df: pl.DataFrame, model_names: list[str]) -> plt.Fi
349349
fig, axes = plt.subplots(1, len(metrics_to_plot), figsize=(4 * len(metrics_to_plot), 5))
350350
axes: list[plt.Axes] = list(axes)
351351

352-
for ax, metric in zip(axes, metrics_to_plot):
352+
for ax, metric in zip(axes, metrics_to_plot, strict=True):
353353
pivot_df = metrics_pd.pivot(index="platform", columns="model", values=metric)
354354
pivot_df = pivot_df[model_names] # ensure consistent column order
355355
pivot_df.plot(kind="bar", ax=ax, rot=45, legend=False, color=MODEL_COLORS)
@@ -388,7 +388,7 @@ def plot_similarity_distribution(
388388
if n == 1:
389389
axes = [axes]
390390

391-
for ax, platform in zip(axes, platforms):
391+
for ax, platform in zip(axes, platforms, strict=True):
392392
data = df.filter(pl.col("platform") == platform)[sim_col].to_numpy()
393393
ax.hist(data, bins=bins, edgecolor="none", alpha=0.8)
394394
ax.set_ylabel(platform, rotation=0, labelpad=60, ha="right")
@@ -434,7 +434,7 @@ def plot_dumbbell_by_project(
434434
).sort("_delta")
435435
y_labels = [f"{row['org_id']}|{row['project_id']}" for row in sorted_df.iter_rows(named=True)]
436436

437-
for ax, metric in zip(axes, metrics):
437+
for ax, metric in zip(axes, metrics, strict=True):
438438
col1 = f"{model1}_{metric}"
439439
col2 = f"{model2}_{metric}"
440440

@@ -443,7 +443,7 @@ def plot_dumbbell_by_project(
443443
y = range(len(sorted_df))
444444

445445
# Draw lines colored by direction
446-
for i, (v1, v2) in enumerate(zip(x1, x2)):
446+
for i, (v1, v2) in enumerate(zip(x1, x2, strict=True)):
447447
color = "green" if v2 >= v1 else "red"
448448
ax.hlines(y=i, xmin=min(v1, v2), xmax=max(v1, v2), color=color, alpha=0.6)
449449

@@ -485,7 +485,7 @@ def compare_models(
485485
used for platforms not explicitly listed.
486486
First key = model1 (baseline), second key = model2 (new model).
487487
output_dir: Directory for writing CSVs. Required if write_csvs is True.
488-
min_group_rate_increase: Track projects where model2 GROUP rate is >= this value higher than model1. None to skip.
488+
min_group_rate_increase: Track projects where model2 GROUP rate is >= this value higher than model1. None skips.
489489
min_group_rate_decrease: Track projects where model2 GROUP rate is >= this value lower than model1 (absolute).
490490
E.g., 0.10 means model2 has at least 10pp lower GROUP rate. None to skip.
491491
write_csvs: If True, write new.csv and merged.csv files for each project.
@@ -753,7 +753,7 @@ def compute_stacktrace_token_percentiles(df: pl.DataFrame) -> pl.DataFrame:
753753
def sweep_thresholds(
754754
df: pl.DataFrame,
755755
model_name: str,
756-
thresholds: list[float] = [0.80, 0.85, 0.87, 0.90],
756+
thresholds: list[float] | None = None,
757757
) -> pl.DataFrame:
758758
"""
759759
Show metrics for a single model at multiple similarity thresholds.
@@ -766,6 +766,8 @@ def sweep_thresholds(
766766
Returns:
767767
DataFrame with one row per threshold and metric columns.
768768
"""
769+
if thresholds is None:
770+
thresholds = [0.80, 0.85, 0.87, 0.90]
769771
sim_col = f"cos_sim_{model_name}"
770772
rows = []
771773
for thresh in thresholds:
@@ -788,7 +790,7 @@ def sweep_thresholds(
788790
def sweep_thresholds_by_project(
789791
df: pl.DataFrame,
790792
model_name: str,
791-
thresholds: list[float] = [0.80, 0.85, 0.87, 0.90],
793+
thresholds: list[float] | None = None,
792794
precision_floor: float = 0.8,
793795
harm_threshold: float = 0.05,
794796
thresholds_platform: dict[str, float] | None = None,
@@ -811,8 +813,8 @@ def sweep_thresholds_by_project(
811813
baseline_threshold: Threshold for the baseline model. Can be a float or a
812814
per-platform dict (with a "default" key), same format as thresholds_platform.
813815
"""
814-
sim_col = f"cos_sim_{model_name}"
815-
pred_col = f"pred_{model_name}"
816+
if thresholds is None:
817+
thresholds = [0.80, 0.85, 0.87, 0.90]
816818
thresholds_sorted = sorted(thresholds, reverse=True)
817819

818820
def _compute_project_precisions(model: str, threshold: float) -> pl.DataFrame:

eval/export_for_db.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def export_for_load_test(
220220
export_for_load_test(sys.argv[1], **kwargs)
221221
else:
222222
print(
223-
f"Usage: python {sys.argv[0]} <similarities_dir> [--load-test [keep_fraction_candidates=X] [keep_fraction_queries=X]]"
223+
f"Usage: python {sys.argv[0]} <similarities_dir> "
224+
f"[--load-test [keep_fraction_candidates=X] [keep_fraction_queries=X]]"
224225
)
225226
sys.exit(1)

profile_dataloading.ipynb

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
},
1818
{
1919
"cell_type": "code",
20-
"execution_count": 1,
20+
"execution_count": null,
2121
"id": "34b40b8b",
2222
"metadata": {},
2323
"outputs": [],
2424
"source": [
2525
"import time\n",
2626
"from collections import Counter\n",
27+
"from functools import wraps\n",
2728
"\n",
2829
"import numpy as np\n",
2930
"import seaborn as sns\n",
@@ -46,13 +47,14 @@
4647
},
4748
{
4849
"cell_type": "code",
49-
"execution_count": 3,
50+
"execution_count": null,
5051
"id": "e8b10dc8",
5152
"metadata": {},
5253
"outputs": [],
5354
"source": [
5455
"def record_times(times: list[float]):\n",
5556
" def decorator(func):\n",
57+
" @wraps(func)\n",
5658
" def wrapper(*args, **kwargs):\n",
5759
" start_time = time.monotonic()\n",
5860
" result = func(*args, **kwargs)\n",
@@ -190,46 +192,25 @@
190192
},
191193
{
192194
"cell_type": "code",
193-
"execution_count": 11,
195+
"execution_count": null,
194196
"id": "410a876f",
195197
"metadata": {},
196-
"outputs": [
197-
{
198-
"data": {
199-
"application/vnd.jupyter.widget-view+json": {
200-
"model_id": "42310f0811c14a47afde208aa9d7a941",
201-
"version_major": 2,
202-
"version_minor": 0
203-
},
204-
"text/plain": [
205-
" 0%| | 0/13130 [00:00<?, ?it/s]"
206-
]
207-
},
208-
"metadata": {},
209-
"output_type": "display_data"
210-
},
211-
{
212-
"name": "stdout",
213-
"output_type": "stream",
214-
"text": [
215-
"CPU times: user 37min 10s, sys: 1min 53s, total: 39min 3s\n",
216-
"Wall time: 7min 4s\n"
217-
]
218-
}
219-
],
198+
"outputs": [],
220199
"source": [
221200
"%%time\n",
222201
"train_dataloader = trainer.get_train_dataloader()\n",
223202
"for batch in tqdm(train_dataloader, total=len(train_dataloader)):\n",
224-
" for sub_batch_idx, sub_batch in enumerate(\n",
225-
" gt.train.batch_pairs_by_token_budget(batch, token_budget=training_config_full.per_device_token_budget)\n",
203+
" num_sub_batches = 0\n",
204+
" for sub_batch in gt.train.batch_pairs_by_token_budget(\n",
205+
" batch, token_budget=training_config_full.per_device_token_budget\n",
226206
" ):\n",
227207
" encodings = preprocess(\n",
228208
" trainer.model.encoder,\n",
229209
" sub_batch[\"query_stacktrace_string\"],\n",
230210
" sub_batch[\"candidate_stacktrace_string\"],\n",
231211
" )\n",
232-
" num_sub_batches_per_batch.append(sub_batch_idx + 1)"
212+
" num_sub_batches += 1\n",
213+
" num_sub_batches_per_batch.append(num_sub_batches)"
233214
]
234215
},
235216
{
@@ -370,7 +351,7 @@
370351
"name": "python",
371352
"nbconvert_exporter": "python",
372353
"pygments_lexer": "ipython3",
373-
"version": "3.13.12"
354+
"version": "3.13.1"
374355
}
375356
},
376357
"nbformat": 4,

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ dependencies = [
2424
"typed-argument-parser==1.11.0",
2525
"wandb==0.23.1",
2626
]
27-
requires-python = ">=3.12"
27+
requires-python = ">=3.13"
2828
authors = [
2929
{ name = "Kush Dubey", email = "kushdubey63@gmail.com" },
3030
]
@@ -61,7 +61,7 @@ indent-width = 4
6161
extend-include = ["*.ipynb"]
6262

6363
[tool.ruff.lint]
64-
select = ["I"]
64+
select = ["I", "F", "E", "B", "UP"]
6565

6666
[project.urls]
6767
Homepage = "https://github.com/getsentry/grouping-trainer"

src/grouping_trainer/compiled.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ def compile_and_warm_up(self):
178178

179179
@_set_float32_matmul_precision(_COMPILED_MATMUL_PRECISION)
180180
def forward(self, input: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
181-
# Only use the compiled forward if the sequence length matches one of our buckets. If we used the compiled forward
182-
# for one that doesn't hit the bucket, we create a new CUDA graph for every unique sequence length above
181+
# Only use the compiled forward if the sequence length matches one of our buckets. If we used the compiled
182+
# forward for one that doesn't hit the bucket, we create a new CUDA graph for every unique sequence length above
183183
# 2048, which thrashes the cache.
184184

185185
if self.training:

0 commit comments

Comments
 (0)