diff --git a/cookbooks/distillation.py b/cookbooks/distillation.py
index a392378..15eeb47 100644
--- a/cookbooks/distillation.py
+++ b/cookbooks/distillation.py
@@ -1,6 +1,8 @@
import asyncio
import json
import os
+import time
+from multiprocessing import Pool, freeze_support
from camel.datasets.static_dataset import StaticDataset
from camel.datasets.few_shot_generator import FewShotGenerator
@@ -13,109 +15,140 @@
from camel.environments import SingleStepEnv, Action
from camel.logger import get_logger, set_log_level
-# Set up logger
+# Logger setup
logger = get_logger(__name__)
set_log_level('INFO')
-if not os.environ["OPENAI_API_KEY"]:
- raise RuntimeError("No OpenAI API key found")
-
-DEEPSEEK_API_KEY = "ENTER API KEY HERE"
-
-if DEEPSEEK_API_KEY == "ENTER API KEY HERE":
- raise RuntimeError("Please enter your API key.")
-
-# Enable DeepSeek reasoning content
-os.environ["GET_REASONING_CONTENT"] = "true"
-
-OUTPUT_FILE = "math_dataset.json"
-ALL_RESPONSES_FILE = "all_responses.txt"
-
-# Load existing dataset if it exists
-if os.path.exists(OUTPUT_FILE):
- with open(OUTPUT_FILE, 'r') as f:
- dataset = json.load(f)
- logger.info(f"Loaded existing dataset with {len(dataset)} examples")
-else:
- dataset = []
- logger.info("Starting new dataset")
-
-logger.info("Loading advanced math dataset...")
-# Load the advanced math dataset and filter for Level 4 and 5
-with open('data/advanced_math/seed_dataset.json', 'r') as f:
- seed_data = json.load(f)
-
-# Filter for Level 4 and 5 questions
-filtered_seed_data = [
- example for example in seed_data
- if example.get('metadata', {}).get('level') in ['Level 4', 'Level 5']
-]
-
-
-
-logger.info(f"Filtered seed dataset from {len(seed_data)} to {len(filtered_seed_data)} examples (Level 4 and 5 only)")
-logger.info(f"Level 4: {sum(1 for x in filtered_seed_data if x['metadata']['level'] == 'Level 4')}")
-logger.info(f"Level 5: {sum(1 for x in filtered_seed_data if x['metadata']['level'] == 'Level 5')}")
-
-seed_dataset = StaticDataset(filtered_seed_data)
-
-logger.info(f"Loaded seed dataset with {len(seed_data)} examples")
-
-logger.info("Initializing models...")
-# Initialize models
-model_4o = ModelFactory.create(
- model_platform=ModelPlatformType.OPENAI,
- model_type=ModelType.GPT_4O_MINI,
- model_config_dict=ChatGPTConfig().as_dict(),
- timeout=1000
-)
-
-model_deepseek = ModelFactory.create(
- model_platform=ModelPlatformType.DEEPSEEK,
- model_type=ModelType.DEEPSEEK_REASONER,
- api_key=DEEPSEEK_API_KEY,
- timeout = 1000
-)
-logger.info("Models initialized successfully")
-
-logger.info("Setting up extractors and verifiers...")
-# Initialize extractors and verifiers
-extractor = BaseExtractor([[BoxedStrategy()]])
-asyncio.run(extractor.setup())
-
-# Python verifier for FewShotGenerator
-python_verifier = PythonVerifier(required_packages=["sympy"])
-asyncio.run(python_verifier.setup(uv=False))
-
-# Math verifier for final answer comparison
-math_verifier = MathVerifier(
- extractor=extractor,
- float_rounding=6,
- numeric_precision=15,
- enable_wrapping=True
-)
-asyncio.run(math_verifier.setup())
-logger.info("Extractors and verifiers setup complete")
-
-logger.info("Initializing generator and environment...")
-# Initialize generator with seed dataset using PythonVerifier
-generator = FewShotGenerator(
- buffer=10,
- seed_dataset=seed_dataset,
- verifier=python_verifier, # Use Python verifier here
- model=model_4o
-)
+# Function to process a single record
+def process_record(process_id, user_prompt, api_key, seed_file_path, output_file):
+ # Create a new environment with all necessary components for this process
+ try:
+ # Set up DeepSeek environment variable
+ os.environ["GET_REASONING_CONTENT"] = "true"
+
+ # Initialize model for this process
+ model_process = ModelFactory.create(
+ model_platform=ModelPlatformType.DEEPSEEK,
+ model_type=ModelType.DEEPSEEK_REASONER,
+ api_key=api_key,
+ timeout=1000
+ )
+
+ # Initialize agent
+ local_agent = ChatAgent(model=model_process)
+
+ # Set up environment components
+ extractor = BaseExtractor([[BoxedStrategy()]])
+ asyncio.run(extractor.setup())
+
+ # Python verifier
+ python_verifier = PythonVerifier(required_packages=["sympy"])
+ asyncio.run(python_verifier.setup(uv=True))
+
+ # Math verifier
+ math_verifier = MathVerifier(
+ extractor=extractor,
+ float_rounding=6,
+ numeric_precision=15,
+ enable_wrapping=True
+ )
+ asyncio.run(math_verifier.setup())
+
+ # Load seed dataset
+ with open(seed_file_path, 'r') as f:
+ seed_data = json.load(f)
+
+ seed_dataset = StaticDataset(seed_data)
+
+ # Initialize generator
+ model_4o = ModelFactory.create(
+ model_platform=ModelPlatformType.OPENAI,
+ model_type=ModelType.GPT_4O_MINI,
+ model_config_dict=ChatGPTConfig().as_dict(),
+ timeout=1000
+ )
+
+ generator = FewShotGenerator(
+ buffer=10,
+ seed_dataset=seed_dataset,
+ verifier=python_verifier,
+ model=model_4o
+ )
+
+ # Create environment
+ env = SingleStepEnv(generator, math_verifier)
+ asyncio.run(env.setup())
+
+ # Reset environment and get question
+ obs = asyncio.run(env.reset())
+ question = obs.question
+
+ # This is the bottleneck operation we're parallelizing
+ deepseek_response = local_agent.step(user_prompt + question).msgs[0].content
+
+ # Split the response into reasoning and answer parts
+ reasoning_part = ""
+ answer_part = deepseek_response
+
+ if "" in deepseek_response and "" in deepseek_response:
+ parts = deepseek_response.split("")
+ if len(parts) > 1:
+ reasoning_part = parts[0].replace("", "").strip()
+ answer_part = parts[1].strip()
+
+ # Verify the result
+ next_obs, reward, done, info = asyncio.run(env.step(Action(index=0, llm_response=deepseek_response)))
+
+ # Create data entry
+ data_entry = {
+ "question": question,
+ "answer": info['state'].final_answer if 'state' in info else '',
+ "response": answer_part,
+ "long_cot": reasoning_part,
+ "shots": obs.metadata.get('shots'),
+ "verified": reward > 0
+ }
+
+ # Return the result - we'll handle shared state synchronization in the main process
+ return data_entry, reward > 0, process_id
+
+ except Exception as e:
+ logger.error(f"Error processing record {process_id}: {str(e)}")
+ return None, False, process_id
-# Create environment with MathVerifier for final comparison
-env = SingleStepEnv(generator, math_verifier) # Use Math verifier here
-asyncio.run(env.setup())
-logger.info("Generator and environment initialized")
+def main():
+ start_time = time.time()
+
+ # Check API keys
+ if not os.environ.get("OPENAI_API_KEY"):
+ raise RuntimeError("No OpenAI API key found")
+
+ if not os.environ.get("DEEPSEEK_API_KEY"):
+ raise RuntimeError("No DeepSeek API key found")
+
+ # Setup file paths
+ OUTPUT_FILE = "math_dataset.json"
+ SEED_FILE_PATH = '/Users/enrei/Desktop/camel0209/camel/camel/verifiers/seed_dataset_first_20.json'
+ BATCH_SIZE = 50
+
+ # Load existing dataset if it exists
+ if os.path.exists(OUTPUT_FILE):
+ with open(OUTPUT_FILE, 'r') as f:
+ dataset = json.load(f)
+ logger.info(f"Loaded existing dataset with {len(dataset)} examples")
+ else:
+ dataset = []
+ logger.info("Starting new dataset")
-# Initialize agent for CoT generation
-agent = ChatAgent(model=model_deepseek)
+ logger.info("Loading advanced math dataset...")
+ # Load the seed dataset
+ with open(SEED_FILE_PATH, 'r') as f:
+ seed_data = json.load(f)
-# Define the prompt for CoT generation
-USER_PROMPT = """You are an agent designed to answer mathematical questions with clarity and precision. Your task is to provide a step-by-step explanation for
+ logger.info(f"Loaded seed dataset with {len(seed_data)} examples")
+
+ # Define the prompt for CoT generation
+ USER_PROMPT = """You are an agent designed to answer mathematical questions with clarity and precision. Your task is to provide a step-by-step explanation for
any mathematical problem posed by the user, ensuring the response is easy to follow. Adhere to these guidelines:
Analyze the mathematical question carefully and break down the solution process into clear, logical steps.
Use natural language to explain each step, incorporating LaTeX notation (e.g., $x + 2$)
@@ -126,57 +159,74 @@
The question you should answer is: """
-num_rejected = 0
-target_size = 1000
-
-logger.info("Starting generation and verification loop...")
-
-while sum(1 for entry in dataset if entry["verified"]) < target_size:
- logger.info(f"Current verified count: {sum(1 for entry in dataset if entry['verified'])}/{target_size}")
-
- obs = asyncio.run(env.reset())
- deepseek_response = agent.step(USER_PROMPT + obs.question).msgs[0].content
-
- # Split the response into reasoning and answer parts
- reasoning_part = ""
- answer_part = deepseek_response
-
- if "" in deepseek_response and "" in deepseek_response:
- parts = deepseek_response.split("")
- if len(parts) > 1:
- reasoning_part = parts[0].replace("", "").strip()
- answer_part = parts[1].strip()
-
- next_obs, reward, done, info = asyncio.run(env.step(Action(index=0, llm_response=deepseek_response)))
-
- # Create and save data entry (both verified and unverified)
- data_entry = {
- "question": obs.question,
- "answer": info['state'].final_answer if 'state' in info else '',
- "response": answer_part,
- "long_cot": reasoning_part,
- "shots": obs.metadata.get('shots'),
- "verified": reward > 0
- }
-
- # Add entry to dataset (both verified and unverified)
- dataset.append(data_entry)
- # Save immediately
- with open(OUTPUT_FILE, 'w') as f:
- json.dump(dataset, f, indent=2)
-
+ num_rejected = 0
+ target_size = 1000
verified_count = sum(1 for entry in dataset if entry["verified"])
- if reward > 0:
- logger.info(f"Verification successful - Added verified entry ({verified_count}/{target_size} verified)")
- else:
- num_rejected += 1
- logger.warning(f"Verification failed - Added unverified entry ({verified_count}/{target_size} verified)")
- agent.reset()
-
-# At the end, log statistics
-total_entries = len(dataset)
-verified_entries = sum(1 for entry in dataset if entry["verified"])
-logger.info(f"Generation complete. Total entries: {total_entries}")
-logger.info(f"Verified entries: {verified_entries}")
-logger.info(f"Rejected entries: {num_rejected}")
+ logger.info("Starting generation and verification loop...")
+
+ # Main processing loop with proper process pool management
+ while verified_count < target_size:
+ logger.info(f"Current verified count: {verified_count}/{target_size}")
+
+ # Determine batch size
+ remaining = target_size - verified_count
+ batch_size = min(BATCH_SIZE, remaining)
+
+ # Create arguments for each process
+ process_args = [
+ (i, USER_PROMPT, os.environ.get("DEEPSEEK_API_KEY"), SEED_FILE_PATH, OUTPUT_FILE)
+ for i in range(batch_size)
+ ]
+
+ logger.info(f"Processing batch with {batch_size} processes...")
+ with Pool(processes=batch_size) as pool:
+ # Use starmap to pass multiple arguments to process_record
+ results = pool.starmap(process_record, process_args)
+
+ # Close and join pool properly
+ pool.close()
+ pool.join()
+
+ # Process results from completed processes
+ newly_verified = 0
+ for data_entry, is_verified, proc_id in results:
+ if data_entry is not None:
+ # Add to dataset
+ dataset.append(data_entry)
+
+ # Update counters
+ if is_verified:
+ newly_verified += 1
+ verified_count += 1
+ logger.info(f"Process {proc_id}: Verification successful - Added verified entry ({verified_count}/{target_size} verified)")
+ else:
+ num_rejected += 1
+ logger.warning(f"Process {proc_id}: Verification failed - Added unverified entry ({verified_count}/{target_size} verified)")
+
+ # Save after each batch
+ with open(OUTPUT_FILE, 'w') as f:
+ json.dump(dataset, f, indent=2)
+
+ logger.info(f"Batch complete - Added {newly_verified} verified entries in this batch")
+
+ # Final statistics
+ total_entries = len(dataset)
+ verified_entries = sum(1 for entry in dataset if entry["verified"])
+ logger.info(f"Generation complete. Total entries: {total_entries}")
+ logger.info(f"Verified entries: {verified_entries}")
+ logger.info(f"Rejected entries: {num_rejected}")
+
+ end_time = time.time()
+
+ # Calculate total elapsed time
+ elapsed_time = end_time - start_time
+ logger.info(f"Total elapsed time: {elapsed_time:.2f}s")
+ logger.info(f"Average time per record: {elapsed_time/total_entries:.2f}s")
+ logger.info(f"Processed using up to {num_processes} worker processes.")
+
+# Required for multiprocessing on macOS (and Windows)
+if __name__ == '__main__':
+ # Add freeze_support to properly handle multiprocessing in frozen executables
+ freeze_support()
+ main()