Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 189 additions & 67 deletions scripts/mongodbintegrationmvp.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,62 @@
# Generated by Honegumi (https://arxiv.org/abs/2502.06815)
# pip install ax-platform==0.4.3 numpy pymongo
import numpy as np
import json
import os
from datetime import datetime
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep
from ax.modelbridge.registry import Models
from pymongo import MongoClient # Added import for MongoDB
from pymongo import MongoClient, errors


obj1_name = "branin"
MAX_TRIALS = 19 # Configuration constant


def branin(x1, x2):
"""Branin function - a common benchmark for optimization."""
y = float(
(x2 - 5.1 / (4 * np.pi**2) * x1**2 + 5.0 / np.pi * x1 - 6.0) ** 2
+ 10 * (1 - 1.0 / (8 * np.pi)) * np.cos(x1)
+ 10
)

return y


# Connect to MongoDB
tmongo_client = MongoClient("mongodb://localhost:27017/")
db = tmongo_client["ax_db"]
experiments_col = db["experiments"]
def create_generation_strategy(sobol_trials=5):
"""Create generation strategy with specified number of Sobol trials."""
return GenerationStrategy([
GenerationStep(
model=Models.SOBOL,
num_trials=sobol_trials,
min_trials_observed=1,
max_parallelism=5,
model_kwargs={"seed": 999}, # For reproducibility
),
GenerationStep(
model=Models.GPEI,
num_trials=-1,
max_parallelism=3,
model_kwargs={},
),
])


# Connect to MongoDB with error handling
try:
mongo_client = MongoClient("mongodb://localhost:27017/", serverSelectionTimeoutMS=5000)
# Test the connection
mongo_client.admin.command('ping')
db = mongo_client["ax_db"]
snapshots_col = db["ax_snapshots"] # Collection for storing JSON snapshots
print("Connected to MongoDB successfully")
except errors.ServerSelectionTimeoutError:
print("Failed to connect to MongoDB. Is MongoDB running?")
exit(1)
except Exception as e:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this explicit error handling? Seems like it would just bubble up naturally (unless you found that the error that bubbled up naturally was non-descript).

As a note for later, we'll set this up with a MongoDB Atlas cluster

print(f"MongoDB connection error: {e}")
exit(1)

# Experiment configuration
parameters = [
Expand All @@ -33,76 +65,166 @@ def branin(x1, x2):
]
objectives = {obj1_name: ObjectiveProperties(minimize=True)}

# Use Ax's default Sobol trials for 2D problems (aligns with GitHub comment)
# Use Ax's default Sobol trials for 2D problems
SOBOL_TRIALS = 5

# Load existing experiment state or initialize new
record = experiments_col.find_one({"experiment_name": obj1_name})
if record:
saved_trials = record.get("trials", [])
n_existing = len(saved_trials)

# Calculate remaining Sobol trials: max(target_sobol - existing, 0)
remaining_sobol = max(SOBOL_TRIALS - n_existing, 0)

if remaining_sobol > 0:
generation_strategy = GenerationStrategy([
{"model": Models.SOBOL, "num_trials": remaining_sobol},
{"model": Models.GPEI, "num_trials": -1}
])
print(f"Will run {remaining_sobol} more Sobol trials (have {n_existing} existing)")
else:
# Remove Sobol step entirely when remaining_sobol = 0
generation_strategy = GenerationStrategy([
{"model": Models.GPEI, "num_trials": -1}
])
print(f"Skipping Sobol (have {n_existing} trials), going to GP")

def save_ax_snapshot_to_mongodb(ax_client, experiment_name):
Comment thread
Gawthaman marked this conversation as resolved.
Outdated

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this return the database ID of the snapshot?

"""Save Ax client snapshot to MongoDB with timestamp."""
try:
# Save to temporary JSON file first
temp_file = f"temp_{experiment_name}_snapshot.json"
ax_client.save_to_json_file(temp_file)

# Read the JSON content
with open(temp_file, 'r') as f:
snapshot_data = json.load(f)

# Create MongoDB document
snapshot_doc = {
"experiment_name": experiment_name,
"snapshot_data": snapshot_data,
"timestamp": datetime.now().isoformat(),
"trial_count": len(ax_client.get_trials_data_frame()) if ax_client.get_trials_data_frame() is not None else 0
}

# Upsert the snapshot (replace if exists, insert if not)
snapshots_col.replace_one(
{"experiment_name": experiment_name},
snapshot_doc,
upsert=True
)

# Clean up temp file
os.remove(temp_file)

print(f"Snapshot saved to MongoDB at {snapshot_doc['timestamp']}")
return True

except Exception as e:
print(f"Error saving snapshot: {e}")
return False


def load_ax_snapshot_from_mongodb(experiment_name):
"""Load most recent Ax client snapshot from MongoDB."""
try:
# Find the most recent snapshot
record = snapshots_col.find_one(
{"experiment_name": experiment_name},
sort=[("timestamp", -1)] # Most recent first
)

if record:
# Save snapshot data to temporary file
temp_file = f"temp_{experiment_name}_snapshot.json"
with open(temp_file, 'w') as f:
json.dump(record["snapshot_data"], f)

# Load AxClient from file
ax_client = AxClient.load_from_json_file(temp_file)

# Clean up temp file
os.remove(temp_file)

print(f"Loaded snapshot from {record['timestamp']} with {record['trial_count']} trials")
return ax_client
else:
print("No existing snapshot found")
return None

except Exception as e:
print(f"Error loading snapshot: {e}")
return None


# Load existing experiment or create new one
ax_client = load_ax_snapshot_from_mongodb(obj1_name)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably don't reuse obj1_name like this. Instead define a separate variable. Could incorporate obj1_name via f-string. Probably good to also add a hard-coded set of 4 characters (randomly generated externally) to it.


if ax_client is None:
# Create new experiment
generation_strategy = create_generation_strategy(SOBOL_TRIALS)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unwrap this function

ax_client = AxClient(generation_strategy=generation_strategy)
ax_client.create_experiment(name=obj1_name, parameters=parameters, objectives=objectives)
ax_client.create_experiment(
name=obj1_name,
parameters=parameters,
objectives=objectives
)
print(f"Created new experiment with {SOBOL_TRIALS} Sobol trials")

# Replay saved trials
for t in saved_trials:
ax_client.complete_trial(trial_index=t["trial_index"], raw_data=t["raw_data"])
start_i = len(saved_trials)
# Save initial snapshot
save_ax_snapshot_to_mongodb(ax_client, obj1_name)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about experiment name

else:
# Use the SAME custom generation strategy for new experiments
generation_strategy = GenerationStrategy([
{"model": Models.SOBOL, "num_trials": SOBOL_TRIALS},
{"model": Models.GPEI, "num_trials": -1}
])
ax_client = AxClient(generation_strategy=generation_strategy)
ax_client.create_experiment(name=obj1_name, parameters=parameters, objectives=objectives)
start_i = 0
experiments_col.insert_one({"experiment_name": obj1_name, "trials": []})
print(f"Starting new experiment with {SOBOL_TRIALS} Sobol trials")
print(f"Resuming existing experiment")

for i in range(start_i, MAX_TRIALS):
# Get current trial count to determine how many more trials to run
current_trials = ax_client.get_trials_data_frame()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point about needing to handle max trials in the for loop. I suppose an alternative would be to have a budget variable stored in MongoDB that gets updated, but ignore that for now. More of a note to self/musing.

start_trial = len(current_trials) if current_trials is not None else 0

parameterization, trial_index = ax_client.get_next_trial()
print(f"Starting optimization: running trials {start_trial} to {MAX_TRIALS-1}")

# extract parameters
x1 = parameterization["x1"]
x2 = parameterization["x2"]
for i in range(start_trial, MAX_TRIALS):
try:
# Get next trial
parameterization, trial_index = ax_client.get_next_trial()

# Extract parameters
x1 = parameterization["x1"]
x2 = parameterization["x2"]

print(f"Trial {trial_index}: x1={x1:.3f}, x2={x2:.3f}")

# Save snapshot before running experiment (preserves pending trial)
save_ax_snapshot_to_mongodb(ax_client, obj1_name)

# Evaluate objective function
results = branin(x1, x2)

# Format raw_data as expected by AxClient
raw_data = {obj1_name: results}

# Complete trial
ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data)

# Save snapshot after completing trial
save_ax_snapshot_to_mongodb(ax_client, obj1_name)

# Get current best for progress tracking
try:
best_parameters, best_metrics = ax_client.get_best_parameters()
best_value = best_metrics[0][obj1_name]
print(f"Trial {trial_index}: result={results:.3f} | Best so far: {best_value:.3f}")
except Exception:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably remove these try-excepts

print(f"Trial {trial_index}: result={results:.3f}")

except Exception as e:
print(f"Error in trial {trial_index}: {e}")
continue

results = branin(x1, x2)
# Format raw_data as expected by AxClient (dict mapping objective name to value)
raw_data = {obj1_name: results}
print("\nOptimization completed!")
try:
best_parameters, best_metrics = ax_client.get_best_parameters()
print(f"Best parameters: {best_parameters}")
print(f"Best metrics: {best_metrics}")

ax_client.complete_trial(trial_index=trial_index, raw_data=raw_data)
# Save final snapshot
save_ax_snapshot_to_mongodb(ax_client, obj1_name)

# Save trial results to MongoDB with parameters for debugging
experiments_col.update_one(
{"experiment_name": obj1_name},
{"$push": {"trials": {
"trial_index": trial_index,
"raw_data": raw_data,
"parameters": parameterization
}}},
)
# Print experiment summary
trials_df = ax_client.get_trials_data_frame()
if trials_df is not None:
print(f"Total trials completed: {len(trials_df)}")
print(f"Best objective value: {trials_df[obj1_name].min():.6f}")

print(f"Trial {trial_index}: x1={x1:.3f}, x2={x2:.3f}, result={results:.3f}")
except Exception as e:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, no need for try except here

print(f"Error getting best parameters: {e}")

# Clean up MongoDB connection
mongo_client.close()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to close the client. I suppose a top-level try-except-finally could be implemented at some point if we find that lots of mongo connections are being leftover during restarts, but I think probably not an issue. Again, just a musing

print("MongoDB connection closed")

best_parameters, metrics = ax_client.get_best_parameters()
print(f"Best parameters: {best_parameters}")
print(f"Best metrics: {metrics}")
# Optional: Display trials data frame for debugging
try:
print("\nTrials Summary:")
print(ax_client.get_trials_data_frame())
except Exception as e:
print(f"Error displaying trials: {e}")
Loading