diff --git a/rfdetr/detr.py b/rfdetr/detr.py index 03d1f65..42b9c48 100644 --- a/rfdetr/detr.py +++ b/rfdetr/detr.py @@ -114,6 +114,18 @@ def train_from_config(self, config: TrainConfig, **kwargs): callbacks=self.callbacks, ) + complete_config = { + "train_config": config.model_dump(), + "model_config": self.model_config.model_dump(), + "model_config_type": self.model_config.__class__.__name__, + "effective_training_params": all_kwargs, + "class_names": self.model.class_names, + "num_classes": len(self.model.class_names), + } + + with open(os.path.join(config.output_dir, "training_config.json"), "w") as f: + json.dump(complete_config, f, indent=2) + def get_train_config(self, **kwargs): return TrainConfig(**kwargs)