Skip to content

feat: improve data distill performance by using batch process #69

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
348 changes: 199 additions & 149 deletions cookbooks/distillation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import json
import os
import time
from multiprocessing import Pool, freeze_support

from camel.datasets.static_dataset import StaticDataset
from camel.datasets.few_shot_generator import FewShotGenerator
Expand All @@ -13,109 +15,140 @@
from camel.environments import SingleStepEnv, Action
from camel.logger import get_logger, set_log_level

# Set up logger
# Logger setup
logger = get_logger(__name__)
set_log_level('INFO')

if not os.environ["OPENAI_API_KEY"]:
raise RuntimeError("No OpenAI API key found")

DEEPSEEK_API_KEY = "ENTER API KEY HERE"

if DEEPSEEK_API_KEY == "ENTER API KEY HERE":
raise RuntimeError("Please enter your API key.")

# Enable DeepSeek reasoning content
os.environ["GET_REASONING_CONTENT"] = "true"

OUTPUT_FILE = "math_dataset.json"
ALL_RESPONSES_FILE = "all_responses.txt"

# Load existing dataset if it exists
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, 'r') as f:
dataset = json.load(f)
logger.info(f"Loaded existing dataset with {len(dataset)} examples")
else:
dataset = []
logger.info("Starting new dataset")

logger.info("Loading advanced math dataset...")
# Load the advanced math dataset and filter for Level 4 and 5
with open('data/advanced_math/seed_dataset.json', 'r') as f:
seed_data = json.load(f)

# Filter for Level 4 and 5 questions
filtered_seed_data = [
example for example in seed_data
if example.get('metadata', {}).get('level') in ['Level 4', 'Level 5']
]



logger.info(f"Filtered seed dataset from {len(seed_data)} to {len(filtered_seed_data)} examples (Level 4 and 5 only)")
logger.info(f"Level 4: {sum(1 for x in filtered_seed_data if x['metadata']['level'] == 'Level 4')}")
logger.info(f"Level 5: {sum(1 for x in filtered_seed_data if x['metadata']['level'] == 'Level 5')}")

seed_dataset = StaticDataset(filtered_seed_data)

logger.info(f"Loaded seed dataset with {len(seed_data)} examples")

logger.info("Initializing models...")
# Initialize models
model_4o = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_config_dict=ChatGPTConfig().as_dict(),
timeout=1000
)

model_deepseek = ModelFactory.create(
model_platform=ModelPlatformType.DEEPSEEK,
model_type=ModelType.DEEPSEEK_REASONER,
api_key=DEEPSEEK_API_KEY,
timeout = 1000
)
logger.info("Models initialized successfully")

logger.info("Setting up extractors and verifiers...")
# Initialize extractors and verifiers
extractor = BaseExtractor([[BoxedStrategy()]])
asyncio.run(extractor.setup())

# Python verifier for FewShotGenerator
python_verifier = PythonVerifier(required_packages=["sympy"])
asyncio.run(python_verifier.setup(uv=False))

# Math verifier for final answer comparison
math_verifier = MathVerifier(
extractor=extractor,
float_rounding=6,
numeric_precision=15,
enable_wrapping=True
)
asyncio.run(math_verifier.setup())
logger.info("Extractors and verifiers setup complete")

logger.info("Initializing generator and environment...")
# Initialize generator with seed dataset using PythonVerifier
generator = FewShotGenerator(
buffer=10,
seed_dataset=seed_dataset,
verifier=python_verifier, # Use Python verifier here
model=model_4o
)
# Function to process a single record
def process_record(process_id, user_prompt, api_key, seed_file_path, output_file):
# Create a new environment with all necessary components for this process
try:
# Set up DeepSeek environment variable
os.environ["GET_REASONING_CONTENT"] = "true"

# Initialize model for this process
model_process = ModelFactory.create(
model_platform=ModelPlatformType.DEEPSEEK,
model_type=ModelType.DEEPSEEK_REASONER,
api_key=api_key,
timeout=1000
)

# Initialize agent
local_agent = ChatAgent(model=model_process)

# Set up environment components
extractor = BaseExtractor([[BoxedStrategy()]])
asyncio.run(extractor.setup())

# Python verifier
python_verifier = PythonVerifier(required_packages=["sympy"])
asyncio.run(python_verifier.setup(uv=True))

# Math verifier
math_verifier = MathVerifier(
extractor=extractor,
float_rounding=6,
numeric_precision=15,
enable_wrapping=True
)
asyncio.run(math_verifier.setup())

# Load seed dataset
with open(seed_file_path, 'r') as f:
seed_data = json.load(f)

seed_dataset = StaticDataset(seed_data)

# Initialize generator
model_4o = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_config_dict=ChatGPTConfig().as_dict(),
timeout=1000
)

generator = FewShotGenerator(
buffer=10,
seed_dataset=seed_dataset,
verifier=python_verifier,
model=model_4o
)

# Create environment
env = SingleStepEnv(generator, math_verifier)
asyncio.run(env.setup())

# Reset environment and get question
obs = asyncio.run(env.reset())
question = obs.question

# This is the bottleneck operation we're parallelizing
deepseek_response = local_agent.step(user_prompt + question).msgs[0].content

# Split the response into reasoning and answer parts
reasoning_part = ""
answer_part = deepseek_response

if "<think>" in deepseek_response and "</think>" in deepseek_response:
parts = deepseek_response.split("</think>")
if len(parts) > 1:
reasoning_part = parts[0].replace("<think>", "").strip()
answer_part = parts[1].strip()

# Verify the result
next_obs, reward, done, info = asyncio.run(env.step(Action(index=0, llm_response=deepseek_response)))

# Create data entry
data_entry = {
"question": question,
"answer": info['state'].final_answer if 'state' in info else '',
"response": answer_part,
"long_cot": reasoning_part,
"shots": obs.metadata.get('shots'),
"verified": reward > 0
}

# Return the result - we'll handle shared state synchronization in the main process
return data_entry, reward > 0, process_id

except Exception as e:
logger.error(f"Error processing record {process_id}: {str(e)}")
return None, False, process_id

# Create environment with MathVerifier for final comparison
env = SingleStepEnv(generator, math_verifier) # Use Math verifier here
asyncio.run(env.setup())
logger.info("Generator and environment initialized")
def main():
start_time = time.time()

# Check API keys
if not os.environ.get("OPENAI_API_KEY"):
raise RuntimeError("No OpenAI API key found")

if not os.environ.get("DEEPSEEK_API_KEY"):
raise RuntimeError("No DeepSeek API key found")

# Setup file paths
OUTPUT_FILE = "math_dataset.json"
SEED_FILE_PATH = '/Users/enrei/Desktop/camel0209/camel/camel/verifiers/seed_dataset_first_20.json'
BATCH_SIZE = 50

# Load existing dataset if it exists
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, 'r') as f:
dataset = json.load(f)
logger.info(f"Loaded existing dataset with {len(dataset)} examples")
else:
dataset = []
logger.info("Starting new dataset")

# Initialize agent for CoT generation
agent = ChatAgent(model=model_deepseek)
logger.info("Loading advanced math dataset...")
# Load the seed dataset
with open(SEED_FILE_PATH, 'r') as f:
seed_data = json.load(f)

# Define the prompt for CoT generation
USER_PROMPT = """You are an agent designed to answer mathematical questions with clarity and precision. Your task is to provide a step-by-step explanation for
logger.info(f"Loaded seed dataset with {len(seed_data)} examples")

# Define the prompt for CoT generation
USER_PROMPT = """You are an agent designed to answer mathematical questions with clarity and precision. Your task is to provide a step-by-step explanation for
any mathematical problem posed by the user, ensuring the response is easy to follow. Adhere to these guidelines:
Analyze the mathematical question carefully and break down the solution process into clear, logical steps.
Use natural language to explain each step, incorporating LaTeX notation (e.g., $x + 2$)
Expand All @@ -126,57 +159,74 @@

The question you should answer is: """

num_rejected = 0
target_size = 1000

logger.info("Starting generation and verification loop...")

while sum(1 for entry in dataset if entry["verified"]) < target_size:
logger.info(f"Current verified count: {sum(1 for entry in dataset if entry['verified'])}/{target_size}")

obs = asyncio.run(env.reset())
deepseek_response = agent.step(USER_PROMPT + obs.question).msgs[0].content

# Split the response into reasoning and answer parts
reasoning_part = ""
answer_part = deepseek_response

if "<think>" in deepseek_response and "</think>" in deepseek_response:
parts = deepseek_response.split("</think>")
if len(parts) > 1:
reasoning_part = parts[0].replace("<think>", "").strip()
answer_part = parts[1].strip()

next_obs, reward, done, info = asyncio.run(env.step(Action(index=0, llm_response=deepseek_response)))

# Create and save data entry (both verified and unverified)
data_entry = {
"question": obs.question,
"answer": info['state'].final_answer if 'state' in info else '',
"response": answer_part,
"long_cot": reasoning_part,
"shots": obs.metadata.get('shots'),
"verified": reward > 0
}

# Add entry to dataset (both verified and unverified)
dataset.append(data_entry)
# Save immediately
with open(OUTPUT_FILE, 'w') as f:
json.dump(dataset, f, indent=2)

num_rejected = 0
target_size = 1000
verified_count = sum(1 for entry in dataset if entry["verified"])
if reward > 0:
logger.info(f"Verification successful - Added verified entry ({verified_count}/{target_size} verified)")
else:
num_rejected += 1
logger.warning(f"Verification failed - Added unverified entry ({verified_count}/{target_size} verified)")

agent.reset()

# At the end, log statistics
total_entries = len(dataset)
verified_entries = sum(1 for entry in dataset if entry["verified"])
logger.info(f"Generation complete. Total entries: {total_entries}")
logger.info(f"Verified entries: {verified_entries}")
logger.info(f"Rejected entries: {num_rejected}")
logger.info("Starting generation and verification loop...")

# Main processing loop with proper process pool management
while verified_count < target_size:
logger.info(f"Current verified count: {verified_count}/{target_size}")

# Determine batch size
remaining = target_size - verified_count
batch_size = min(BATCH_SIZE, remaining)

# Create arguments for each process
process_args = [
(i, USER_PROMPT, os.environ.get("DEEPSEEK_API_KEY"), SEED_FILE_PATH, OUTPUT_FILE)
for i in range(batch_size)
]

logger.info(f"Processing batch with {batch_size} processes...")
with Pool(processes=batch_size) as pool:
# Use starmap to pass multiple arguments to process_record
results = pool.starmap(process_record, process_args)

# Close and join pool properly
pool.close()
pool.join()

# Process results from completed processes
newly_verified = 0
for data_entry, is_verified, proc_id in results:
if data_entry is not None:
# Add to dataset
dataset.append(data_entry)

# Update counters
if is_verified:
newly_verified += 1
verified_count += 1
logger.info(f"Process {proc_id}: Verification successful - Added verified entry ({verified_count}/{target_size} verified)")
else:
num_rejected += 1
logger.warning(f"Process {proc_id}: Verification failed - Added unverified entry ({verified_count}/{target_size} verified)")

# Save after each batch
with open(OUTPUT_FILE, 'w') as f:
json.dump(dataset, f, indent=2)

logger.info(f"Batch complete - Added {newly_verified} verified entries in this batch")

# Final statistics
total_entries = len(dataset)
verified_entries = sum(1 for entry in dataset if entry["verified"])
logger.info(f"Generation complete. Total entries: {total_entries}")
logger.info(f"Verified entries: {verified_entries}")
logger.info(f"Rejected entries: {num_rejected}")

end_time = time.time()

# Calculate total elapsed time
elapsed_time = end_time - start_time
logger.info(f"Total elapsed time: {elapsed_time:.2f}s")
logger.info(f"Average time per record: {elapsed_time/total_entries:.2f}s")
logger.info(f"Processed using up to {num_processes} worker processes.")

# Required for multiprocessing on macOS (and Windows)
if __name__ == '__main__':
# Add freeze_support to properly handle multiprocessing in frozen executables
freeze_support()
main()