Skip to content

How to pass non primitive data to a custom hook #1365

Open
@FrancescoSaverioZuppichini

Description

Hi guys,

I hope you are doing great! So, I have the following custom hook

from logging import log
from pprint import pprint
from mmcv.runner.hooks import HOOKS
from mmcv.runner.hooks import TextLoggerHook
from gust import TrainHandler


@HOOKS.register_module()
class TrainHandlerCallback(TextLoggerHook):
    def __init__(self, train_handler: TrainHandler, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.train_handler = train_handler

    def _dump_log(self, log_dict, runner):
        pprint(log_dict)
        mode = log_dict.get("mode", None)
        if mode == "val":
            self.train_handler.log_metric("train/loss", log_dict["loss"])

And I need to set train_handler but the config system doesn't support non primitive data, I've tried to inject my config using the merge dict option

"log_config.hooks": [
              dict(type="TrainHandlerCallback", train_handler=train_handler)
          ],

But yeah, as expected got a "serialization" error:

  File "/opt/conda/lib/python3.7/ast.py", line 35, in parse
    return compile(source, filename, mode, PyCF_ONLY_AST)
  File "<unknown>", line 267
    train_handler=<gust.train_handler.TrainHandler object at 0x7f4bfbf4dd50>)])

The doc unfortunately is not helping.

Thanks

Francesco

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions