Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ In addition, integration modules are available for the following libraries, prov
* [Pruning with Tensorflow Integration Module](./tensorflow/tensorflow_estimator_integration.py)
* [Pruning with XGBoost Integration Module](./xgboost/xgboost_integration.py)
* [Pruning with XGBoost Integration Module (Cross Validation Version)](./xgboost/xgboost_cv_integration.py)
* [Pruning with Comet Integration Module](./comet/comet_callback.py)
</details>

<details open>
Expand Down
99 changes: 99 additions & 0 deletions comet/comet_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
This script integrates Comet ML and Optuna to optimize a Random Forest Classifier
on the scikit-learn Breast Cancer dataset. It performs the following steps:

1. Initializes a Comet ML experiment for logging.
2. Loads the Breast Cancer dataset and splits it into training and testing sets.
3. Defines an evaluation function using F1-score, precision, and recall.
4. Implements an Optuna objective function to optimize hyperparameters
(n_estimators and max_depth) for the Random Forest model.
5. Uses Optuna to run multiple trials and identify the best hyperparameters.
6. Trains the final Random Forest model using the best-found hyperparameters.
7. Logs training and testing metrics to Comet ML.

You can run this example as follows:
$ python comet_callback.py
"""

import os

import comet_ml
import optuna
from optuna_integration.comet import CometCallback

from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.model_selection import train_test_split


# Ensure the API key is available
if not os.getenv("COMET_API_KEY"):
raise ValueError("COMET_API_KEY is missing! Please set it as an environment variable.")

# Start the experiment using comet_ml.start()
experiment = comet_ml.start()

# Log the project name
experiment.set_name("comet-optuna-example")

# Load dataset
random_state = 42
cancer = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(
cancer.data, cancer.target, stratify=cancer.target, random_state=random_state
)


# Evaluation function
def evaluate(y_true, y_pred):
return {
"f1": f1_score(y_true, y_pred),
"precision": precision_score(y_true, y_pred),
"recall": recall_score(y_true, y_pred),
}


def objective(trial):
n_estimators = trial.suggest_int("n_estimators", 10, 200)
max_depth = trial.suggest_int("max_depth", 2, 20)

clf = RandomForestClassifier(
n_estimators=n_estimators, max_depth=max_depth, random_state=random_state
)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

score = f1_score(y_test, y_pred)

# Log the metric manually
experiment.log_metric("f1_score", score, step=trial.number)

return score


# Optuna Study with Comet ML callback
study = optuna.create_study(direction="maximize")
comet_callback = CometCallback(
study, project_name="comet-optuna-sklearn-example", metric_names=["f1_score"]
)
study.optimize(objective, n_trials=20, callbacks=[comet_callback])

# Train final model with best parameters
best_params = study.best_params
clf = RandomForestClassifier(**best_params, random_state=random_state)
clf.fit(X_train, y_train)
y_train_pred = clf.predict(X_train)
y_test_pred = clf.predict(X_test)

# Log training metrics
with experiment.train():
experiment.log_metrics(evaluate(y_train, y_train_pred))

# Log testing metrics
with experiment.test():
experiment.log_metrics(evaluate(y_test, y_test_pred))

experiment.end()
4 changes: 4 additions & 0 deletions comet/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
comet-ml
optuna
optuna-integration
scikit-learn