Skip to content

Commit 7df9a94

Browse files
authored
Merge pull request #59 from ForgeOpus/claude/modelforge-notebook-cli-Zp0ef
Claude/modelforge notebook cli zp0ef
2 parents 312cabc + 16180e8 commit 7df9a94

45 files changed

Lines changed: 1344 additions & 1710 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

ModelForge/cli.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,31 @@ def check_huggingface_login():
5050
def main():
5151
"""
5252
Main entry point for ModelForge CLI.
53+
54+
Subcommands:
55+
modelforge — start the web server
56+
modelforge cli — launch the interactive CLI wizard
5357
"""
58+
if len(sys.argv) > 1:
59+
subcommand = sys.argv[1]
60+
if subcommand == "cli":
61+
try:
62+
from .notebook_cli.wizard import main as cli_main
63+
except ImportError as e:
64+
logger.error(f"Failed to import CLI wizard or its dependencies: {e}")
65+
print("\nThe interactive CLI requires optional dependencies that are not installed.")
66+
print("Please install the CLI extras and try again, for example:")
67+
print(" pip install \"modelforge-finetuning[cli]\"")
68+
sys.exit(1)
69+
cli_main()
70+
return
71+
else:
72+
print(f"Unknown subcommand: '{subcommand}'")
73+
print("Usage:")
74+
print(" modelforge Start the web server")
75+
print(" modelforge cli Launch the interactive CLI wizard")
76+
sys.exit(1)
77+
5478
print("\n" + "=" * 80)
5579
print(" __ __ _ _ _____ ")
5680
print(" | \\/ | | | | | ___| ")

ModelForge/cli_old.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

ModelForge/evaluation/dataset_validator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,10 @@ def validate_dataset(
8787
if len(required_fields) > 0:
8888
for i, example in enumerate(dataset.select(range(min(10, len(dataset))))):
8989
for field in required_fields:
90-
if not example.get(field):
91-
logger.warning(
92-
f"Example {i} has empty field '{field}'"
90+
value = example.get(field)
91+
if value is None or (isinstance(value, str) and not value.strip()):
92+
raise DatasetValidationError(
93+
f"Example {i} has empty required field '{field}'."
9394
)
9495

9596
logger.info(

ModelForge/evaluation/metrics.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,27 @@ def compute_causal_lm_metrics(eval_pred: EvalPrediction) -> Dict[str, float]:
2929
"""
3030
logger.info("Computing causal LM metrics")
3131

32+
import torch
33+
import torch.nn.functional as F
34+
3235
predictions, labels = eval_pred.predictions, eval_pred.label_ids
3336

34-
# For language modeling, predictions are logits
35-
# Loss is computed by the trainer, we just calculate perplexity
36-
if hasattr(eval_pred, 'loss') and eval_pred.loss is not None:
37-
loss = eval_pred.loss
38-
else:
39-
# Fallback: estimate loss from predictions
40-
loss = 0.0
37+
logits = torch.from_numpy(np.array(predictions)).float()
38+
lbl = torch.from_numpy(np.array(labels))
39+
40+
shift_logits = logits[..., :-1, :].contiguous()
41+
shift_labels = lbl[..., 1:].contiguous()
4142

42-
perplexity = np.exp(loss) if loss > 0 else 0.0
43+
loss = F.cross_entropy(
44+
shift_logits.view(-1, shift_logits.size(-1)),
45+
shift_labels.view(-1).long(),
46+
ignore_index=-100,
47+
reduction="mean",
48+
)
49+
perplexity = float(torch.exp(loss))
4350

4451
metrics = {
45-
"perplexity": float(perplexity),
52+
"perplexity": perplexity,
4653
"eval_loss": float(loss),
4754
}
4855

@@ -69,6 +76,9 @@ def compute_seq2seq_metrics(eval_pred: EvalPrediction, tokenizer: Any = None) ->
6976

7077
predictions, labels = eval_pred.predictions, eval_pred.label_ids
7178

79+
if predictions.ndim == 3:
80+
predictions = np.argmax(predictions, axis=-1)
81+
7282
# Decode predictions and labels
7383
if tokenizer is not None:
7484
# Replace -100 in labels (used for padding)
@@ -88,9 +98,9 @@ def compute_seq2seq_metrics(eval_pred: EvalPrediction, tokenizer: Any = None) ->
8898
)
8999

90100
metrics = {
91-
"rouge1": float(result["rouge1"].mid.fmeasure),
92-
"rouge2": float(result["rouge2"].mid.fmeasure),
93-
"rougeL": float(result["rougeL"].mid.fmeasure),
101+
"rouge1": float(result["rouge1"]),
102+
"rouge2": float(result["rouge2"]),
103+
"rougeL": float(result["rougeL"]),
94104
}
95105

96106
logger.info(f"Seq2Seq metrics: {metrics}")
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
ModelForge Notebook CLI — public API.
3+
4+
Provides an interactive wizard-style interface that calls the ModelForge
5+
Python APIs directly, without going through the REST/web layer.
6+
7+
Install the required extras with:
8+
pip install modelforge-finetuning[cli]
9+
10+
Usage from a notebook cell:
11+
from ModelForge.notebook_cli import run_cli
12+
run_cli()
13+
14+
Usage from a terminal:
15+
modelforge-nb
16+
"""
17+
from .wizard import ModelForgeWizard, main as run_cli
18+
19+
__all__ = ["ModelForgeWizard", "run_cli"]
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""
2+
Notebook-aware progress display for ModelForge CLI.
3+
4+
Automatically uses ipywidgets when running inside a Jupyter/Colab/VSCode
5+
notebook environment, and falls back to tqdm for plain terminal use.
6+
"""
7+
from __future__ import annotations
8+
9+
10+
def _is_notebook() -> bool:
11+
"""Return True if running inside a Jupyter-compatible notebook kernel."""
12+
try:
13+
from IPython import get_ipython
14+
shell = get_ipython()
15+
if shell is None:
16+
return False
17+
return shell.__class__.__name__ == "ZMQInteractiveShell"
18+
except ImportError:
19+
return False
20+
21+
22+
class ProgressDisplay:
23+
"""
24+
Unified progress display that works in notebooks and terminals.
25+
26+
Usage::
27+
28+
display = ProgressDisplay()
29+
display.start("Loading model...")
30+
display.update(50, "Halfway there")
31+
display.update(100, "Done!")
32+
display.close()
33+
"""
34+
35+
def __init__(self):
36+
self._notebook = _is_notebook()
37+
self._bar = None
38+
self._label = None
39+
self._tqdm_bar = None
40+
self._last_value = 0
41+
42+
def start(self, message: str = "Starting...") -> None:
43+
"""Initialise and display the progress bar."""
44+
self._last_value = 0
45+
if self._notebook:
46+
self._start_widget(message)
47+
else:
48+
self._start_tqdm(message)
49+
50+
def _start_widget(self, message: str) -> None:
51+
try:
52+
import ipywidgets as widgets
53+
from IPython.display import display
54+
55+
self._bar = widgets.IntProgress(
56+
value=0,
57+
min=0,
58+
max=100,
59+
description="",
60+
bar_style="info",
61+
layout=widgets.Layout(width="100%"),
62+
)
63+
self._label = widgets.Label(value=message)
64+
box = widgets.VBox([self._label, self._bar])
65+
display(box)
66+
except ImportError:
67+
# ipywidgets not available — fall through to tqdm
68+
self._notebook = False
69+
self._start_tqdm(message)
70+
71+
def _start_tqdm(self, message: str) -> None:
72+
from tqdm import tqdm
73+
74+
self._tqdm_bar = tqdm(
75+
total=100,
76+
desc=message,
77+
bar_format="{desc}: {percentage:3.0f}%|{bar}| [{elapsed}]",
78+
dynamic_ncols=True,
79+
)
80+
81+
def update(self, value: int, message: str = "") -> None:
82+
"""
83+
Update progress.
84+
85+
Args:
86+
value: Progress value 0-100.
87+
message: Status message to display alongside the bar.
88+
"""
89+
value = max(0, min(100, value))
90+
if self._notebook:
91+
if self._bar is not None:
92+
self._bar.value = value
93+
if value >= 100:
94+
self._bar.bar_style = "success"
95+
if self._label is not None and message:
96+
self._label.value = message
97+
else:
98+
if self._tqdm_bar is not None:
99+
delta = value - self._last_value
100+
if delta > 0:
101+
self._tqdm_bar.update(delta)
102+
if message:
103+
self._tqdm_bar.set_description(message)
104+
self._last_value = value
105+
106+
def close(self) -> None:
107+
"""Close/finalize the progress display."""
108+
if self._notebook:
109+
if self._bar is not None:
110+
self._bar.bar_style = "success"
111+
else:
112+
if self._tqdm_bar is not None:
113+
self._tqdm_bar.close()
114+
self._tqdm_bar = None
115+
116+
def error(self) -> None:
117+
"""Mark the progress bar as errored."""
118+
if self._notebook:
119+
if self._bar is not None:
120+
self._bar.bar_style = "danger"
121+
if self._label is not None:
122+
self._label.value = "Training failed."
123+
else:
124+
if self._tqdm_bar is not None:
125+
self._tqdm_bar.set_description("Training failed")
126+
self._tqdm_bar.close()
127+
self._tqdm_bar = None

0 commit comments

Comments
 (0)