diff --git a/lazyllm/tools/data/operators/code_sandbox_op.py b/lazyllm/tools/data/operators/code_sandbox_op.py new file mode 100644 index 000000000..68dd629e6 --- /dev/null +++ b/lazyllm/tools/data/operators/code_sandbox_op.py @@ -0,0 +1,79 @@ +from typing import List, Tuple, Dict +from lazyllm import LOG +from ..base_data import data_register +from .python_executor import PythonExecutor +# ————后续修改———— + +CodeGenOps = data_register.new_group('codegen_ops') + +class CodeSandboxSampleEvaluator(CodeGenOps): + + def __init__(self, language: str = 'python', timeout_length: int = 15, use_process_isolation: bool = True, + input_code_key: str = 'generated_code', output_status_key: str = 'sandbox_status', + output_log_key: str = 'sandbox_log', **kwargs): + super().__init__(**kwargs) + self.language = language + self.timeout_length = timeout_length + self.use_process_isolation = use_process_isolation + self.input_code_key = input_code_key + self.output_status_key = output_status_key + self.output_log_key = output_log_key + + LOG.info(f'Initializing {self.__class__.__name__}...') + + if PythonExecutor is None: + raise ImportError("PythonExecutor not found. Please ensure 'python_executor.py' is in the same directory.") + + self.executor = PythonExecutor( + get_answer_from_stdout=True, + timeout_length=timeout_length, + use_process_isolation=use_process_isolation + ) + self.score_name = 'SandboxValidationScore' + LOG.info(f'{self.__class__.__name__} initialized.') + + def _execute_code_batch(self, code_list: List[str]) -> List[Tuple[str, str]]: + results_with_reports = self.executor.batch_apply(code_list, messages=[]) + + processed_results = [] + for (result, report) in results_with_reports: + if report == 'Done': + status = 'PASS' + log = result.get('text', '') if isinstance(result, dict) else result + else: + status = 'FAIL' + log = report + + processed_results.append((status, log)) + + return processed_results + + def forward_batch_input(self, inputs: List[Dict], **kwargs) -> List[Dict]: + LOG.info(f'Evaluating {self.score_name}...') + + if not inputs: + return [] + + code_list = [] + for i, item in enumerate(inputs): + if self.input_code_key not in item: + raise ValueError(f"Missing required key '{self.input_code_key}' in item at index {i}") + code_list.append(item[self.input_code_key]) + + execution_results = self._execute_code_batch(code_list) + + results = [] + for item, (status, log) in zip(inputs, execution_results): + new_item = item.copy() + # Check for conflicts + if self.output_status_key in new_item: + raise ValueError(f"The key '{self.output_status_key}' already exists and would be overwritten.") + if self.output_log_key in new_item: + raise ValueError(f"The key '{self.output_log_key}' already exists and would be overwritten.") + + new_item[self.output_status_key] = status + new_item[self.output_log_key] = log + results.append(new_item) + + LOG.info('Evaluation complete!') + return results diff --git a/lazyllm/tools/data/operators/python_executor.py b/lazyllm/tools/data/operators/python_executor.py new file mode 100644 index 000000000..bef383987 --- /dev/null +++ b/lazyllm/tools/data/operators/python_executor.py @@ -0,0 +1,590 @@ +import io +import queue +import regex +import traceback +import copy +import multiprocessing +from multiprocessing import Process +from typing import Any, Dict, Optional, Tuple, List, Union +from tqdm import tqdm +from concurrent.futures import TimeoutError +from contextlib import redirect_stdout +import base64 +from io import BytesIO +from lazyllm import LOG, thirdparty +Image = thirdparty.PIL.Image +try: + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + MATPLOTLIB_AVAILABLE = True +except ImportError: + MATPLOTLIB_AVAILABLE = False + plt = None + + +def encode_image(image_path): + '''Encode an image file to base64 string.''' + with open(image_path, 'rb') as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + +def base64_to_image( + base64_str: str, + remove_prefix: bool = True, + convert_mode: Optional[str] = 'RGB' +) -> Union[Image.Image, None]: + ''' + Convert a Base64-encoded image string to a PIL Image object. + + Args: + base64_str: Base64-encoded image string (can include data: prefix) + remove_prefix: Whether to automatically remove the 'data:image/...' prefix (default True) + convert_mode: Convert to the specified mode (such as 'RGB'/'RGBA', None means no conversion) + + Returns: + PIL.Image.Image object, or None if decoding fails + + Examples: + >>> img = base64_to_image('data:image/png;base64,iVBORw0KGg...') + >>> img = base64_to_image('iVBORw0KGg...', remove_prefix=False) + ''' + try: + # 1. Handle Base64 prefix + if remove_prefix and ',' in base64_str: + base64_str = base64_str.split(',')[1] + + # 2. Decode Base64 + image_data = base64.b64decode(base64_str) + + # 3. Convert to PIL Image + image = Image.open(BytesIO(image_data)) + + # 4. Optional mode conversion + if convert_mode: + image = image.convert(convert_mode) + + return image + + except (base64.binascii.Error, Exception) as e: + LOG.error(f'Base64 decode failed: {str(e)}') + return None + + +class PersistentWorker: + '''Persistent worker process.''' + + # Runtime class registry for pickle-safe serialization + RUNTIME_REGISTRY = { + 'ImageRuntime': None, # Will be set later to avoid circular import + 'DateRuntime': None, + 'ColorObjectRuntime': None, + 'GenericRuntime': None, + } + + @classmethod + def _get_runtime_class(cls, runtime_identifier): + '''Get runtime class from identifier (class name or class object).''' + if isinstance(runtime_identifier, str): + # String identifier - look up in registry + if runtime_identifier in cls.RUNTIME_REGISTRY: + return cls.RUNTIME_REGISTRY[runtime_identifier] + else: + # Default to ImageRuntime if not found + return cls.RUNTIME_REGISTRY.get('ImageRuntime', ImageRuntime) + elif isinstance(runtime_identifier, type): + # Class object - get its name and look up + class_name = runtime_identifier.__name__ + return cls.RUNTIME_REGISTRY.get(class_name, runtime_identifier) + else: + # Default fallback + return cls.RUNTIME_REGISTRY.get('ImageRuntime', ImageRuntime) + + @classmethod + def _get_runtime_identifier(cls, runtime_class): + '''Convert runtime class to pickle-safe identifier.''' + if runtime_class is None: + return 'ImageRuntime' + elif isinstance(runtime_class, str): + return runtime_class + else: + return runtime_class.__name__ + + def __init__(self): + self.input_queue = multiprocessing.Queue() + self.output_queue = multiprocessing.Queue() + self.process = None + self.start() + + def start(self): + '''Start the worker process.''' + self.process = Process(target=self._worker_loop) + self.process.daemon = True + self.process.start() + + def _worker_loop(self): + '''Main loop for the worker process.''' + runtime = None + runtime_class = None + + while True: + try: + # Get task + task = self.input_queue.get() + + if task is None: # Termination signal + break + + task_type = task.get('type') + if task_type == 'init': + runtime, runtime_class, response = self._handle_init(task) + elif task_type == 'execute': + runtime, runtime_class, response = self._handle_execute(task, runtime, runtime_class) + elif task_type == 'reset': + runtime, runtime_class, response = self._handle_reset(task) + else: + continue + + self.output_queue.put(response) + + except Exception as e: + self.output_queue.put({ + 'status': 'error', + 'error': f'Worker error: {str(e)}', + 'traceback': traceback.format_exc() + }) + + def _create_runtime(self, task): + messages = task.get('messages', []) + runtime_identifier = task.get('runtime_class', 'ImageRuntime') + runtime_class = self._get_runtime_class(runtime_identifier) + runtime = runtime_class(messages) + return runtime, runtime_class + + def _handle_init(self, task): + runtime, runtime_class = self._create_runtime(task) + return runtime, runtime_class, {'status': 'success', 'result': 'Initialized'} + + def _handle_reset(self, task): + runtime, runtime_class = self._create_runtime(task) + return runtime, runtime_class, {'status': 'success', 'result': 'Reset'} + + def _handle_execute(self, task, runtime, runtime_class): + if runtime is None: + runtime, runtime_class = self._create_runtime(task) + + code = task.get('code') + get_answer_from_stdout = task.get('get_answer_from_stdout', True) + answer_symbol = task.get('answer_symbol') + answer_expr = task.get('answer_expr') + + try: + # Record the number of images before execution + pre_figures_count = len(runtime._global_vars.get('_captured_figures', [])) + + if get_answer_from_stdout: + program_io = io.StringIO() + with redirect_stdout(program_io): + runtime.exec_code('\n'.join(code)) + program_io.seek(0) + result = program_io.read() + elif answer_symbol: + runtime.exec_code('\n'.join(code)) + result = runtime._global_vars.get(answer_symbol, '') + elif answer_expr: + runtime.exec_code('\n'.join(code)) + result = runtime.eval_code(answer_expr) + else: + if len(code) > 1: + runtime.exec_code('\n'.join(code[:-1])) + result = runtime.eval_code(code[-1]) + else: + runtime.exec_code('\n'.join(code)) + result = '' + + # Get newly generated images + all_figures = runtime._global_vars.get('_captured_figures', []) + new_figures = all_figures[pre_figures_count:] + + # Build result + if new_figures: + result = {'text': result, 'images': new_figures} if result else {'images': new_figures} + else: + result = {'text': result} if result else {} + + response = {'status': 'success', 'result': result, 'report': 'Done'} + except Exception as e: + error_msg = f'{type(e).__name__}: {str(e)}' + response = { + 'status': 'error', + 'error': error_msg, + 'traceback': traceback.format_exc(), + 'report': f'Error: {error_msg}' + } + return runtime, runtime_class, response + + def execute(self, code: List[str], messages: list = None, runtime_class=None, + get_answer_from_stdout=True, answer_symbol=None, answer_expr=None, timeout: int = 30): + '''Execute code.''' + # Convert runtime class to pickle-safe identifier + runtime_identifier = self._get_runtime_identifier(runtime_class) + self.input_queue.put({ + 'type': 'execute', + 'code': code, + 'messages': messages, + 'runtime_class': runtime_identifier, + 'get_answer_from_stdout': get_answer_from_stdout, + 'answer_symbol': answer_symbol, + 'answer_expr': answer_expr + }) + + try: + result = self.output_queue.get(timeout=timeout) + return result + except queue.Empty: + return { + 'status': 'error', + 'error': 'Execution timeout', + 'report': 'Timeout Error' + } + + def init_runtime(self, messages: list, runtime_class=None): + '''Initialize runtime.''' + # Convert runtime class to pickle-safe identifier + runtime_identifier = self._get_runtime_identifier(runtime_class) + self.input_queue.put({ + 'type': 'init', + 'messages': messages, + 'runtime_class': runtime_identifier + }) + return self.output_queue.get() + + def reset_runtime(self, messages: list = None, runtime_class=None): + '''Reset runtime.''' + # Convert runtime class to pickle-safe identifier + runtime_identifier = self._get_runtime_identifier(runtime_class) + self.input_queue.put({ + 'type': 'reset', + 'messages': messages, + 'runtime_class': runtime_identifier + }) + return self.output_queue.get() + + def terminate(self): + '''Terminate the worker process.''' + if self.process and self.process.is_alive(): + self.input_queue.put(None) + self.process.join(timeout=5) + if self.process.is_alive(): + self.process.terminate() + + +class GenericRuntime: + GLOBAL_DICT = {} + LOCAL_DICT = None + HEADERS = [] + + def __init__(self): + self._global_vars = copy.copy(self.GLOBAL_DICT) + self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None + self._captured_figures = [] + + for c in self.HEADERS: + self.exec_code(c) + + def exec_code(self, code_piece: str) -> None: + # Security check + if regex.search(r'(\s|^)?(input|os\.system|subprocess)\(', code_piece): + raise RuntimeError('Forbidden function calls detected') + + # Detect and modify plt.show() calls + if 'plt.show()' in code_piece and MATPLOTLIB_AVAILABLE: + modified_code = code_piece.replace('plt.show()', ''' +# Capture current figure +buf = io.BytesIO() +plt.savefig(buf, format='png') +buf.seek(0) +_captured_image = base64.b64encode(buf.read()).decode('utf-8') +_captured_figures.append(_captured_image) +plt.close() +''') + # Ensure _captured_figures variable exists + if '_captured_figures' not in self._global_vars: + self._global_vars['_captured_figures'] = [] + + exec(modified_code, self._global_vars) + else: + exec(code_piece, self._global_vars) + + def eval_code(self, expr: str) -> Any: + return eval(expr, self._global_vars) + + def inject(self, var_dict: Dict[str, Any]) -> None: + for k, v in var_dict.items(): + self._global_vars[k] = v + + @property + def answer(self): + return self._global_vars.get('answer', None) + + @property + def captured_figures(self): + return self._global_vars.get('_captured_figures', []) + + +class ImageRuntime(GenericRuntime): + HEADERS = [ + '''try: + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + MATPLOTLIB_AVAILABLE = True +except ImportError: + MATPLOTLIB_AVAILABLE = False + plt = None +''', + 'from lazyllm import thirdparty; Image = thirdparty.PIL.Image', + 'import io', + 'import base64', + 'from lazyllm import thirdparty; np = thirdparty.numpy', + '_captured_figures = []', # Initialize image capture list + ] + + def __init__(self, messages): + super().__init__() + + image_var_dict = {} + image_var_idx = 0 + init_captured_figures = [] + + for message_item in messages: + content = message_item['content'] + for item in content: + if isinstance(item, dict): + item_type = item.get('type') + if item_type == 'image_url': + item_image_url = item['image_url']['url'] + image = base64_to_image(item_image_url) + if image: + image_var_dict[f'image_clue_{image_var_idx}'] = image + init_captured_figures.append(base64.b64encode( + BytesIO(image.tobytes()).getvalue()).decode('utf-8')) + image_var_idx += 1 + + image_var_dict['_captured_figures'] = init_captured_figures + self.inject(image_var_dict) + + +class DateRuntime(GenericRuntime): + GLOBAL_DICT = {} + HEADERS = [ + 'import datetime', + 'from dateutil.relativedelta import relativedelta', + 'timedelta = relativedelta' + ] + + +class CustomDict(dict): + def __iter__(self): + return list(super().__iter__()).__iter__() + + +class ColorObjectRuntime(GenericRuntime): + GLOBAL_DICT = {'dict': CustomDict} + + +class PythonExecutor: + def __init__( + self, + runtime_class=None, + get_answer_symbol: Optional[str] = None, + get_answer_expr: Optional[str] = None, + get_answer_from_stdout: bool = True, + timeout_length: int = 20, + use_process_isolation: bool = True, + ) -> None: + self.runtime_class = runtime_class if runtime_class else ImageRuntime + self.answer_symbol = get_answer_symbol + self.answer_expr = get_answer_expr + self.get_answer_from_stdout = get_answer_from_stdout + self.timeout_length = timeout_length + self.use_process_isolation = use_process_isolation + self.persistent_worker = None + + def _ensure_worker(self): + '''Ensure the worker process exists.''' + if self.persistent_worker is None: + self.persistent_worker = PersistentWorker() + + def process_generation_to_code(self, gens: str): + return [g.split('\n') for g in gens] + + def execute( + self, + code, + messages, + get_answer_from_stdout=True, + runtime_class=None, + answer_symbol=None, + answer_expr=None, + ) -> Tuple[Union[str, Dict[str, Any]], str]: + + if self.use_process_isolation: + # Ensure worker process exists + self._ensure_worker() + + # Execute code + result = self.persistent_worker.execute( + code, + messages, + runtime_class or self.runtime_class, + get_answer_from_stdout, + answer_symbol, + answer_expr, + timeout=self.timeout_length + ) + + if result['status'] == 'success': + return result['result'], result.get('report', 'Done') + else: + error_msg = result.get('error', 'Unknown error') + error_result = { + 'error': error_msg, + 'traceback': result.get('traceback', '') + } + return error_result, result.get('report', f'Error: {error_msg}') + else: + # Non-isolation mode (for backward compatibility) + runtime = runtime_class(messages) if runtime_class else self.runtime_class(messages) + + try: + if get_answer_from_stdout: + program_io = io.StringIO() + with redirect_stdout(program_io): + runtime.exec_code('\n'.join(code)) + program_io.seek(0) + result = program_io.read() + elif answer_symbol: + runtime.exec_code('\n'.join(code)) + result = runtime._global_vars.get(answer_symbol, '') + elif answer_expr: + runtime.exec_code('\n'.join(code)) + result = runtime.eval_code(answer_expr) + else: + if len(code) > 1: + runtime.exec_code('\n'.join(code[:-1])) + result = runtime.eval_code(code[-1]) + else: + runtime.exec_code('\n'.join(code)) + result = '' + + # Check for captured figures + captured_figures = runtime.captured_figures + if captured_figures: + result = { + 'text': result, + 'images': captured_figures + } if result else {'images': captured_figures} + else: + result = {'text': result} if result else {} + + report = 'Done' + + except Exception as e: + error_msg = f'{type(e).__name__}: {str(e)}' + result = { + 'error': error_msg, + 'traceback': traceback.format_exc() + } + report = f'Error: {error_msg}' + + return result, report + + def apply(self, code, messages): + return self.batch_apply([code], messages)[0] + + @staticmethod + def truncate(s, max_length=400): + if isinstance(s, dict): + # If it is a dict (with images), truncate only the text part + if 'text' in s: + half = max_length // 2 + if len(s['text']) > max_length: + s['text'] = s['text'][:half] + '...' + s['text'][-half:] + return s + else: + half = max_length // 2 + if isinstance(s, str) and len(s) > max_length: + s = s[:half] + '...' + s[-half:] + return s + + def batch_apply(self, batch_code, messages): + all_code_snippets = self.process_generation_to_code(batch_code) + + timeout_cnt = 0 + all_exec_results = [] + + if len(all_code_snippets) > 100: + progress_bar = tqdm(total=len(all_code_snippets), desc='Execute') + else: + progress_bar = None + + for code in all_code_snippets: + try: + result = self.execute( + code, + messages=messages, + get_answer_from_stdout=self.get_answer_from_stdout, + runtime_class=self.runtime_class, + answer_symbol=self.answer_symbol, + answer_expr=self.answer_expr, + ) + all_exec_results.append(result) + except TimeoutError as error: + LOG.warning(error) + all_exec_results.append(('', 'Timeout Error')) + timeout_cnt += 1 + except Exception as error: + LOG.error(f'Error in batch_apply: {error}') + all_exec_results.append(('', f'Error: {str(error)}')) + + if progress_bar is not None: + progress_bar.update(1) + + if progress_bar is not None: + progress_bar.close() + + batch_results = [] + for _code, (res, report) in zip(all_code_snippets, all_exec_results): + # Handle results + if isinstance(res, dict): + # If result contains images, special handling + if 'text' in res: + res['text'] = str(res['text']).strip() + res['text'] = self.truncate(res['text']) + report = str(report).strip() + report = self.truncate(report) + else: + # Normal text result + res = str(res).strip() + res = self.truncate(res) + report = str(report).strip() + report = self.truncate(report) + batch_results.append((res, report)) + return batch_results + + def reset(self, messages=None): + '''Reset executor state.''' + if self.use_process_isolation and self.persistent_worker: + self.persistent_worker.reset_runtime(messages, self.runtime_class) + + def __del__(self): + '''Clean up resources.''' + if self.persistent_worker: + self.persistent_worker.terminate() + + +# Initialize runtime registry after all classes are defined +PersistentWorker.RUNTIME_REGISTRY['ImageRuntime'] = ImageRuntime +PersistentWorker.RUNTIME_REGISTRY['DateRuntime'] = DateRuntime +PersistentWorker.RUNTIME_REGISTRY['ColorObjectRuntime'] = ColorObjectRuntime +PersistentWorker.RUNTIME_REGISTRY['GenericRuntime'] = GenericRuntime