Skip to content

Commit 2f505d0

Browse files
committed
[onert/python] Add train flag to train_step
This commit adds `train` flag to train_step to toggle weight updates. - Introduce `train: bool = True` parameter in `train_step` - Pass `update_weights=train` to `session.train()` ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com>
1 parent 2c9ff76 commit 2f505d0

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

  • runtime/onert/api/python/package/experimental/train

runtime/onert/api/python/package/experimental/train/session.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,14 +246,17 @@ def _check_batch_size(self,
246246
f"{data_type} batch size mismatch at index {idx}: "
247247
f"shape[0] = {arr.shape[0]} vs batch size = {batch_size}")
248248

249-
def train_step(self, inputs: List[np.ndarray],
250-
expecteds: List[np.ndarray]) -> Dict[str, Any]:
249+
def train_step(self,
250+
inputs: List[np.ndarray],
251+
expecteds: List[np.ndarray],
252+
train: bool = True) -> Dict[str, Any]:
251253
"""
252-
Train the model for a single batch.
254+
Train (or evaluate) the model for a single batch.
253255
254256
Args:
255257
inputs (list of np.ndarray): List of input arrays for the batch.
256258
expecteds (list of np.ndarray): List of expected output arrays for the batch.
259+
train (bool): If True, update weights; if False, run only forward & metrics.
257260
258261
Returns:
259262
dict: Loss and metrics values, and train_time in ms.
@@ -282,7 +285,7 @@ def train_step(self, inputs: List[np.ndarray],
282285

283286
# Run a single training step
284287
t_start: float = time.perf_counter()
285-
self.session.train(update_weights=True)
288+
self.session.train(update_weights=train)
286289
t_end: float = time.perf_counter()
287290

288291
# Update loss

0 commit comments

Comments
 (0)