1515The main entry point to run the PPO algorithm
1616"""
1717
18+ from copy import deepcopy
1819from typing import Literal , Optional , Union
1920
2021import numpy as np
2122import psutil
2223import torch
2324import torch .distributed as dist
24- from copy import deepcopy
2525from accelerate import init_empty_weights
2626from codetiming import Timer
2727from torch .distributed .device_mesh import init_device_mesh
4242from ..single_controller .base import Worker
4343from ..single_controller .base .decorator import Dispatch , register
4444from ..utils .checkpoint .fsdp_checkpoint_manager import FSDPCheckpointManager
45+ from ..utils .dataset import process_image
4546from ..utils .flops_counter import FlopsCounter
4647from ..utils .fsdp_utils import (
4748 get_fsdp_wrap_policy ,
5152 offload_fsdp_model ,
5253 offload_fsdp_optimizer ,
5354)
54- from ..utils .dataset import process_image
5555from ..utils .model_utils import print_gpu_memory_usage , print_model_size
5656from ..utils .tokenizer import get_processor , get_tokenizer
5757from ..utils .torch_dtypes import PrecisionType
@@ -436,10 +436,9 @@ def preprocess_multi_modal_data(self, data: DataProto):
436436 processed_images = []
437437 for multi_modal_data in multi_modal_data_copy :
438438 processed_per_query_images = []
439- for image in multi_modal_data ['image' ]:
440- processed_per_query_images .append (
441- process_image (image , min_pixels = min_pixels , max_pixels = max_pixels )
442- )
439+ for image in multi_modal_data ["image" ]:
440+ processed_per_query_images .append (process_image (image , min_pixels = min_pixels , max_pixels = max_pixels ))
441+
443442 processed_images .append (processed_per_query_images )
444443
445444 # Note: Using the alternative (commented) code below to process images can lead to subtle resize issues:
@@ -454,17 +453,20 @@ def preprocess_multi_modal_data(self, data: DataProto):
454453 # for j, image in enumerate(per_query_images):
455454 # images[i][j] = process_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
456455
457- multi_modal_inputs = np .array ([
458- dict (self .processor .image_processor (images = per_query_images , videos = None ))
459- for per_query_images in processed_images
460- ], dtype = object )
456+ multi_modal_inputs = np .array (
457+ [
458+ dict (self .processor .image_processor (images = per_query_images , videos = None ))
459+ for per_query_images in processed_images
460+ ],
461+ dtype = object ,
462+ )
461463 data .non_tensor_batch ["multi_modal_inputs" ] = multi_modal_inputs
462464
463465 @register (dispatch_mode = Dispatch .DP_COMPUTE_PROTO )
464466 def update_actor (self , data : DataProto ):
465467 assert self ._is_actor
466468 if "multi_modal_inputs" in self ._cache :
467- data .non_tensor_batch [' multi_modal_inputs' ] = deepcopy (self ._cache [' multi_modal_inputs' ])
469+ data .non_tensor_batch [" multi_modal_inputs" ] = deepcopy (self ._cache [" multi_modal_inputs" ])
468470 elif "multi_modal_data" in data .non_tensor_batch :
469471 self .preprocess_multi_modal_data (data )
470472
@@ -545,12 +547,14 @@ def generate_sequences(self, prompts: DataProto):
545547 cached_multi_modal_data = None
546548 if "multi_modal_data" in prompts .non_tensor_batch :
547549 cached_multi_modal_data = deepcopy (prompts .non_tensor_batch ["multi_modal_data" ])
548- min_pixels = prompts .meta_info [' min_pixels' ]
549- max_pixels = prompts .meta_info [' max_pixels' ]
550+ min_pixels = prompts .meta_info [" min_pixels" ]
551+ max_pixels = prompts .meta_info [" max_pixels" ]
550552 processed_images = []
551553 for i , multi_modal_data in enumerate (prompts .non_tensor_batch ["multi_modal_data" ]):
552554 for j , image in enumerate (multi_modal_data ["image" ]):
553- multi_modal_data ['image' ][j ] = process_image (image , min_pixels = min_pixels , max_pixels = max_pixels )
555+ multi_modal_data ["image" ][j ] = process_image (
556+ image , min_pixels = min_pixels , max_pixels = max_pixels
557+ )
554558 processed_images .append (multi_modal_data )
555559 prompts .non_tensor_batch ["multi_modal_data" ] = processed_images
556560
@@ -562,7 +566,9 @@ def generate_sequences(self, prompts: DataProto):
562566 output .non_tensor_batch ["multi_modal_data" ] = cached_multi_modal_data
563567 if sampling_n > 1 :
564568 output .non_tensor_batch ["multi_modal_data" ] = np .repeat (
565- output .non_tensor_batch ["multi_modal_data" ], repeats = sampling_n , axis = 0 ,
569+ output .non_tensor_batch ["multi_modal_data" ],
570+ repeats = sampling_n ,
571+ axis = 0 ,
566572 )
567573
568574 output = self .rollout_sharding_manager .postprocess_data (output )
@@ -577,7 +583,7 @@ def compute_log_probs(self, data: DataProto):
577583 if "multi_modal_data" in data .non_tensor_batch :
578584 self .preprocess_multi_modal_data (data )
579585 # create cache for multi_modal_inputs
580- self ._cache [' multi_modal_inputs' ] = deepcopy (data .non_tensor_batch [' multi_modal_inputs' ])
586+ self ._cache [" multi_modal_inputs" ] = deepcopy (data .non_tensor_batch [" multi_modal_inputs" ])
581587
582588 data = data .to (torch .cuda .current_device ())
583589 if self ._use_param_offload :
@@ -611,7 +617,7 @@ def compute_ref_log_probs(self, data: DataProto):
611617 # not in the ref_policy's or critic's caches.
612618 assert self ._is_ref
613619 if "multi_modal_inputs" in self ._cache :
614- data .non_tensor_batch [' multi_modal_inputs' ] = deepcopy (self ._cache [' multi_modal_inputs' ])
620+ data .non_tensor_batch [" multi_modal_inputs" ] = deepcopy (self ._cache [" multi_modal_inputs" ])
615621 elif "multi_modal_data" in data .non_tensor_batch :
616622 self .preprocess_multi_modal_data (data )
617623
@@ -643,7 +649,7 @@ def compute_values(self, data: DataProto):
643649 # The `self._cache` is empty here since cached `multi_modal_inputs` is only saved in the actor's _cache,
644650 # not in the ref_policy's or critic's caches.
645651 if "multi_modal_inputs" in self ._cache :
646- data .non_tensor_batch [' multi_modal_inputs' ] = deepcopy (self ._cache [' multi_modal_inputs' ])
652+ data .non_tensor_batch [" multi_modal_inputs" ] = deepcopy (self ._cache [" multi_modal_inputs" ])
647653 elif "multi_modal_data" in data .non_tensor_batch :
648654 self .preprocess_multi_modal_data (data )
649655
@@ -668,7 +674,7 @@ def update_critic(self, data: DataProto):
668674 # The `self._cache` is empty here since cached `multi_modal_inputs` is only saved in the actor's _cache,
669675 # not in the ref_policy's or critic's caches.
670676 if "multi_modal_inputs" in self ._cache :
671- data .non_tensor_batch [' multi_modal_inputs' ] = deepcopy (self ._cache [' multi_modal_inputs' ])
677+ data .non_tensor_batch [" multi_modal_inputs" ] = deepcopy (self ._cache [" multi_modal_inputs" ])
672678 elif "multi_modal_data" not in data .non_tensor_batch :
673679 self .preprocess_multi_modal_data (data )
674680
0 commit comments