Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
### data
train_dataset_type: messages
eval_dataset_type: messages
train_dataset_path: ./ocr_vl_sft-train_Bengali.jsonl
train_dataset_prob: "1.0"
eval_dataset_path: ./ocr_vl_sft-test_Bengali.jsonl
eval_dataset_prob: "1.0"
max_seq_len: 8192
padding_free: False
packing: False
truncate_packing: False
dataset_type: map
dataloader_num_workers: 8
mix_strategy: concat
template_backend: custom
template: deepseek_ocr2

### model
model_name_or_path: deepseek-ai/DeepSeek-OCR-2
_attn_implementation: flashmask
copy_custom_file_list: "configuration_deepseek_v2.py conversation.py deepencoderv2.py modeling_deepseekocr2.py modeling_deepseekv2.py"

### finetuning
# base
stage: VL-SFT
fine_tuning: full
seed: 42
do_train: true
do_eval: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 8
num_train_epochs: 2
max_steps: -1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增一个文档简单介绍一下,并提供数据集./ocr_vl_sft-train_Bengali.jsonl下载方式

max_estimate_samples: 500
eval_steps: 400
evaluation_strategy: steps
save_steps: 400
save_strategy: steps
logging_steps: 1
gradient_accumulation_steps: 8
logging_dir: ./Deepseek-OCR2-Bengali/visualdl_logs/
output_dir: ./Deepseek-OCR2-SFT-Bengali
disable_tqdm: true
eval_accumulation_steps: 16

# train
lr_scheduler_type: cosine
warmup_ratio: 0.01
learning_rate: 5.0e-6
min_lr: 5.0e-7

# optimizer
weight_decay: 0.1
adam_epsilon: 1.0e-8
adam_beta1: 0.9
adam_beta2: 0.95

# performance
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
sharding: stage1
recompute_granularity: full
recompute_method: uniform
recompute_num_layers: 1
bf16: true
fp16_opt_level: O2
# pre_alloc_memory: 42

# save
unified_checkpoint: False
save_checkpoint_format: "flex_checkpoint"
load_checkpoint_format: "flex_checkpoint"
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
### data
train_dataset_type: messages
eval_dataset_type: messages
train_dataset_path: ./ocr_vl_sft-train_Bengali.jsonl
train_dataset_prob: "1.0"
eval_dataset_path: ./ocr_vl_sft-test_Bengali.jsonl
eval_dataset_prob: "1.0"
max_seq_len: 8192
padding_free: False
packing: False
truncate_packing: False
dataset_type: map
dataloader_num_workers: 8
mix_strategy: concat
template_backend: custom
template: deepseek_ocr2

### model
model_name_or_path: deepseek-ai/DeepSeek-OCR-2
_attn_implementation: flashmask
lora: true
lora_rank: 8
lora_alpha: 32
copy_custom_file_list: "configuration_deepseek_v2.py conversation.py deepencoderv2.py modeling_deepseekocr2.py modeling_deepseekv2.py"

### finetuning
# base
stage: VL-SFT
fine_tuning: lora
seed: 42
do_train: true
do_eval: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 8
num_train_epochs: 1
max_steps: -1
max_estimate_samples: 500
eval_steps: 400
evaluation_strategy: steps
save_steps: 400
save_strategy: steps
logging_steps: 1
gradient_accumulation_steps: 8
logging_dir: ./Deepseek-OCR2-Bengali-lora/visualdl_logs/
output_dir: ./Deepseek-OCR2-SFT-Bengali-lora
disable_tqdm: true
eval_accumulation_steps: 16

# train
lr_scheduler_type: cosine
warmup_ratio: 0.01
learning_rate: 5.0e-4
min_lr: 5.0e-5

# optimizer
weight_decay: 0.1
adam_epsilon: 1.0e-8
adam_beta1: 0.9
adam_beta2: 0.95

# performance
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
sharding: stage1
recompute_granularity: full
recompute_method: uniform
recompute_num_layers: 1
bf16: true
fp16_opt_level: O2
# pre_alloc_memory: 45

# save
unified_checkpoint: False
save_checkpoint_format: "flex_checkpoint"
load_checkpoint_format: "flex_checkpoint"
11 changes: 9 additions & 2 deletions paddleformers/cli/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
check_data_split,
)
from paddleformers.data.indexed_dataset import SFTMMapIndexedDatasetBuilder
from paddleformers.datasets.collate import collate_fn, mm_collate_fn
from paddleformers.datasets.collate import (
collate_fn,
mm_collate_fn,
mm_collate_fn_ds_ocr2,
)
from paddleformers.datasets.data_utils import estimate_training
from paddleformers.datasets.loader import create_dataset as create_dataset_sft
from paddleformers.datasets.loader import create_indexed_dataset
Expand Down Expand Up @@ -631,8 +635,11 @@ def fetch_and_serialize(generator, dtype):
logger.info(f"Setting max_seq_len to {max_seq_len} using PaddleFormers Model.")
if data_args.dataset_type != "pretrain":
if "VL" in model_args.stage:
cur_mm_collate_fn = mm_collate_fn
if model_config.model_type == "deepseek_ocr2":
cur_mm_collate_fn = mm_collate_fn_ds_ocr2
data_collator = partial(
mm_collate_fn,
cur_mm_collate_fn,
template=template_instance,
processor=processor,
tokenizer=tokenizer,
Expand Down
27 changes: 27 additions & 0 deletions paddleformers/cli/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,33 @@ def get_lora_target_modules(model):
"model.visual.blocks.*mlp.up_proj.*",
"model.visual.blocks.*mlp.down_proj.*",
]
elif model.config.model_type == "deepseek_ocr2":
target_modules = [
# Language Model (DeepseekV3)
".*model.*q_proj.*",
".*model.*q_a_proj.*",
".*model.*q_b_proj.*",
".*model.*kv_a_proj_with_mqa.*",
".*model.*kv_b_proj.*",
".*model.*k_proj.*",
".*model.*v_proj.*",
".*model.*o_proj.*",
".*model.*mlp.gate_proj.*",
".*model.*mlp.up_proj.*",
".*model.*mlp.down_proj.*",
# SAM Vision Encoder
"sam_model.*attn.qkv.*",
"sam_model.*attn.proj.*",
"sam_model.*mlp.lin1.*",
"sam_model.*mlp.lin2.*",
# Qwen2 Encoder-as-Decoder
"qwen2_model.*self_attn.qkv_proj.*",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lora的时候要训练VIT吗

"qwen2_model.*self_attn.o_proj.*",
"qwen2_model.*mlp.up_gate_proj.*",
"qwen2_model.*mlp.down_proj.*",
# Projector
"projector.*",
]
else:
raise ValueError(f"Unknown base_model_prefix: {model.config.model_type}.")
return target_modules
Expand Down
134 changes: 134 additions & 0 deletions paddleformers/datasets/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,140 @@ def mm_collate_fn(
return input_dict


def mm_collate_fn_ds_ocr2(
batch: List[List[Sequence]],
template,
processor,
tokenizer,
training_args,
model_args,
max_seq_len: int,
padding_free: bool,
model,
):
"""Convert batch of sequences into training tensors.

Args:
batch (List[List[Sequence]]): Batch of input sequences
tokenizer: Tokenizer for text conversion
model_args: Model configuration parameters
max_seq_len (int): Maximum sequence length for padding
padding_free (bool): Whether to flatten the data within a batch to avoid padding

Returns:
dict: Dictionary containing:
- input_ids: Padded token IDs
- labels: Shifted labels for prediction
- loss_mask: Mask for computing loss
"""

if isinstance(model, LoRAModel):
model = model.model.base_model

input_keys = ["input_ids", "labels", "position_ids", "images_spatial_crop", "images_seq_mask"]

if training_args.num_nextn_predict_layers > 0:
input_keys.append("nbatch_pack_offset")
if model_args.use_attn_mask_startend_row_indices:
input_keys.append("attn_mask_startend_row_indices")
else:
input_keys.append("attention_mask")

return_list = []
return_images_list = []
if padding_free:
batch = [sum(batch, [])]
max_seq_len = sum(len(item.token_ids) for sequence in batch for item in sequence)
if not max_seq_len:
max_seq_len = max(sum(len(item.token_ids) for item in sequence) for sequence in batch)
max_seq_len = calc_padding_size(max_seq_len, training_args)
if training_args.num_nextn_predict_layers > 0:
max_seq_len += training_args.num_nextn_predict_layers

for batch_sequence in batch:
original_token_ids = []
original_position_ids = []
images_list = []
images_spatial_crop_list = []
images_seq_mask_list = []
for seq in batch_sequence:
original_token_ids.append(seq.token_ids)
original_position_ids.append(seq.position_ids)
mm_inputs = seq.mm_inputs

cur_image = mm_inputs["images"]
cur_images_crop = mm_inputs["images_crop"]
images_list.extend((cur_images_crop, cur_image))
images_spatial_crop_list.extend(mm_inputs["images_spatial_crop"])
images_seq_mask = (
paddle.to_tensor(seq.token_ids)
== tokenizer.encode(template.mm_plugin.image_token, add_special_tokens=False)[0]
)
images_seq_mask_list.append(images_seq_mask)

if original_position_ids:
position_ids = [np.concatenate(original_position_ids)]
padded_position_ids = pad_batch_data(position_ids, pad_idx=0, max_seq_len=max_seq_len)
else:
padded_position_ids = []

token_ids = [np.concatenate(original_token_ids)]
labels = [np.concatenate([seq.labels for seq in batch_sequence])]
# padding
padded_token_ids = pad_batch_data(token_ids, pad_idx=tokenizer.pad_token_id, max_seq_len=max_seq_len)
padded_labels = pad_batch_data(labels, pad_idx=-100, max_seq_len=max_seq_len)
return_list.append(
[
padded_token_ids,
padded_labels,
]
)

images_seq_mask_list = [np.concatenate(images_seq_mask_list)]
padded_images_seq_mask = pad_batch_data(images_seq_mask_list, pad_idx=False, max_seq_len=max_seq_len)
return_list[-1].extend(
[
padded_position_ids,
images_spatial_crop_list,
padded_images_seq_mask,
]
)
return_images_list.append(images_list)

if training_args.num_nextn_predict_layers > 0:
# each sequence end index
batch_sequence_len = [len(sequence) for sequence in original_token_ids]
nbatch_pack_offset = [0] * sum(batch_sequence_len)
prefix_sum = 0
for sequence_len in batch_sequence_len[:-1]:
prefix_sum += sequence_len
nbatch_pack_offset[prefix_sum - 1] = 1
padded_nbatch_pack_offset = pad_batch_data([nbatch_pack_offset], pad_idx=0, max_seq_len=max_seq_len)
return_list[-1].append(padded_nbatch_pack_offset)

if model_args.use_attn_mask_startend_row_indices:
return_list[-1].append(
gen_attn_mask_startend_row_indices(original_token_ids, max_seq_len, model_args.use_global_causal_attn)
)
else:
return_list[-1].append(
gen_self_attn_mask(original_token_ids, max_seq_len, model_args.use_global_causal_attn)
)

transposed_list = list(zip(*return_list))
input_dict = {}
for key, tensors in zip(input_keys, transposed_list):
filtered_tensors = [paddle.to_tensor(x) for x in tensors if x is not None and len(x) > 0]
if filtered_tensors:
value = paddle.concat(filtered_tensors, axis=0)
else:
value = paddle.to_tensor([])
if len(value) > 0:
input_dict[key] = value
input_dict["images"] = return_images_list
return input_dict


def pad_batch_data(
insts,
pad_idx=0,
Expand Down
Loading
Loading