Skip to content

Commit e14fb1c

Browse files
authored
Fix/wandb checkpoints (#229)
* fix: save output as files not artifacts * ci: add mypy check of examples Co-authored-by: Albert Sawczyn <[email protected]>
1 parent ed24367 commit e14fb1c

6 files changed

+62
-47
lines changed

embeddings/pipeline/lightning_pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,16 @@ def run(self, run_name: Optional[str] = None) -> EvaluationResult:
4848
self._save_artifacts()
4949
model_result = self.model.execute(data=self.datamodule, run_name=run_name)
5050
result = self.evaluator.evaluate(model_result)
51-
self._finish_logging(run_name)
51+
self._finish_logging()
5252
return result
5353

5454
def _save_artifacts(self) -> None:
5555
srsly.write_json(self.output_path.joinpath("packages.json"), get_installed_packages())
5656

57-
def _finish_logging(self, run_name: Optional[str] = None) -> None:
57+
def _finish_logging(self) -> None:
5858
if self.logging_config.use_wandb():
5959
logger = WandbWrapper()
60-
logger.log_output(self.output_path, run_name)
60+
logger.log_output(
61+
self.output_path, ignore={"wandb", "csv", "tensorboard", "checkpoints"}
62+
)
6163
logger.finish_logging()

embeddings/pipeline/pipelines_metadata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ class FlairClassificationEvaluationPipelineMetadata(FlairEvaluationPipelineMetad
6363
class LightningPipelineMetadata(EmbeddingPipelineBaseMetadata):
6464
embedding_name_or_path: T_path
6565
dataset_name_or_path: T_path
66-
input_column_name: Union[str, Sequence[str]]
6766
target_column_name: str
6867
config: LightningConfig
6968
devices: Optional[Union[List[int], str, int]]
@@ -74,10 +73,11 @@ class LightningPipelineMetadata(EmbeddingPipelineBaseMetadata):
7473

7574

7675
class LightningClassificationPipelineMetadata(LightningPipelineMetadata):
77-
pass
76+
input_column_name: Union[str, Sequence[str]]
7877

7978

8079
class LightningSequenceLabelingPipelineMetadata(LightningPipelineMetadata):
80+
input_column_name: str
8181
evaluation_mode: EvaluationMode
8282
tagging_scheme: Optional[TaggingScheme]
8383

embeddings/utils/loggers.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from dataclasses import dataclass, field
55
from pathlib import Path
6-
from typing import Any, Dict, List, Optional, Union
6+
from typing import Any, Dict, Iterable, List, Optional, Union
77

88
import wandb
99
from pytorch_lightning import loggers as pl_loggers
@@ -129,12 +129,14 @@ def finish_logging(self) -> None:
129129

130130

131131
class WandbWrapper(ExperimentLogger):
132-
def log_output(self, output_path: T_path, run_name: Optional[str] = None) -> None:
133-
wandb.log_artifact(
134-
str(output_path),
135-
name=run_name,
136-
type="output",
137-
)
132+
def log_output(
133+
self,
134+
output_path: T_path,
135+
ignore: Optional[Iterable[str]] = None,
136+
) -> None:
137+
for entry in os.scandir(output_path):
138+
if not ignore or entry.name not in ignore:
139+
wandb.save(entry.path, output_path)
138140

139141
def finish_logging(self) -> None:
140142
wandb.finish()

examples/hps_lightning_document_classification.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
import typer
44

5+
from embeddings.config.lighting_config_space import LightingTextClassificationConfigSpace
56
from embeddings.defaults import RESULTS_PATH
6-
from embeddings.hyperparameter_search.lighting_configspace import (
7-
LightingTextClassificationConfigSpace,
8-
)
97
from embeddings.pipeline.lightning_classification import LightningClassificationPipeline
108
from embeddings.pipeline.lightning_hps_pipeline import OptimizedLightingClassificationPipeline
9+
from embeddings.utils.loggers import LightningLoggingConfig
1110
from embeddings.utils.utils import build_output_path
1211

1312
app = typer.Typer()
@@ -32,20 +31,11 @@ def run(
3231
wandb: bool = typer.Option(False, help="Flag for using wandb."),
3332
tensorboard: bool = typer.Option(False, help="Flag for using tensorboard."),
3433
csv: bool = typer.Option(False, help="Flag for using csv."),
35-
wandb_project: Optional[str] = typer.Option(None, help="Name of wandb project."),
34+
tracking_project_name: Optional[str] = typer.Option(None, help="Name of wandb project."),
3635
wandb_entity: Optional[str] = typer.Option(None, help="Name of entity project"),
3736
) -> None:
3837
if not run_name:
3938
run_name = embedding_name_or_path
40-
41-
logging_kwargs = {
42-
"use_tensorboard": tensorboard,
43-
"use_wandb": wandb,
44-
"use_csv": csv,
45-
"wandb_project": wandb_project,
46-
"wandb_entity": wandb_entity,
47-
}
48-
4939
output_path = build_output_path(root, embedding_name_or_path, dataset_name)
5040
config_space = LightingTextClassificationConfigSpace(
5141
embedding_name_or_path=embedding_name_or_path,
@@ -55,15 +45,30 @@ def run(
5545
dataset_name_or_path=dataset_name,
5646
input_column_name=input_column_name,
5747
target_column_name=target_column_name,
58-
logging_kwargs=logging_kwargs,
48+
logging_config=LightningLoggingConfig.from_flags(
49+
wandb=wandb,
50+
tensorboard=tensorboard,
51+
csv=csv,
52+
tracking_project_name=tracking_project_name,
53+
wandb_entity=wandb_entity,
54+
),
5955
n_trials=n_trials,
6056
).persisting(
6157
best_params_path=output_path.joinpath("best_params.yaml"),
6258
log_path=output_path.joinpath("hps_log.pickle"),
6359
)
6460
df, metadata = pipeline.run(run_name=f"search-{run_name}")
61+
del pipeline
6562

66-
pipeline = LightningClassificationPipeline(
67-
output_path=output_path, logging_kwargs=logging_kwargs, **metadata
63+
metadata["output_path"] = output_path
64+
retrain_pipeline = LightningClassificationPipeline(
65+
logging_config=LightningLoggingConfig.from_flags(
66+
wandb=wandb,
67+
tensorboard=tensorboard,
68+
csv=csv,
69+
tracking_project_name=tracking_project_name,
70+
wandb_entity=wandb_entity,
71+
),
72+
**metadata,
6873
)
69-
pipeline.run(run_name=f"best-params-retrain-{run_name}")
74+
retrain_pipeline.run(run_name=f"best-params-retrain-{run_name}")

examples/hps_lightning_sequence_labeling.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
import typer
44

5+
from embeddings.config.lighting_config_space import LightingSequenceLabelingConfigSpace
56
from embeddings.defaults import RESULTS_PATH
6-
from embeddings.hyperparameter_search.lighting_configspace import (
7-
LightingSequenceLabelingConfigSpace,
8-
)
97
from embeddings.pipeline.lightning_hps_pipeline import OptimizedLightingSequenceLabelingPipeline
108
from embeddings.pipeline.lightning_sequence_labeling import LightningSequenceLabelingPipeline
9+
from embeddings.utils.loggers import LightningLoggingConfig
1110
from embeddings.utils.utils import build_output_path
1211

1312
app = typer.Typer()
@@ -32,20 +31,11 @@ def run(
3231
wandb: bool = typer.Option(False, help="Flag for using wandb."),
3332
tensorboard: bool = typer.Option(False, help="Flag for using tensorboard."),
3433
csv: bool = typer.Option(False, help="Flag for using csv."),
35-
wandb_project: Optional[str] = typer.Option(None, help="Name of wandb project."),
34+
tracking_project_name: Optional[str] = typer.Option(None, help="Name of wandb project."),
3635
wandb_entity: Optional[str] = typer.Option(None, help="Name of entity project"),
3736
) -> None:
3837
if not run_name:
3938
run_name = embedding_name_or_path
40-
41-
logging_kwargs = {
42-
"use_tensorboard": tensorboard,
43-
"use_wandb": wandb,
44-
"use_csv": csv,
45-
"wandb_project": wandb_project,
46-
"wandb_entity": wandb_entity,
47-
}
48-
4939
output_path = build_output_path(root, embedding_name_or_path, dataset_name)
5040
config_space = LightingSequenceLabelingConfigSpace(
5141
embedding_name_or_path=embedding_name_or_path,
@@ -55,15 +45,30 @@ def run(
5545
dataset_name_or_path=dataset_name,
5646
input_column_name=input_column_name,
5747
target_column_name=target_column_name,
58-
logging_kwargs=logging_kwargs,
48+
logging_config=LightningLoggingConfig.from_flags(
49+
wandb=wandb,
50+
tensorboard=tensorboard,
51+
csv=csv,
52+
tracking_project_name=tracking_project_name,
53+
wandb_entity=wandb_entity,
54+
),
5955
n_trials=n_trials,
6056
).persisting(
6157
best_params_path=output_path.joinpath("best_params.yaml"),
6258
log_path=output_path.joinpath("hps_log.pickle"),
6359
)
6460
df, metadata = pipeline.run(run_name=f"search-{run_name}")
61+
del pipeline
6562

66-
pipeline = LightningSequenceLabelingPipeline(
67-
output_path=output_path, logging_kwargs=logging_kwargs, **metadata
63+
metadata["output_path"] = output_path
64+
retrain_pipeline = LightningSequenceLabelingPipeline(
65+
logging_config=LightningLoggingConfig.from_flags(
66+
wandb=wandb,
67+
tensorboard=tensorboard,
68+
csv=csv,
69+
tracking_project_name=tracking_project_name,
70+
wandb_entity=wandb_entity,
71+
),
72+
**metadata,
6873
)
69-
pipeline.run(run_name=f"best-params-retrain-{run_name}")
74+
retrain_pipeline.run(run_name=f"best-params-retrain-{run_name}")

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,11 @@ black_fix = "black ./"
6363
isort_fix = "isort . "
6464
pyflakes = "pyflakes embeddings"
6565
mypy = "mypy -p embeddings"
66+
mypy_examples = "mypy examples"
6667
coverage_test = "coverage run -m pytest"
6768
coverage_report = "coverage report -mi"
6869
test = ["coverage_test", "coverage_report"]
69-
check = ["black", "isort", "mypy", "pyflakes"]
70+
check = ["black", "isort", "pyflakes", "mypy", "mypy_examples"]
7071
fix = ["black_fix", "isort_fix"]
7172
all = ["check", "test"]
7273

0 commit comments

Comments
 (0)