Skip to content

Commit 86b1053

Browse files
committed
Fix merge LoRA adapters
1 parent 48dcbba commit 86b1053

2 files changed

Lines changed: 7 additions & 10 deletions

File tree

fine-tune.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def main():
4949
parser.add_argument("--lora_alpha", default=1.0, type=float)
5050
parser.add_argument("--activation_checkpointing", action="store_true")
5151
parser.add_argument("--eval_interval", default=1, type=int)
52-
parser.add_argument("--eval_ratio", default=0.1, type=float)
52+
parser.add_argument("--num_eval_samples", default=2048, type=int)
5353
parser.add_argument("--checkpoint_interval", default=1, type=int)
5454
parser.add_argument(
5555
"--checkpoint_path", default="./checkpoints/checkpoint.pt", type=str
@@ -89,11 +89,6 @@ def main():
8989
f"Eval interval must be greater than 0, {args.eval_interval} given."
9090
)
9191

92-
if args.eval_ratio < 0 or args.eval_ratio > 1:
93-
raise ValueError(
94-
f"Eval ratio must be between 0 and 1, {args.eval_ratio} given."
95-
)
96-
9792
if args.checkpoint_interval < 1:
9893
raise ValueError(
9994
f"Checkpoint interval must be greater than 0, {args.checkpoint_interval} given."
@@ -151,9 +146,9 @@ def main():
151146

152147
dataset = ConcatDataset(datasets)
153148

154-
training_ratio = 1.0 - args.eval_ratio
149+
n_train_samples = len(dataset) - args.num_eval_samples
155150

156-
training, testing = random_split(dataset, (training_ratio, args.eval_ratio))
151+
training, testing = random_split(dataset, [n_train_samples, args.num_eval_samples])
157152

158153
right_pad_collate = partial(
159154
pad_collate,
@@ -198,7 +193,7 @@ def main():
198193

199194
model.add_lora_parameters(**lora_args)
200195

201-
print("LoRA parameters added")
196+
print("Added LoRA adapters")
202197

203198
print(f"Model has {model.num_trainable_params:,} trainable parameters")
204199

src/nope_gpt/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ def merge_lora_parameters(self) -> None:
126126
if not hasattr(module, "parametrizations"):
127127
continue
128128

129-
for name in module.parametrizations.keys():
129+
lora_params = [name for name in module.parametrizations.keys()]
130+
131+
for name in lora_params:
130132
remove_parametrizations(module, name)
131133

132134
def forward(self, x: Tensor) -> Tensor:

0 commit comments

Comments
 (0)