[onert/python] Introduce eval_step#15198
Conversation
| def train_step(self, | ||
| inputs: List[np.ndarray], | ||
| expecteds: List[np.ndarray], | ||
| train: bool = True) -> Dict[str, Any]: |
There was a problem hiding this comment.
Just a question,
if train is false, what is the operation of this function (train_step)?
There was a problem hiding this comment.
It acts like an “eval step", we can get losses and metrics but the model’s weights stay unchanged :)
There was a problem hiding this comment.
It's a little weird to me, because train function acts like a eval() function in some case.
But I'm not sure this is normal in these kind of codes.
There was a problem hiding this comment.
It might be a bit confusing to have the train_step() function behave differently based on the train flag. In standard practice, it's better to separate the concerns by defining two distinct functions:
-
train_step(): This is for the training phase, where weights are updated. -
eval_step(): This function is for inference or validation, where no weight updates happen and you likely want to disable gradients (usingtorch.no_grad()).
This mirrors how frameworks like PyTorch and PyTorch Lightning handle the training and evaluation phases. In PyTorch, model.train() and model.eval() toggle the mode for training and inference, respectively, and it's common to have separate functions for each phase. Instead of toggling behavior with a flag inside train_step(), it's cleaner to have separate functions to keep things clear and consistent.
There was a problem hiding this comment.
I modified this PR to introduce eval_step() instead of adding the train flag.
2f505d0 to
c704dc7
Compare
train flag to train_stepeval_step
c704dc7 to
d957418
Compare
This commit introduce `eval_step` and common batch logic into `_batch_step`. - Introduce private helper `_batch_step(inputs, expecteds, update_weights)` to handle binding inputs, running session, collecting losses/metrics, and timing - Update `train_step` to call `_batch_step(..., update_weights=True)` and expose `"train_time"` - Add `eval_step` for forward‑only inference, calling `_batch_step(..., update_weights=False)` and exposing `"eval_time"` ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com>
d957418 to
ed6eb05
Compare
This commit introduce
eval_stepand common batch logic into_batch_step._batch_step(inputs, expecteds, update_weights)to handle binding inputs, running session, collecting losses/metrics, and timingtrain_stepto call_batch_step(..., update_weights=True)and expose"train_time"eval_stepfor forward‑only inference, calling_batch_step(..., update_weights=False)and exposing"eval_time"For #15172 (comment)
ONE-DCO-1.0-Signed-off-by: ragmani ragmani0216@gmail.com