Skip to content

paddleocr-vl经过全量sft训练后推理速度显著变慢的原因? #1427

@HWChatGPT4

Description

@HWChatGPT4

训练数据集用自制的ostl格式的表格
训练过程参考的https://github.com/PaddlePaddle/ERNIE/blob/release/v1.5/docs/paddleocr_vl_sft_zh.md
训练后模型部署过程参考https://www.paddleocr.ai/latest/version3.x/pipeline_usage/PaddleOCR-VL.html#4 中的服务化部署的方式
推理时调用部署的接口(http://localhost:8080/layout-parsing)

训练前推理用时

Image

训练后推理用时

Image

参数设置
non-default args: {'api_server_count': 4, 'host': '0.0.0.0', 'port': 8080, 'chat_template': '/usr/local/lib/python3.10/site-packages/paddlex/inference/genai/chat_templates/PaddleOCR-VL-0.9B.jinja', 'model': '/home/paddleocr/paddleocrVL_sft_v1', 'trust_remote_code': True, 'max_model_len': 8192, 'served_model_name': ['PaddleOCR-VL-0.9B'], 'gpu_memory_utilization': 0.7, 'max_num_batched_tokens': 131072, 'max_num_seqs': 128}

我试过调max_model_len,max_num_seqs,gpu_memory_utilization,max_num_batched_tokens时间基本都是9-10s,都和训练前5s差距很大

推理显卡
rtx3080ti

训练显卡
A100-80GB

训练配置文件

train_dataset_type: "erniekit"
eval_dataset_type: "erniekit"
train_dataset_path: "./0112train/merged_output.jsonl"
train_dataset_prob: "1.0"
max_seq_len: 8192
num_samples_each_epoch: 6000000
use_pic_id: False
sft_replace_ids: True
sft_image_normalize: True
sft_image_rescale: True
image_dtype: "float32"

model_name_or_path: ./PaddleOCR-VL

fine_tuning: Full

multimodal: True
use_flash_attention: True
use_sparse_flash_attn: True

stage: OCR-VL-SFT
seed: 23
do_train: True
distributed_dataloader: False
dataloader_num_workers: 8
prefetch_factor: 10

batch_size: 8
packing_size: 1
gradient_accumulation_steps: 8

packing: True
padding: False
num_train_epochs: 2
max_steps: 100
save_steps: 20
save_total_limit: 5
save_strategy: steps
logging_steps: 1
release_grads: True

logging_dir: ./0112PaddleOCR-VL-SFT-table/tensorboard_logs/
output_dir: ./0112PaddleOCR-VL-SFT-table
disable_tqdm: True

warmup_steps: 10
learning_rate: 5.0e-6
lr_scheduler_type: cosine
min_lr: 5.0e-7
layerwise_lr_decay_bound: 1.0
from_scratch: 0

weight_decay: 0.1
adam_epsilon: 1.0e-8
adam_beta1: 0.9
adam_beta2: 0.95

tensor_parallel_degree: 1

pipeline_parallel_degree: 1
sharding_parallel_degree: 1

sharding: stage1

sequence_parallel: False
pipeline_parallel_config: enable_delay_scale_loss enable_release_grads disable_partial_send_recv
recompute: True
recompute_granularity: "full"
recompute_use_reentrant: True
compute_type: bf16
fp16_opt_level: O2
disable_ckpt_quant: True

amp_custom_white_list:

  • lookup_table
  • lookup_table_v2
  • flash_attn
  • matmul
  • matmul_v2
  • fused_gemm_epilogue
    amp_custom_black_list:
  • reduce_sum
  • softmax_with_cross_entropy
  • c_softmax_with_cross_entropy
  • elementwise_div
  • sin
  • cos
    unified_checkpoint: True
    convert_from_hf: True
    save_to_hf: True

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions