-
Notifications
You must be signed in to change notification settings - Fork 7
Integrating MongoDB MVP #383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
cfae831
4353f4f
0d5ae9d
fe004b8
633f4b2
c5c7a97
8c499f9
932d37b
454d471
3d65800
89b7183
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,30 +1,62 @@ | ||
| # Generated by Honegumi (https://arxiv.org/abs/2502.06815) | ||
| # pip install ax-platform==0.4.3 numpy pymongo | ||
| import numpy as np | ||
| import json | ||
| import os | ||
| from datetime import datetime | ||
| from ax.service.ax_client import AxClient, ObjectiveProperties | ||
| from ax.modelbridge.generation_strategy import GenerationStrategy | ||
| from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep | ||
| from ax.modelbridge.registry import Models | ||
| from pymongo import MongoClient # Added import for MongoDB | ||
| from pymongo import MongoClient, errors | ||
|
|
||
|
|
||
| obj1_name = "branin" | ||
| MAX_TRIALS = 19 # Configuration constant | ||
|
|
||
|
|
||
| def branin(x1, x2): | ||
| """Branin function - a common benchmark for optimization.""" | ||
| y = float( | ||
| (x2 - 5.1 / (4 * np.pi**2) * x1**2 + 5.0 / np.pi * x1 - 6.0) ** 2 | ||
| + 10 * (1 - 1.0 / (8 * np.pi)) * np.cos(x1) | ||
| + 10 | ||
| ) | ||
|
|
||
| return y | ||
|
|
||
|
|
||
| # Connect to MongoDB | ||
| tmongo_client = MongoClient("mongodb://localhost:27017/") | ||
| db = tmongo_client["ax_db"] | ||
| experiments_col = db["experiments"] | ||
| def create_generation_strategy(sobol_trials=5): | ||
| """Create generation strategy with specified number of Sobol trials.""" | ||
| return GenerationStrategy([ | ||
| GenerationStep( | ||
| model=Models.SOBOL, | ||
| num_trials=sobol_trials, | ||
| min_trials_observed=1, | ||
| max_parallelism=5, | ||
| model_kwargs={"seed": 999}, # For reproducibility | ||
| ), | ||
| GenerationStep( | ||
| model=Models.GPEI, | ||
| num_trials=-1, | ||
| max_parallelism=3, | ||
| model_kwargs={}, | ||
| ), | ||
| ]) | ||
|
|
||
|
|
||
| # Connect to MongoDB with error handling | ||
| try: | ||
| mongo_client = MongoClient("mongodb://localhost:27017/", serverSelectionTimeoutMS=5000) | ||
| # Test the connection | ||
| mongo_client.admin.command('ping') | ||
| db = mongo_client["ax_db"] | ||
| snapshots_col = db["ax_snapshots"] # Collection for storing JSON snapshots | ||
| print("Connected to MongoDB successfully") | ||
| except errors.ServerSelectionTimeoutError: | ||
| print("Failed to connect to MongoDB. Is MongoDB running?") | ||
| exit(1) | ||
| except Exception as e: | ||
| print(f"MongoDB connection error: {e}") | ||
| exit(1) | ||
|
|
||
| # Experiment configuration | ||
| parameters = [ | ||
|
|
@@ -33,76 +65,166 @@ def branin(x1, x2): | |
| ] | ||
| objectives = {obj1_name: ObjectiveProperties(minimize=True)} | ||
|
|
||
| # Use Ax's default Sobol trials for 2D problems (aligns with GitHub comment) | ||
| # Use Ax's default Sobol trials for 2D problems | ||
| SOBOL_TRIALS = 5 | ||
|
|
||
| # Load existing experiment state or initialize new | ||
| record = experiments_col.find_one({"experiment_name": obj1_name}) | ||
| if record: | ||
| saved_trials = record.get("trials", []) | ||
| n_existing = len(saved_trials) | ||
|
|
||
| # Calculate remaining Sobol trials: max(target_sobol - existing, 0) | ||
| remaining_sobol = max(SOBOL_TRIALS - n_existing, 0) | ||
|
|
||
| if remaining_sobol > 0: | ||
| generation_strategy = GenerationStrategy([ | ||
| {"model": Models.SOBOL, "num_trials": remaining_sobol}, | ||
| {"model": Models.GPEI, "num_trials": -1} | ||
| ]) | ||
| print(f"Will run {remaining_sobol} more Sobol trials (have {n_existing} existing)") | ||
| else: | ||
| # Remove Sobol step entirely when remaining_sobol = 0 | ||
| generation_strategy = GenerationStrategy([ | ||
| {"model": Models.GPEI, "num_trials": -1} | ||
| ]) | ||
| print(f"Skipping Sobol (have {n_existing} trials), going to GP") | ||
|
|
||
| def save_ax_snapshot_to_mongodb(ax_client, experiment_name): | ||
|
Gawthaman marked this conversation as resolved.
Outdated
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this return the database ID of the snapshot? |
||
| """Save Ax client snapshot to MongoDB with timestamp.""" | ||
| try: | ||
| # Save to temporary JSON file first | ||
| temp_file = f"temp_{experiment_name}_snapshot.json" | ||
| ax_client.save_to_json_file(temp_file) | ||
|
|
||
| # Read the JSON content | ||
| with open(temp_file, 'r') as f: | ||
| snapshot_data = json.load(f) | ||
|
|
||
| # Create MongoDB document | ||
| snapshot_doc = { | ||
| "experiment_name": experiment_name, | ||
| "snapshot_data": snapshot_data, | ||
| "timestamp": datetime.now().isoformat(), | ||
| "trial_count": len(ax_client.get_trials_data_frame()) if ax_client.get_trials_data_frame() is not None else 0 | ||
| } | ||
|
|
||
| # Upsert the snapshot (replace if exists, insert if not) | ||
| snapshots_col.replace_one( | ||
| {"experiment_name": experiment_name}, | ||
| snapshot_doc, | ||
| upsert=True | ||
| ) | ||
|
|
||
| # Clean up temp file | ||
| os.remove(temp_file) | ||
|
|
||
| print(f"Snapshot saved to MongoDB at {snapshot_doc['timestamp']}") | ||
| return True | ||
|
|
||
| except Exception as e: | ||
| print(f"Error saving snapshot: {e}") | ||
| return False | ||
|
|
||
|
|
||
| def load_ax_snapshot_from_mongodb(experiment_name): | ||
| """Load most recent Ax client snapshot from MongoDB.""" | ||
| try: | ||
| # Find the most recent snapshot | ||
| record = snapshots_col.find_one( | ||
| {"experiment_name": experiment_name}, | ||
| sort=[("timestamp", -1)] # Most recent first | ||
| ) | ||
|
|
||
| if record: | ||
| # Save snapshot data to temporary file | ||
| temp_file = f"temp_{experiment_name}_snapshot.json" | ||
| with open(temp_file, 'w') as f: | ||
| json.dump(record["snapshot_data"], f) | ||
|
|
||
| # Load AxClient from file | ||
| ax_client = AxClient.load_from_json_file(temp_file) | ||
|
|
||
| # Clean up temp file | ||
| os.remove(temp_file) | ||
|
|
||
| print(f"Loaded snapshot from {record['timestamp']} with {record['trial_count']} trials") | ||
| return ax_client | ||
| else: | ||
| print("No existing snapshot found") | ||
| return None | ||
|
|
||
| except Exception as e: | ||
| print(f"Error loading snapshot: {e}") | ||
| return None | ||
|
|
||
|
|
||
| # Load existing experiment or create new one | ||
| ax_client = load_ax_snapshot_from_mongodb(obj1_name) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably don't reuse obj1_name like this. Instead define a separate variable. Could incorporate obj1_name via f-string. Probably good to also add a hard-coded set of 4 characters (randomly generated externally) to it. |
||
|
|
||
| if ax_client is None: | ||
| # Create new experiment | ||
| generation_strategy = create_generation_strategy(SOBOL_TRIALS) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please unwrap this function |
||
| ax_client = AxClient(generation_strategy=generation_strategy) | ||
| ax_client.create_experiment(name=obj1_name, parameters=parameters, objectives=objectives) | ||
| ax_client.create_experiment( | ||
| name=obj1_name, | ||
| parameters=parameters, | ||
| objectives=objectives | ||
| ) | ||
| print(f"Created new experiment with {SOBOL_TRIALS} Sobol trials") | ||
|
|
||
| # Replay saved trials | ||
| for t in saved_trials: | ||
| ax_client.complete_trial(trial_index=t["trial_index"], raw_data=t["raw_data"]) | ||
| start_i = len(saved_trials) | ||
| # Save initial snapshot | ||
| save_ax_snapshot_to_mongodb(ax_client, obj1_name) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment about experiment name |
||
| else: | ||
| # Use the SAME custom generation strategy for new experiments | ||
| generation_strategy = GenerationStrategy([ | ||
| {"model": Models.SOBOL, "num_trials": SOBOL_TRIALS}, | ||
| {"model": Models.GPEI, "num_trials": -1} | ||
| ]) | ||
| ax_client = AxClient(generation_strategy=generation_strategy) | ||
| ax_client.create_experiment(name=obj1_name, parameters=parameters, objectives=objectives) | ||
| start_i = 0 | ||
| experiments_col.insert_one({"experiment_name": obj1_name, "trials": []}) | ||
| print(f"Starting new experiment with {SOBOL_TRIALS} Sobol trials") | ||
| print(f"Resuming existing experiment") | ||
|
|
||
| for i in range(start_i, MAX_TRIALS): | ||
| # Get current trial count to determine how many more trials to run | ||
| current_trials = ax_client.get_trials_data_frame() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, good point about needing to handle max trials in the for loop. I suppose an alternative would be to have a budget variable stored in MongoDB that gets updated, but ignore that for now. More of a note to self/musing. |
||
| start_trial = len(current_trials) if current_trials is not None else 0 | ||
|
|
||
| parameterization, trial_index = ax_client.get_next_trial() | ||
| print(f"Starting optimization: running trials {start_trial} to {MAX_TRIALS-1}") | ||
|
|
||
| # extract parameters | ||
| x1 = parameterization["x1"] | ||
| x2 = parameterization["x2"] | ||
| for i in range(start_trial, MAX_TRIALS): | ||
| try: | ||
| # Get next trial | ||
| parameterization, trial_index = ax_client.get_next_trial() | ||
|
|
||
| # Extract parameters | ||
| x1 = parameterization["x1"] | ||
| x2 = parameterization["x2"] | ||
|
|
||
| print(f"Trial {trial_index}: x1={x1:.3f}, x2={x2:.3f}") | ||
|
|
||
| # Save snapshot before running experiment (preserves pending trial) | ||
| save_ax_snapshot_to_mongodb(ax_client, obj1_name) | ||
|
|
||
| # Evaluate objective function | ||
| results = branin(x1, x2) | ||
|
|
||
| # Format raw_data as expected by AxClient | ||
| raw_data = {obj1_name: results} | ||
|
|
||
| # Complete trial | ||
| ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data) | ||
|
|
||
| # Save snapshot after completing trial | ||
| save_ax_snapshot_to_mongodb(ax_client, obj1_name) | ||
|
|
||
| # Get current best for progress tracking | ||
| try: | ||
| best_parameters, best_metrics = ax_client.get_best_parameters() | ||
| best_value = best_metrics[0][obj1_name] | ||
| print(f"Trial {trial_index}: result={results:.3f} | Best so far: {best_value:.3f}") | ||
| except Exception: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably remove these try-excepts |
||
| print(f"Trial {trial_index}: result={results:.3f}") | ||
|
|
||
| except Exception as e: | ||
| print(f"Error in trial {trial_index}: {e}") | ||
| continue | ||
|
|
||
| results = branin(x1, x2) | ||
| # Format raw_data as expected by AxClient (dict mapping objective name to value) | ||
| raw_data = {obj1_name: results} | ||
| print("\nOptimization completed!") | ||
| try: | ||
| best_parameters, best_metrics = ax_client.get_best_parameters() | ||
| print(f"Best parameters: {best_parameters}") | ||
| print(f"Best metrics: {best_metrics}") | ||
|
|
||
| ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data) | ||
| # Save final snapshot | ||
| save_ax_snapshot_to_mongodb(ax_client, obj1_name) | ||
|
|
||
| # Save trial results to MongoDB with parameters for debugging | ||
| experiments_col.update_one( | ||
| {"experiment_name": obj1_name}, | ||
| {"$push": {"trials": { | ||
| "trial_index": trial_index, | ||
| "raw_data": raw_data, | ||
| "parameters": parameterization | ||
| }}}, | ||
| ) | ||
| # Print experiment summary | ||
| trials_df = ax_client.get_trials_data_frame() | ||
| if trials_df is not None: | ||
| print(f"Total trials completed: {len(trials_df)}") | ||
| print(f"Best objective value: {trials_df[obj1_name].min():.6f}") | ||
|
|
||
| print(f"Trial {trial_index}: x1={x1:.3f}, x2={x2:.3f}, result={results:.3f}") | ||
| except Exception as e: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, no need for try except here |
||
| print(f"Error getting best parameters: {e}") | ||
|
|
||
| # Clean up MongoDB connection | ||
| mongo_client.close() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good to close the client. I suppose a top-level try-except-finally could be implemented at some point if we find that lots of mongo connections are being leftover during restarts, but I think probably not an issue. Again, just a musing |
||
| print("MongoDB connection closed") | ||
|
|
||
| best_parameters, metrics = ax_client.get_best_parameters() | ||
| print(f"Best parameters: {best_parameters}") | ||
| print(f"Best metrics: {metrics}") | ||
| # Optional: Display trials data frame for debugging | ||
| try: | ||
| print("\nTrials Summary:") | ||
| print(ax_client.get_trials_data_frame()) | ||
| except Exception as e: | ||
| print(f"Error displaying trials: {e}") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this explicit error handling? Seems like it would just bubble up naturally (unless you found that the error that bubbled up naturally was non-descript).
As a note for later, we'll set this up with a MongoDB Atlas cluster