Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion bergson/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DataConfig(Serializable):
and optionally `doc_to_target` and `doc_to_choice`. MCQA YAML
available at `bergson/templates/mcqa.yaml`."""

data_args: str = ""
data_kwargs: str = ""
"""Arguments to pass to the dataset constructor in the format
arg1=val1,arg2=val2."""

Expand Down Expand Up @@ -136,6 +136,14 @@ class ModelConfig(ABC):
fsdp: bool = False
"""Whether to use PyTorch Fully Sharded Data Parallel (FSDP)"""

peft_init_kwargs: str = ""
"""peft.LoraConfig arguments for initializing a PEFT adapter on the
base model in the format 'arg1=val1,arg2=val2'.
Use | to separate list values, e.g. target_modules=q_proj|k_proj|v_proj."""

model_kwargs: str = ""
"""HF Model kwargs for in the format 'arg1=val1,arg2=val2'."""


@dataclass
class LRScheduleConfig(Serializable):
Expand Down
6 changes: 3 additions & 3 deletions bergson/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .config import DataConfig
from .utils.utils import (
assert_type,
simple_parse_args_string,
simple_parse_kwargs_string,
)


Expand Down Expand Up @@ -425,7 +425,7 @@ def load_data_string(
data_str: str,
split: str = "train",
subset: str | None = None,
data_args: str = "",
data_kwargs: str = "",
) -> Dataset:
"""Load a dataset from a string identifier or path."""
if data_str.endswith(".csv"):
Expand All @@ -438,7 +438,7 @@ def load_data_string(
ds = ds[split]
else:
try:
kwargs = simple_parse_args_string(data_args)
kwargs = simple_parse_kwargs_string(data_kwargs)
ds = load_dataset(data_str, subset, split=split, **kwargs)

if isinstance(ds, DatasetDict) or isinstance(ds, IterableDatasetDict):
Expand Down
2 changes: 1 addition & 1 deletion bergson/query/query_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def query(
index_cfg.data.dataset,
index_cfg.data.split,
index_cfg.data.subset,
index_cfg.data.data_args,
index_cfg.data.data_kwargs,
)

faiss_cfg = FaissConfig() if query_cfg.faiss else None
Expand Down
59 changes: 39 additions & 20 deletions bergson/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,32 +55,51 @@ def setup_reproducibility():


def handle_arg_string(arg: str):
if arg.lower() == "true":
return True
elif arg.lower() == "false":
return False
elif arg.isnumeric():
# Handle lists
if "|" in arg:
return [handle_arg_string(v) for v in arg.split("|")]

# Handle integers
try:
return int(arg)
except ValueError:
pass

# Handle floats
try:
return float(arg)
except ValueError:
return arg
pass

# Handle booleans
match arg.lower():
case "true":
return True
case "false":
return False
case _:
return arg


def simple_parse_args_string(args_string: str) -> dict[str, Any]:
"""
Parses something like
args1=val1,arg2=val2
into a dictionary.
"""
args_string = args_string.strip()
if not args_string:
return {}
arg_list = [arg for arg in args_string.split(",") if arg]
args_dict = {
kv[0]: handle_arg_string("=".join(kv[1:]))
for kv in [arg.split("=") for arg in arg_list]
}
def simple_parse_kwargs_string(args_string: str) -> dict:
"""Parses something like `args1=val1,arg2=val2` into a dictionary."""
args_dict = {}

for elem in args_string.split(","):
lvalue, sep, rvalue = elem.partition("=")

# Ignore whitespace
lvalue = lvalue.strip()
rvalue = rvalue.strip()

if not (lvalue and sep):
raise ValueError(f"Invalid argument: '{elem}'. Expected format key=value.")

if not lvalue.isidentifier():
raise ValueError(f"Invalid key: '{lvalue}'. Must be a valid identifier.")

args_dict[lvalue] = handle_arg_string(rvalue)

return args_dict


Expand Down
110 changes: 66 additions & 44 deletions bergson/utils/worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
Dataset,
IterableDataset,
)
from peft import PeftConfig, PeftModel, get_peft_model_state_dict
from peft import (
PeftConfig,
PeftModel,
PeftType,
get_peft_model,
get_peft_model_state_dict,
)
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
from torch.distributed.fsdp import fully_shard
from transformers import (
AutoConfig,
Expand All @@ -28,6 +35,7 @@
from bergson.format import apply_format
from bergson.gradients import GradientProcessor, Normalizer
from bergson.utils import assert_type, get_layer_list, weighted_causal_lm_ce
from bergson.utils.utils import simple_parse_kwargs_string

BIG_NUM = np.iinfo(np.int64).max

Expand Down Expand Up @@ -107,6 +115,24 @@ def apply_force_math_sdp(cfg: ModelConfig) -> None:
print("force_math_sdp: disabled flash and memory-efficient SDPA backends")


def extract_peft_target_modules(model) -> set[str]:
"""Extract adapter module names from a PeftModel."""
target_modules: set[str] = set()
peft_state_dict = get_peft_model_state_dict(model=model)
for adapter in model.peft_config.keys(): # type: ignore
for name in list(peft_state_dict.keys()):
prefix = name.removesuffix(".weight")
processed_name = f"{prefix}.{adapter}".removeprefix("base_model.")
try:
model.get_submodule(processed_name)
target_modules.add(processed_name)
except AttributeError:
print(
f"Adapter parameter '{processed_name}'" " not found in the model."
)
return target_modules


def setup_model_and_peft(
cfg: ModelConfig,
device_map_auto: bool = False,
Expand All @@ -115,6 +141,7 @@ def setup_model_and_peft(
) -> tuple[PreTrainedModel | PeftModel, set | None]:
"""Handle model loading, quantization, FSDP, and PEFT detection"""
apply_force_math_sdp(cfg)

local_rank = cfg.distributed.local_rank

match cfg.precision:
Expand Down Expand Up @@ -150,57 +177,52 @@ def setup_model_and_peft(
bnb_4bit_use_double_quant=True,
)

# Try to detect PEFT model
# Determine base model path and whether we're loading a pretrained adapter
try:
peft_config = PeftConfig.from_pretrained(cfg.model)
pretrained_peft_config = PeftConfig.from_pretrained(cfg.model)
except ValueError:
peft_config = None
pretrained_peft_config = None

if peft_config is None:
# Load regular model
model = AutoModelForCausalLM.from_pretrained(
cfg.model,
device_map=device_map,
quantization_config=quantization_config,
dtype=dtype,
revision=cfg.revision,
**model_kwargs,
)
model.loss_function = weighted_causal_lm_ce
target_modules = None
else:
# Load PEFT model
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path, # type: ignore
device_map=device_map,
quantization_config=quantization_config,
dtype=dtype,
revision=cfg.revision,
**model_kwargs,
)
base_model.loss_function = weighted_causal_lm_ce
assert not (cfg.peft_init_kwargs and pretrained_peft_config), (
f"peft_init_args is set but '{cfg.model}' is already a" " PEFT adapter."
)

base_model_path = (
pretrained_peft_config.base_model_name_or_path # type: ignore
if pretrained_peft_config
else cfg.model
)
assert base_model_path is not None

model_kwargs.update(simple_parse_kwargs_string(cfg.model_kwargs))

model = AutoModelForCausalLM.from_pretrained(
base_model_path,
device_map=device_map,
quantization_config=quantization_config,
dtype=dtype,
revision=cfg.revision,
**model_kwargs,
)
model.loss_function = weighted_causal_lm_ce
target_modules = None

if cfg.peft_init_kwargs:
# Initialize a fresh PEFT adapter
peft_kwargs = simple_parse_kwargs_string(cfg.peft_init_kwargs)
peft_type = PeftType(peft_kwargs.pop("peft_type", "LORA"))
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]
model = get_peft_model(model, peft_config_cls(**peft_kwargs))
target_modules = extract_peft_target_modules(model)
elif pretrained_peft_config:
# Load pretrained PEFT adapter
model = PeftModel.from_pretrained(
base_model,
model,
cfg.model,
device_map=device_map,
autocast_adapter_dtype=False,
)

# Extract target modules
target_modules = set()
peft_state_dict = get_peft_model_state_dict(model=model)
for adapter in model.peft_config.keys():
for name in list(peft_state_dict.keys()):
prefix = name.removesuffix(".weight")
processed_name = f"{prefix}.{adapter}".removeprefix("base_model.")
try:
model.get_submodule(processed_name)
target_modules.add(processed_name)
except AttributeError:
print(
f"Adapter parameter '{processed_name}' not found in the model."
)
target_modules = extract_peft_target_modules(model) # type: ignore

# Configure gradients
model.requires_grad_(False)
Expand Down Expand Up @@ -317,7 +339,7 @@ def setup_data_pipeline(
data_cfg = data_cfg or cfg.data

ds = load_data_string(
data_cfg.dataset, data_cfg.split, data_cfg.subset, data_cfg.data_args
data_cfg.dataset, data_cfg.split, data_cfg.subset, data_cfg.data_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer or cfg.model)
max_model_length = max_tokens_for_model(tokenizer, cfg.model, cfg.revision)
Expand Down
2 changes: 1 addition & 1 deletion examples/trainer_grad_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def main(args: IndexConfig):
dataset = load_data_string(
args.data.dataset,
args.data.split,
data_args=args.data.data_args,
data_kwargs=args.data.data_kwargs,
)
dataset = dataset.map(
tokenize,
Expand Down
Loading