Skip to content

Commit ccd82c6

Browse files
committed
added yaml (multi)representer for PretrainedConfig object types
1 parent 742c729 commit ccd82c6

File tree

6 files changed

+33
-14
lines changed

6 files changed

+33
-14
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
## [2.5.1] - 2025-03-27
8+
9+
### Added
10+
11+
- Support for Lightning ``2.5.1``
12+
- added (multi)representer for ``PretrainedConfig`` object types
13+
714
## [2.5.0] - 2024-12-20
815

916
### Added

dockers/base-cuda/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ RUN \
8484
else \
8585
# or target a specific cuda build, by specifying a particular index url w/...
8686
# ... default channel
87-
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124; \
87+
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124; \
8888
# ... pytorch patch version
8989
# pip install torch==1.11.1+cu113 torchvision==0.11.3+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html; \
9090
# ... pytorch nightly dev version

requirements/base.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#lightning>=2.5.0,<2.5.1
1+
lightning>=2.5.0,<2.5.2
22
# the below is uncommented when master is targeting a specific pl dev master commit
3-
git+https://github.com/Lightning-AI/lightning.git@878ecf56b06d5ae3f482e146e78accabc685bfc7#egg=lightning
3+
# git+https://github.com/Lightning-AI/lightning.git@878ecf56b06d5ae3f482e146e78accabc685bfc7#egg=lightning
44
torch>=2.2.0

requirements/standalone_base.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#pytorch-lightning>=2.5.0,<2.5.1
1+
pytorch-lightning>=2.5.0,<2.5.2
22
# the below is uncommented when master is targeting a specific pl dev master commit
3-
git+https://github.com/Lightning-AI/pytorch-lightning.git@878ecf56b06d5ae3f482e146e78accabc685bfc7#egg=pytorch-lightning
3+
# git+https://github.com/Lightning-AI/pytorch-lightning.git@878ecf56b06d5ae3f482e146e78accabc685bfc7#egg=pytorch-lightning
44
torch>=2.2.0

setup.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,15 @@ def _setup_args(standalone: bool = False) -> Dict[str, Any]:
131131
)
132132

133133
base_reqs = "standalone_base.txt" if standalone else "base.txt"
134-
# install_requires = setup_tools._load_requirements(
135-
# _INSTALL_PATHS["require"], file_name=base_reqs, standalone=standalone
136-
# )
137134
install_requires = setup_tools._load_requirements(
138-
_INSTALL_PATHS["require"],
139-
file_name=base_reqs,
140-
standalone=standalone,
141-
pl_commit="878ecf56b06d5ae3f482e146e78accabc685bfc7",
135+
_INSTALL_PATHS["require"], file_name=base_reqs, standalone=standalone
142136
)
137+
# install_requires = setup_tools._load_requirements(
138+
# _INSTALL_PATHS["require"],
139+
# file_name=base_reqs,
140+
# standalone=standalone,
141+
# pl_commit="878ecf56b06d5ae3f482e146e78accabc685bfc7",
142+
# )
143143
base_setup["install_requires"] = install_requires
144144
return base_setup
145145

src/fts_examples/cfg_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Dict, List, Any, Union
44
from dataclasses import dataclass, field, asdict, fields
55

6+
from transformers import PretrainedConfig
67
from lightning.pytorch.utilities.exceptions import MisconfigurationException
78

89

@@ -40,8 +41,19 @@ def optimizer_cfg_mapping_representer(dumper, data):
4041
def lr_scheduler_cfg_mapping_representer(dumper, data):
4142
return dumper.represent_mapping('tag:yaml.org,2002:map', asdict(data))
4243

43-
yaml.SafeDumper.add_representer(OptimizerCfg, optimizer_cfg_mapping_representer)
44-
yaml.SafeDumper.add_representer(LRSchedulerCfg, lr_scheduler_cfg_mapping_representer)
44+
def pretrained_cfg_mapping_representer(dumper, data):
45+
return dumper.represent_mapping('tag:yaml.org,2002:map', data.to_dict())
46+
47+
# Register all custom representers to both base dumper classes
48+
representers = {
49+
PretrainedConfig: pretrained_cfg_mapping_representer,
50+
OptimizerCfg: optimizer_cfg_mapping_representer,
51+
LRSchedulerCfg: lr_scheduler_cfg_mapping_representer
52+
}
53+
54+
for dumper_cls in [yaml.Dumper, yaml.SafeDumper]:
55+
for cls, representer in representers.items():
56+
dumper_cls.add_multi_representer(cls, representer)
4557

4658
def _is_overridden(dataclass_instance) -> bool:
4759
is_overridden = False

0 commit comments

Comments
 (0)