diff --git a/applications/ColossalChat/examples/training_scripts/lora_finetune.py b/applications/ColossalChat/examples/training_scripts/lora_finetune.py index 851ad6a2d9e3..f52615496a38 100644 --- a/applications/ColossalChat/examples/training_scripts/lora_finetune.py +++ b/applications/ColossalChat/examples/training_scripts/lora_finetune.py @@ -43,7 +43,36 @@ def all_reduce_mean(loss: torch.Tensor, plugin: Plugin) -> torch.Tensor: return loss / dist.get_world_size(group) +def get_second_latest_subfolder_and_optimizer_file(folder_path): + os.path.exists(folder_path) or os.makedirs(folder_path) + + # 获取所有以"lora"开头的子文件夹 + subfolders = [ + f for f in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, f)) and f.startswith("lora") + ] + + # 检查子文件夹数量是否大于等于2 + if len(subfolders) < 2: + return None, None # 如果子文件夹数量小于2,返回None + + # 按最后修改时间排序,最新的排在前面 + subfolders.sort(key=lambda x: os.path.getmtime(os.path.join(folder_path, x)), reverse=True) + + # 获取倒数第二新的子文件夹路径 + second_latest_subfolder = subfolders[1] if len(subfolders) >= 2 else None + second_latest_lora_subfolder_path = os.path.join(folder_path, second_latest_subfolder) + + # 获取所有以"optimizer"开头且".pth"为后缀的文件 + # 获取倒数第二新的optimizer文件 + second_latest_optimizer_subfolder_path = os.path.join( + folder_path, second_latest_subfolder.replace("lora_", "optimizer_") + ".pth" + ) + + return second_latest_lora_subfolder_path, second_latest_optimizer_subfolder_path + + def train(args) -> None: + # ============================== # Initialize Distributed Training # ============================== @@ -208,7 +237,12 @@ def is_master(): ) else: lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=args.lora_alpha) - model = booster.enable_lora(model, lora_config=lora_config) + if args.lora_path: + coordinator.print_on_master(f"Loading lora weights from: {args.lora_path}") + model = booster.enable_lora(model, pretrained_dir=args.lora_path) + else: + model = booster.enable_lora(model, lora_config=lora_config) + model.enable_input_require_grads() # this is essential, otherwise the grad checkpoint will not work. model.train() @@ -257,7 +291,7 @@ def is_master(): ) torch.set_default_dtype(torch.float) - booster.load_model(model, args.pretrained) + booster.load_model(model, args.pretrained, strict=False) coordinator.print_on_master( f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" @@ -269,6 +303,26 @@ def is_master(): start_epoch = 0 start_step = 0 + if not (args.lora_path or args.optmizer_path): + args.lora_path, args.optmizer_path = get_second_latest_subfolder_and_optimizer_file(args.save_dir) + coordinator.print_on_master(f"Lora Path:{args.lora_path}") + coordinator.print_on_master(f"Optimizer Path:{args.optmizer_path}") + + # Load checkpoint if available + if args.optmizer_path: + checkpoint_path = args.optmizer_path + if os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location=get_current_device()) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) + start_epoch = checkpoint["epoch"] + start_step = checkpoint["step"] + coordinator.print_on_master(f"Resuming optimizer from epoch {start_epoch}, step {start_step}") + else: + coordinator.print_on_master("optimizer checkpoint not found, starting training from scratch") + else: + coordinator.print_on_master("Starting training from optimizer scratch") + num_steps_per_epoch = len(dataloader) // args.accumulation_steps for epoch in range(start_epoch, args.num_epochs): @@ -316,10 +370,12 @@ def is_master(): dataloader, desc=f"Epoch {epoch}", disable=not is_master(), - initial=start_step // args.accumulation_steps, + initial=start_step // args.accumulation_steps, # 设置起始位置 ) total_loss = torch.tensor(0.0, device=get_current_device()) for step, batch in enumerate(pbar, start=start_step // args.accumulation_steps): + if step > num_steps_per_epoch: + break batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} batch_output = model(**batch) @@ -348,9 +404,25 @@ def is_master(): lr_scheduler.step() optimizer.zero_grad() + # print(lr_scheduler.get_last_lr()[0]) total_loss.fill_(0.0) + if (step + 1) % args.save_interval == 0: + if args.lora_rank > 0: + booster.save_lora_as_pretrained( + model, os.path.join(args.save_dir, f"lora_epoch{epoch}_step{step}") + ) + checkpoint = { + "epoch": epoch, + "step": step + 1, + "optimizer_state_dict": optimizer.state_dict(), + "lr_scheduler_state_dict": lr_scheduler.state_dict(), + } + torch.save(checkpoint, os.path.join(args.save_dir, f"optimizer_epoch{epoch}_step{step}.pth")) + coordinator.print_on_master(f"Saved checkpoint at epoch {epoch}, step {step + 1}") + + start_step = 0 # Delete cache. # del batch, batch_labels, batch_output, loss accelerator.empty_cache() @@ -373,10 +445,16 @@ def is_master(): "-m", "--pretrained", type=str, - required=True, + default=None, help="Address of the pre-trained model", ) - parser.add_argument("-d", "--dataset", type=str, required=True, help="Raw Jonl dataset for training.") + parser.add_argument( + "-d", + "--dataset", + type=str, + default=None, + help="Raw Jonl dataset for training.", + ) parser.add_argument( "-p", "--plugin", @@ -385,15 +463,69 @@ def is_master(): choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp", "moe"], help="Choose which plugin to use", ) - parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory") - parser.add_argument("--tensorboard_dir", type=str, default=None, help="Tensorboard directory") - parser.add_argument("--config_file", type=str, default="training_config.json", help="Config file") + parser.add_argument( + "--save_dir", + type=str, + default=None, + help="Checkpoint directory", + ) + parser.add_argument("--save_interval", type=int, default=100, help="Save interval") + parser.add_argument( + "--lora_path", + type=str, + default=None, + help="Lora checkpoint directory", + ) + parser.add_argument( + "--optmizer_path", + type=str, + default=None, + help="Optmizer checkpoint directory", + ) + parser.add_argument( + "--tensorboard_dir", + type=str, + default="logs", + help="Tensorboard directory", + ) + parser.add_argument( + "--config_file", + type=str, + default="training_config.json", + help="Config file", + ) # Training parameters - parser.add_argument("-n", "--num_epochs", type=int, default=1, help="Number of training epochs") - parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps") - parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process") - parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") - parser.add_argument("--max_length", type=int, default=8192, help="Model max length") + parser.add_argument( + "-n", + "--num_epochs", + type=int, + default=2, + help="Number of training epochs", + ) + parser.add_argument( + "--accumulation_steps", + type=int, + default=1, + help="Number of accumulation steps", + ) + parser.add_argument( + "--batch_size", + type=int, + default=2, + help="Global Batch size of each process", + ) + parser.add_argument( + "--lr", + type=float, + default=2e-5, + help="Learning rate", + ) + parser.add_argument( + "--max_length", + type=int, + default=256, + help="Model max length", + ) parser.add_argument( "--mixed_precision", type=str, @@ -401,14 +533,29 @@ def is_master(): choices=["fp16", "bf16"], help="Mixed precision", ) - parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") - parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") - parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument( + "--grad_clip", + type=float, + default=1.0, + help="Gradient clipping value", + ) + parser.add_argument( + "--weight_decay", + type=float, + default=0.1, + help="Weight decay", + ) + parser.add_argument( + "--warmup_steps", + type=int, + default=8, + help="Warmup steps", + ) parser.add_argument( "-g", "--use_grad_checkpoint", action="store_true", - default=False, + default=True, help="Use gradient checkpointing", ) parser.add_argument( @@ -420,11 +567,37 @@ def is_master(): ) # Additional arguments for 3d plugin. - parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.") - parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.") - parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.") - parser.add_argument("--ep", type=int, default=1, help="EP size, used for moe plugin.") - parser.add_argument("--zero_stage", type=int, default=1, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2]) + parser.add_argument( + "--tp", + type=int, + default=1, + help="TP size, used for 3d plugin.", + ) + parser.add_argument( + "--pp", + type=int, + default=1, + help="PP size, used for 3d plugin.", + ) + parser.add_argument( + "--sp", + type=int, + default=1, + help="SP size, used for 3d plugin.", + ) + parser.add_argument( + "--ep", + type=int, + default=1, + help="EP size, used for moe plugin.", + ) + parser.add_argument( + "--zero_stage", + type=int, + default=1, + help="Zero stage, used for 3d plugin.", + choices=[0, 1, 2], + ) parser.add_argument( "--sp_mode", type=str, @@ -439,13 +612,29 @@ def is_master(): help="Whether to enable SP, used for 3d plugin.", ) parser.add_argument( - "--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin." + "--zero_cpu_offload", + default=False, + action="store_true", + help="Whether to use offloading, used for 3d plugin.", + ) + parser.add_argument( + "--microbatch_size", + type=int, + default=1, + help="Batch size for each process in PP, used for 3d plugin.", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=8, + help="lora rank when using lora to train.", ) parser.add_argument( - "--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin." + "--lora_alpha", + type=int, + default=16, + help="lora alpha when using lora to train.", ) - parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.") - parser.add_argument("--lora_alpha", type=int, default=8, help="lora alpha when using lora to train.") args = parser.parse_args()