Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions runtime/onert/api/python/package/experimental/train/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,39 @@ def train_step(self, inputs: List[np.ndarray],
dict: Loss and metrics values, and train_time in ms.
"""
if self.optimizer is None or self.loss is None:
raise RuntimeError(
"The training session is not properly configured. "
"Please call `compile(optimizer, loss)` before calling `train_step()`.")
raise RuntimeError("Call `compile()` before `train_step()`")
result = self._batch_step(inputs, expecteds, update_weights=True)
result["train_time"] = result.pop("time_ms")
return result

def eval_step(self, inputs: List[np.ndarray],
expecteds: List[np.ndarray]) -> Dict[str, Any]:
"""
Run one evaluation batch: forward only (no weight update).

Args:
inputs (list of np.ndarray): Inputs for this batch.
expecteds (list of np.ndarray): Ground‑truth outputs for this batch.

Returns:
dict: {
"loss": list of float losses,
"metrics": dict of metric_name -> value,
"eval_time": float (ms)
}
"""
if self.optimizer is None or self.loss is None:
raise RuntimeError("Call `compile()` before `eval_step()`")
result = self._batch_step(inputs, expecteds, update_weights=False)
result["eval_time"] = result.pop("time_ms")
return result

def _batch_step(self, inputs: List[np.ndarray], expecteds: List[np.ndarray],
update_weights: bool) -> Dict[str, Any]:
"""
Common logic for one batch: bind data, run, collect loss & metrics.
Returns a dict with keys "loss", "metrics", "time_ms".
"""
# Validate batch sizes
self._check_batch_size(inputs, self.train_info.batch_size, "input")
self._check_batch_size(expecteds, self.train_info.batch_size, "expected")
Expand All @@ -282,22 +311,17 @@ def train_step(self, inputs: List[np.ndarray],

# Run a single training step
t_start: float = time.perf_counter()
self.session.train(update_weights=True)
self.session.train(update_weights=update_weights)
t_end: float = time.perf_counter()

# Update loss
losses: List[float] = [
self.session.train_get_loss(i) for i in range(len(expecteds))
]
losses = [self.session.train_get_loss(i) for i in range(len(expecteds))]

# Update metrics
metric_results: Dict[str, float] = {}
for metric in self.metrics:
metric.update_state(outputs, expecteds)
metric_results[metric.__class__.__name__] = metric.result()

return {
"loss": losses,
"metrics": metric_results,
"train_time": (t_end - t_start) * 1000
}
metrics: Dict[str, float] = {}
for m in self.metrics:
m.update_state(outputs, expecteds)
metrics[m.__class__.__name__] = m.result()
m.reset_state()

return {"loss": losses, "metrics": metrics, "time_ms": (t_end - t_start) * 1000}