Skip to content

Commit 0a60211

Browse files
authored
feat: add suggest_integrations flag to toggle integration suggestions in Trainer logs (default True) (#21632)
feat: add `suggest_integrations` flag to `Trainer` for optional integration suggestions
1 parent 4a548c9 commit 0a60211

5 files changed

Lines changed: 33 additions & 2 deletions

File tree

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
### Added
1212

13+
- Added `suggest_integrations` flag to `Trainer` to control whether optional integration suggestions (e.g., litmodels, litlogger) are shown in logs ([#21632](https://github.com/Lightning-AI/pytorch-lightning/pull/21632))
14+
1315
- Added `log_key_prefix` parameter to `LearningRateMonitor` callback for prefixing logged metric names ([#21612](https://github.com/Lightning-AI/pytorch-lightning/issues/21612))
1416

1517
### Changed

src/lightning/pytorch/trainer/connectors/callback_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def _maybe_show_litmodels_tip(self) -> None:
175175
This is called after loggers are set up, so we can reliably check for LitLogger.
176176
177177
"""
178-
if not self._pending_litmodels_tip:
178+
if not self._pending_litmodels_tip or not self.trainer.suggest_integrations:
179179
return
180180
self._pending_litmodels_tip = False # Only show once
181181

src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> Non
8787
else:
8888
self.trainer.loggers = [logger]
8989

90-
if not any(isinstance(logger, LitLogger) for logger in self.trainer.loggers):
90+
if (
91+
not any(isinstance(logger, LitLogger) for logger in self.trainer.loggers)
92+
and self.trainer.suggest_integrations
93+
):
9194
rank_zero_info(
9295
"💡 Tip: For seamless cloud logging and experiment tracking,"
9396
" try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger,"

src/lightning/pytorch/trainer/trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(
132132
default_root_dir: Optional[_PATH] = None,
133133
enable_autolog_hparams: bool = True,
134134
model_registry: Optional[str] = None,
135+
suggest_integrations: bool = True,
135136
) -> None:
136137
r"""Customize every aspect of training via flags.
137138
@@ -308,6 +309,10 @@ def __init__(
308309
309310
model_registry: The name of the model being uploaded to Model hub.
310311
312+
suggest_integrations: Whether to display suggestions for optional Lightning integrations.
313+
Default: ``True``.
314+
315+
311316
Raises:
312317
TypeError:
313318
If ``gradient_clip_val`` is not an int or float.
@@ -324,6 +329,7 @@ def __init__(
324329

325330
# remove version if accidentally passed
326331
self._model_registry = model_registry.split(":")[0] if model_registry else None
332+
self.suggest_integrations = suggest_integrations
327333

328334
self.barebones = barebones
329335
if barebones:

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,3 +2138,23 @@ def test_expand_home_trainer():
21382138
assert trainer.default_root_dir == str(home_root / "trainer")
21392139
trainer = Trainer(default_root_dir=Path("~/trainer"))
21402140
assert trainer.default_root_dir == str(home_root / "trainer")
2141+
2142+
2143+
@pytest.mark.parametrize("suggest_integrations", [True, False])
2144+
def test_trainer_integration_suggestions(tmp_path, caplog, suggest_integrations):
2145+
caplog.set_level("INFO", logger="lightning.pytorch.utilities.rank_zero")
2146+
2147+
trainer = Trainer(
2148+
default_root_dir=tmp_path,
2149+
max_epochs=1,
2150+
enable_progress_bar=False,
2151+
suggest_integrations=suggest_integrations,
2152+
)
2153+
2154+
trainer.fit(BoringModel())
2155+
2156+
messages = [r.getMessage() for r in caplog.records if r.name == "lightning.pytorch.utilities.rank_zero"]
2157+
2158+
has_suggestion = any("litmodels" in m or "litlogger" in m for m in messages)
2159+
2160+
assert has_suggestion == suggest_integrations

0 commit comments

Comments
 (0)