44from typing_extensions import override
55
66import lightning as L
7- from lightning .pytorch .loggers import MLFlowLogger
7+ from lightning .pytorch .loggers import TensorBoardLogger , WandbLogger
88from mlflow .entities import Metric , RunTag , Param
99from mlflow .tracking import MlflowClient # type: ignore[possibly-unbound-import]
1010
1111if TYPE_CHECKING :
12- from lightning . pytorch . loggers import MLFlowLogger
12+ pass
1313
14- from fkat .pytorch .loggers import LightningLogger
14+ from fkat .pytorch .loggers import LightningLogger , _is_logger_type
1515from fkat .utils import assert_not_none
1616from fkat .utils .logging import rank0_logger
1717from fkat .utils .mlflow import broadcast_mlflow_run_id , mlflow_logger
1818
1919log = rank0_logger (__name__ )
2020
2121
22- class MLFlowCallbackLogger :
22+ class MLFlowCallbackLogger ( LightningLogger ) :
2323 """
2424 Mlflow logger class that supports distributed logging of tags, metrics and artifacts.
2525
@@ -86,6 +86,69 @@ def log_artifact(self, local_path: str, artifact_path: str | None = None) -> Non
8686 )
8787
8888
89+ class TensorBoardCallbackLogger (LightningLogger ):
90+ """TensorBoard logger for distributed logging."""
91+
92+ def __init__ (self , logger : TensorBoardLogger ) -> None :
93+ self ._logger = logger
94+
95+ def log_tag (self , key : str , value : str ) -> None :
96+ self ._logger .experiment .add_text (key , value )
97+
98+ def tags (self ) -> dict [str , Any ]:
99+ return {}
100+
101+ def log_batch (
102+ self ,
103+ metrics : dict [str , float ] | None = None ,
104+ params : dict [str , Any ] | None = None ,
105+ tags : dict [str , str ] | None = None ,
106+ timestamp : int | None = None ,
107+ step : int | None = None ,
108+ ) -> None :
109+ if metrics :
110+ for k , v in metrics .items ():
111+ self ._logger .experiment .add_scalar (k , v , step )
112+ if tags :
113+ for k , v in tags .items ():
114+ self ._logger .experiment .add_text (k , v , step )
115+
116+ def log_artifact (self , local_path : str , artifact_path : str | None = None ) -> None :
117+ pass
118+
119+
120+ class WandbCallbackLogger (LightningLogger ):
121+ """WandB logger for distributed logging."""
122+
123+ def __init__ (self , logger : WandbLogger ) -> None :
124+ self ._logger = logger
125+
126+ def log_tag (self , key : str , value : str ) -> None :
127+ self ._logger .experiment .config .update ({key : value })
128+
129+ def tags (self ) -> dict [str , Any ]:
130+ return dict (self ._logger .experiment .config )
131+
132+ def log_batch (
133+ self ,
134+ metrics : dict [str , float ] | None = None ,
135+ params : dict [str , Any ] | None = None ,
136+ tags : dict [str , str ] | None = None ,
137+ timestamp : int | None = None ,
138+ step : int | None = None ,
139+ ) -> None :
140+ log_dict = {}
141+ if metrics :
142+ log_dict .update (metrics )
143+ if tags :
144+ log_dict .update (tags )
145+ if log_dict :
146+ self ._logger .experiment .log (log_dict , step = step )
147+
148+ def log_artifact (self , local_path : str , artifact_path : str | None = None ) -> None :
149+ self ._logger .experiment .save (local_path )
150+
151+
89152class CallbackLogger (LightningLogger ):
90153 """
91154 A wrapper on top of the collection of Logger instances,
@@ -104,9 +167,13 @@ class CallbackLogger(LightningLogger):
104167 def __init__ (self , trainer : "L.Trainer | None" , loggers : list [LightningLogger ] | None = None ) -> None :
105168 if trainer :
106169 self .loggers = []
107- for logger in trainer .logger if isinstance ( trainer . logger , list ) else [ trainer . logger ] :
108- if isinstance (logger , MLFlowLogger ):
170+ for logger in trainer .loggers :
171+ if _is_logger_type (logger , " MLFlowLogger" ):
109172 self .loggers .append (MLFlowCallbackLogger (trainer = trainer ))
173+ elif _is_logger_type (logger , "TensorBoardLogger" ):
174+ self .loggers .append (TensorBoardCallbackLogger (logger = logger )) # type: ignore[arg-type]
175+ elif _is_logger_type (logger , "WandbLogger" ):
176+ self .loggers .append (WandbCallbackLogger (logger = logger )) # type: ignore[arg-type]
110177 else :
111178 assert loggers
112179 self .loggers = loggers
0 commit comments