Skip to content

Commit 517087d

Browse files
authored
Merge pull request #305 from ParagEkbote/Add-Comet-Example
Add Example for Comet
2 parents 6ad6d23 + 4b7fb6c commit 517087d

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

comet/comet_integration.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
This script integrates Comet ML and Optuna to optimize a Random Forest Classifier
3+
on the scikit-learn Breast Cancer dataset. It performs the following steps:
4+
5+
1. Initializes a Comet ML experiment for logging.
6+
2. Loads the Breast Cancer dataset and splits it into training and testing sets.
7+
3. Defines an evaluation function using accuracy.
8+
4. Implements an Optuna objective function to optimize hyperparameters
9+
(min_samples_leaf, max_depth, and min_samples_split) for the Random Forest model.
10+
5. Uses Optuna to run multiple trials and identify the best hyperparameters.
11+
6. Trains the final Random Forest model using the best-found hyperparameters.
12+
7. Logs training and testing metrics to Comet ML.
13+
14+
You can run this example as follows:
15+
$ python comet_integration.py
16+
"""
17+
18+
import optuna
19+
from optuna_integration.comet import CometCallback
20+
21+
from sklearn.datasets import load_iris
22+
from sklearn.ensemble import RandomForestClassifier
23+
from sklearn.metrics import accuracy_score
24+
from sklearn.model_selection import train_test_split
25+
26+
27+
def objective(trial):
28+
"""Objective function for optimizing a RandomForestClassifier using Optuna."""
29+
data = load_iris()
30+
x_train, x_valid, y_train, y_valid = train_test_split(
31+
data["data"], data["target"], random_state=42
32+
)
33+
params = {
34+
"min_samples_leaf": trial.suggest_int("min_samples_leaf", 2, 10),
35+
"max_depth": trial.suggest_int("max_depth", 5, 20),
36+
"min_samples_split": trial.suggest_int("min_samples_split", 2, 10),
37+
}
38+
39+
clf = RandomForestClassifier(**params, random_state=42)
40+
clf.fit(x_train, y_train)
41+
pred = clf.predict(x_valid)
42+
score = accuracy_score(y_valid, pred)
43+
44+
return score
45+
46+
47+
if __name__ == "__main__":
48+
study = optuna.create_study(direction="maximize")
49+
comet_callback = CometCallback(study, project_name="comet-optuna-sklearn-example")
50+
51+
study.optimize(objective, n_trials=20, callbacks=[comet_callback])
52+
53+
print(f"Number of finished trials: {len(study.trials)}\n")
54+
55+
print("Best trial:")
56+
trial = study.best_trial
57+
58+
print(f" Value: {trial.value}\n")
59+
print(" Params:")
60+
for key, value in trial.params.items():
61+
print(f" {key}: {value}")

comet/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
comet-ml
2+
optuna
3+
optuna-integration
4+
scikit-learn

0 commit comments

Comments
 (0)