Skip to content

Commit 43fc0ca

Browse files
Andrew Patrikalakisanrp
authored andcommitted
Write README.md for LoRA adapter as well
1 parent 294e551 commit 43fc0ca

3 files changed

Lines changed: 140 additions & 84 deletions

File tree

src/heretic/main.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def _is_help_invocation() -> bool:
4343
import torch
4444
import torch.nn.functional as F
4545
import transformers
46-
from huggingface_hub import ModelCard, ModelCardData
4746
from lm_eval.models.huggingface import HFLM
4847
from optuna import Trial, TrialPruned
4948
from optuna.exceptions import ExperimentalWarning
@@ -61,10 +60,10 @@ def _is_help_invocation() -> bool:
6160
from .config import QuantizationMethod
6261
from .evaluator import Evaluator
6362
from .model import AbliterationParameters, Model, get_model_class
63+
from .model_card_utils import get_model_card
6464
from .system import empty_cache, get_accelerator_info
6565
from .utils import (
6666
format_duration,
67-
get_readme_intro,
6867
get_trial_parameters,
6968
is_hf_path,
7069
load_prompts,
@@ -779,6 +778,7 @@ def count_completed_trials() -> int:
779778
save_directory,
780779
max_shard_size=settings.max_shard_size,
781780
)
781+
card = get_model_card(settings, trial, "", True)
782782
else:
783783
print("Saving merged model...")
784784
merged_model = model.get_merged_model()
@@ -789,6 +789,10 @@ def count_completed_trials() -> int:
789789
del merged_model
790790
empty_cache()
791791
model.tokenizer.save_pretrained(save_directory)
792+
card = get_model_card(settings, trial, "", False)
793+
794+
if card is not None:
795+
card.save(f"{save_directory}/README.md")
792796

793797
print(f"Model saved to [bold]{save_directory}[/].")
794798

@@ -882,6 +886,12 @@ def count_completed_trials() -> int:
882886
max_shard_size=settings.max_shard_size,
883887
token=token,
884888
)
889+
card = get_model_card(
890+
settings,
891+
trial,
892+
reproducibility_information,
893+
True,
894+
)
885895
else:
886896
print("Uploading merged model...")
887897
merged_model = model.get_merged_model()
@@ -899,37 +909,14 @@ def count_completed_trials() -> int:
899909
token=token,
900910
)
901911

902-
if is_hf_path(settings.model):
903-
card = ModelCard.load(settings.model)
904-
else:
905-
card_path = (
906-
Path(settings.model)
907-
/ huggingface_hub.constants.REPOCARD_NAME
912+
card = get_model_card(
913+
settings,
914+
trial,
915+
reproducibility_information,
916+
False,
908917
)
909-
if card_path.exists():
910-
card = ModelCard.load(card_path)
911-
else:
912-
card = None
913918

914919
if card is not None:
915-
if card.data is None:
916-
card.data = ModelCardData()
917-
if card.data.tags is None:
918-
card.data.tags = []
919-
card.data.tags.append("heretic")
920-
card.data.tags.append("uncensored")
921-
card.data.tags.append("decensored")
922-
card.data.tags.append("abliterated")
923-
if reproducibility_information != "none":
924-
card.data.tags.append("reproducible")
925-
card.text = (
926-
get_readme_intro(
927-
settings,
928-
trial,
929-
reproducibility_information != "none",
930-
)
931-
+ card.text
932-
)
933920
card.push_to_hub(repo_id, token=token)
934921

935922
if reproducibility_information != "none":

src/heretic/model_card_utils.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# SPDX-License-Identifier: AGPL-3.0-or-later
2+
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
3+
4+
from pathlib import Path
5+
6+
import huggingface_hub
7+
from huggingface_hub import ModelCard, ModelCardData
8+
from optuna import Trial
9+
10+
from .config import RowNormalization, Settings
11+
from .system import (
12+
get_heretic_version_info,
13+
)
14+
from .utils import get_trial_parameters, is_hf_path
15+
16+
17+
def get_readme_intro(
18+
settings: Settings,
19+
trial: Trial,
20+
contains_reproducibility_information: bool,
21+
is_lora: bool,
22+
) -> str:
23+
if is_hf_path(settings.model):
24+
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
25+
else:
26+
# Hide the path, which may contain private information.
27+
model_link = "a model"
28+
29+
version_info = get_heretic_version_info()
30+
31+
if contains_reproducibility_information:
32+
reproducibility_instructions = """
33+
> [!TIP]
34+
> **This model is reproducible!**
35+
>
36+
> See the [README](reproduce/README.md) in the `reproduce` directory for more information.
37+
"""
38+
else:
39+
reproducibility_instructions = ""
40+
41+
return f"""# This is a decensored {"adapter" if is_lora else "version"} of {
42+
model_link
43+
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version_info.version}
44+
{reproducibility_instructions}
45+
## Abliteration parameters
46+
47+
| Parameter | Value |
48+
| :-------- | :---: |
49+
{
50+
chr(10).join(
51+
[
52+
f"| **{name}** | {value} |"
53+
for name, value in get_trial_parameters(trial).items()
54+
]
55+
)
56+
}
57+
58+
## Performance
59+
60+
| Metric | This model | Original model ({model_link}) |
61+
| :----- | :--------: | :---------------------------: |
62+
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
63+
| **Refusals** | {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]} | {
64+
trial.user_attrs["base_refusals"]
65+
}/{trial.user_attrs["n_bad_prompts"]} |
66+
67+
-----
68+
69+
"""
70+
71+
72+
def get_model_card(
73+
settings: Settings,
74+
trial: Trial,
75+
reproducibility_information: str,
76+
is_lora: bool,
77+
) -> ModelCard | None:
78+
# If the model path exists locally and includes the
79+
# card, use it directly. If the model path doesn't
80+
# exist locally, it can be assumed to be a model
81+
# hosted on the Hugging Face Hub, in which case
82+
# we can retrieve the model card.
83+
if is_hf_path(settings.model):
84+
card = ModelCard.load(settings.model)
85+
else:
86+
card_path = Path(settings.model) / huggingface_hub.constants.REPOCARD_NAME
87+
if card_path.exists():
88+
card = ModelCard.load(card_path)
89+
else:
90+
card = None
91+
92+
if card is not None:
93+
if card.data is None:
94+
card.data = ModelCardData()
95+
if card.data.tags is None:
96+
card.data.tags = []
97+
card.data.tags.append("heretic")
98+
card.data.tags.append("uncensored")
99+
card.data.tags.append("decensored")
100+
card.data.tags.append("abliterated")
101+
if (
102+
settings.orthogonalize_direction
103+
and settings.row_normalization == RowNormalization.FULL
104+
):
105+
card.data.tags.append("mpoa")
106+
if reproducibility_information != "none":
107+
card.data.tags.append("reproducible")
108+
109+
if is_hf_path(settings.model):
110+
card.data.base_model = settings.model
111+
card.data.base_model_relation = "adapter" if is_lora else "finetuned"
112+
113+
card.text = (
114+
get_readme_intro(
115+
settings,
116+
trial,
117+
reproducibility_information != "none",
118+
is_lora,
119+
)
120+
+ card.text
121+
)
122+
123+
return card

src/heretic/utils.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -272,60 +272,6 @@ def get_trial_parameters(trial: Trial) -> dict[str, str]:
272272
return params
273273

274274

275-
def get_readme_intro(
276-
settings: Settings,
277-
trial: Trial,
278-
contains_reproducibility_information: bool,
279-
) -> str:
280-
if is_hf_path(settings.model):
281-
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
282-
else:
283-
# Hide the path, which may contain private information.
284-
model_link = "a model"
285-
286-
version_info = get_heretic_version_info()
287-
288-
if contains_reproducibility_information:
289-
reproducibility_instructions = """
290-
> [!TIP]
291-
> **This model is reproducible!**
292-
>
293-
> See the [README](reproduce/README.md) in the `reproduce` directory for more information.
294-
"""
295-
else:
296-
reproducibility_instructions = ""
297-
298-
return f"""# This is a decensored version of {
299-
model_link
300-
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version_info.version}
301-
{reproducibility_instructions}
302-
## Abliteration parameters
303-
304-
| Parameter | Value |
305-
| :-------- | :---: |
306-
{
307-
chr(10).join(
308-
[
309-
f"| **{name}** | {value} |"
310-
for name, value in get_trial_parameters(trial).items()
311-
]
312-
)
313-
}
314-
315-
## Performance
316-
317-
| Metric | This model | Original model ({model_link}) |
318-
| :----- | :--------: | :---------------------------: |
319-
| **KL divergence** | {trial.user_attrs["kl_divergence"]:.4f} | 0 *(by definition)* |
320-
| **Refusals** | {trial.user_attrs["refusals"]}/{trial.user_attrs["n_bad_prompts"]} | {
321-
trial.user_attrs["base_refusals"]
322-
}/{trial.user_attrs["n_bad_prompts"]} |
323-
324-
-----
325-
326-
"""
327-
328-
329275
def generate_config_toml(settings: Settings) -> str:
330276
"""Serializes the full Settings object to TOML."""
331277

0 commit comments

Comments
 (0)