|
| 1 | +import random |
| 2 | +from capymoa.base import ( |
| 3 | + Classifier, |
| 4 | +) |
| 5 | +from typing import Dict, Any |
| 6 | +from capymoa.stream import Schema |
| 7 | +from capymoa.evaluation import ClassificationEvaluator |
| 8 | +from capymoa.automl._utils import ( |
| 9 | + generate_parameter_combinations, |
| 10 | + create_capymoa_classifier, |
| 11 | +) |
| 12 | +import json |
| 13 | + |
| 14 | + |
| 15 | +class EpsilonGreedy: |
| 16 | + """Epsilon-Greedy bandit policy for model selection. |
| 17 | +
|
| 18 | + This policy selects the best model with probability ``1 - epsilon`` and explores |
| 19 | + other models with probability ``epsilon``. During the burn-in period, it always |
| 20 | + explores to gather initial information about all models. |
| 21 | +
|
| 22 | + >>> from capymoa.automl import EpsilonGreedy |
| 23 | + >>> policy = EpsilonGreedy(epsilon=0.1, burn_in=50) |
| 24 | + >>> policy.epsilon |
| 25 | + 0.1 |
| 26 | +
|
| 27 | + .. seealso:: |
| 28 | +
|
| 29 | + :class:`~capymoa.automl.BanditClassifier` |
| 30 | + """ |
| 31 | + |
| 32 | + def __init__(self, epsilon: float = 0.1, burn_in: int = 100): |
| 33 | + """Construct a new Epsilon-Greedy policy. |
| 34 | +
|
| 35 | + :param epsilon: Probability of exploring a random model (default: ``0.1``). |
| 36 | + :param burn_in: Number of initial rounds dedicated to exploration (default: ``100``). |
| 37 | + """ |
| 38 | + self.epsilon = epsilon |
| 39 | + """Probability of exploring a random model.""" |
| 40 | + |
| 41 | + self.burn_in = burn_in |
| 42 | + """Number of initial rounds where all models are explored to collect initial statistics.""" |
| 43 | + |
| 44 | + self.n_arms = 0 |
| 45 | + """Number of available models (arms).""" |
| 46 | + |
| 47 | + self.arm_rewards = [] |
| 48 | + """Cumulative reward values for each model (arm).""" |
| 49 | + |
| 50 | + self.arm_counts = [] |
| 51 | + """Number of times each arm has been pulled.""" |
| 52 | + |
| 53 | + self.total_pulls = 0 |
| 54 | + """Total number of model selections performed.""" |
| 55 | + |
| 56 | + def initialize(self, n_arms): |
| 57 | + """Initialize the policy with a given number of arms.""" |
| 58 | + self.n_arms = n_arms |
| 59 | + self.arm_rewards = [0.0] * n_arms |
| 60 | + self.arm_counts = [0] * n_arms |
| 61 | + self.total_pulls = 0 |
| 62 | + |
| 63 | + def get_best_arm_idx(self, available_arms): |
| 64 | + best_arm = max( |
| 65 | + available_arms, |
| 66 | + key=lambda arm: self.arm_rewards[arm] / max(1, self.arm_counts[arm]), |
| 67 | + ) |
| 68 | + return best_arm |
| 69 | + |
| 70 | + def pull(self, available_arms): |
| 71 | + """Select which arms to pull based on the epsilon-greedy policy.""" |
| 72 | + if self.total_pulls < self.burn_in: |
| 73 | + # During burn-in, explore all available arms |
| 74 | + return available_arms |
| 75 | + |
| 76 | + # With probability epsilon, explore a random arm |
| 77 | + if random.random() < self.epsilon: |
| 78 | + return [random.choice(available_arms)] |
| 79 | + |
| 80 | + # Otherwise, exploit the best arm |
| 81 | + best_arm = max( |
| 82 | + available_arms, |
| 83 | + key=lambda arm: self.arm_rewards[arm] / max(1, self.arm_counts[arm]), |
| 84 | + ) |
| 85 | + return [best_arm] |
| 86 | + |
| 87 | + def update(self, arm, reward): |
| 88 | + """Update the policy with the observed reward for the pulled arm.""" |
| 89 | + self.arm_rewards[arm] += reward |
| 90 | + self.arm_counts[arm] += 1 |
| 91 | + self.total_pulls += 1 |
| 92 | + |
| 93 | + def get_arm_stats(self): |
| 94 | + """Get statistics about each arm's performance.""" |
| 95 | + return { |
| 96 | + "rewards": self.arm_rewards.copy(), |
| 97 | + "counts": self.arm_counts.copy(), |
| 98 | + "means": [r / max(1, c) for r, c in zip(self.arm_rewards, self.arm_counts)], |
| 99 | + } |
| 100 | + |
| 101 | + |
| 102 | +class BanditClassifier(Classifier): |
| 103 | + """Bandit-based model selection for streaming classification. |
| 104 | +
|
| 105 | + Each base classifier is associated with an arm of a multi-armed bandit. |
| 106 | + At each training step, the bandit policy selects which model to train |
| 107 | + (i.e., which arm to pull). The reward corresponds to the model's |
| 108 | + performance on the current instance. The best-performing model is then |
| 109 | + used for prediction [#robbins1952]_. |
| 110 | +
|
| 111 | + >>> from capymoa.datasets import ElectricityTiny |
| 112 | + >>> from capymoa.classifier import HoeffdingTree |
| 113 | + >>> from capymoa.automl import BanditClassifier, EpsilonGreedy |
| 114 | + >>> stream = ElectricityTiny() |
| 115 | + >>> schema = stream.get_schema() |
| 116 | + >>> learner = BanditClassifier( |
| 117 | + ... schema=schema, |
| 118 | + ... base_classifiers=[HoeffdingTree], |
| 119 | + ... policy=EpsilonGreedy(epsilon=0.1, burn_in=100) |
| 120 | + ... ) |
| 121 | + >>> instance = next(stream) |
| 122 | + >>> learner.train(instance) |
| 123 | +
|
| 124 | + .. seealso:: |
| 125 | +
|
| 126 | + :class:`~capymoa.automl.EpsilonGreedy` |
| 127 | +
|
| 128 | + .. [#robbins1952] Robbins, H. (1952). *Some aspects of the sequential design of experiments.* |
| 129 | + Bulletin of the American Mathematical Society, 58(5), 527–535. |
| 130 | + """ |
| 131 | + |
| 132 | + def __init__( |
| 133 | + self, |
| 134 | + schema: Schema = None, |
| 135 | + random_seed: int = 1, |
| 136 | + base_classifiers: list = None, |
| 137 | + config_file: str = None, |
| 138 | + metric: str = "accuracy", |
| 139 | + policy: EpsilonGreedy = None, |
| 140 | + verbose: bool = False, |
| 141 | + ): |
| 142 | + """Construct a Bandit-based model selector. |
| 143 | +
|
| 144 | + :param schema: The schema of the stream. |
| 145 | + :param random_seed: Random seed used for initialization. |
| 146 | + :param base_classifiers: List of base classifier classes to consider. |
| 147 | + :param config_file: Path to a JSON configuration file with model hyperparameters. |
| 148 | + :param metric: The metric used to evaluate model performance. Defaults to ``"accuracy"``. |
| 149 | + :param policy: The bandit policy used to choose which model to train (e.g., :class:`~capymoa.automl.EpsilonGreedy`). |
| 150 | + :param verbose: If True, print progress information during training. |
| 151 | + """ |
| 152 | + super().__init__(schema=schema, random_seed=random_seed) |
| 153 | + |
| 154 | + self.config_file = config_file |
| 155 | + self.base_classifiers = base_classifiers |
| 156 | + self.metric = metric |
| 157 | + self.policy = policy |
| 158 | + self.verbose = verbose |
| 159 | + self.log_cnt = 0 |
| 160 | + self.log_point = 5000 |
| 161 | + |
| 162 | + # Initialize policy if not provided |
| 163 | + if self.policy is None: |
| 164 | + self.policy = EpsilonGreedy(epsilon=0.1, burn_in=100) |
| 165 | + |
| 166 | + # Initialize models based on configuration |
| 167 | + self._initialize_models() |
| 168 | + |
| 169 | + # Track the best model |
| 170 | + self._best_model_idx = 0 |
| 171 | + |
| 172 | + def _initialize_models(self): |
| 173 | + """Initialize models based on configuration.""" |
| 174 | + # Validate that we have at least one source of models |
| 175 | + if self.base_classifiers is None and self.config_file is None: |
| 176 | + raise ValueError("Either base_classifiers or config_file must be provided") |
| 177 | + |
| 178 | + # Initialize state variables |
| 179 | + self.active_models = [] # List of active model instances |
| 180 | + self.metrics = [] # List of evaluation metrics for each model |
| 181 | + |
| 182 | + # If using a config file, load and process it |
| 183 | + if self.config_file is not None: |
| 184 | + if self.verbose: |
| 185 | + print(f"Loading model configurations from {self.config_file}") |
| 186 | + self._load_model_configurations() |
| 187 | + else: |
| 188 | + # Use the provided base classifiers directly |
| 189 | + if self.verbose: |
| 190 | + print(f"Using {len(self.base_classifiers)} provided base classifiers") |
| 191 | + for model in self.base_classifiers: |
| 192 | + # Check if model is already instantiated or is a class |
| 193 | + if isinstance(model, Classifier): |
| 194 | + # Model is already instantiated, use it directly |
| 195 | + self.active_models.append(model) |
| 196 | + else: |
| 197 | + # Model is a class, instantiate it |
| 198 | + clf_instance = model(schema=self.schema) |
| 199 | + self.active_models.append(clf_instance) |
| 200 | + |
| 201 | + # Create an evaluator for this model |
| 202 | + self.metrics.append(ClassificationEvaluator(schema=self.schema)) |
| 203 | + |
| 204 | + # Initialize policy with number of arms |
| 205 | + self.policy.initialize(len(self.active_models)) |
| 206 | + |
| 207 | + def _load_model_configurations(self): |
| 208 | + """Load model configurations from a JSON file.""" |
| 209 | + try: |
| 210 | + with open(self.config_file, "r") as f: |
| 211 | + config = json.load(f) |
| 212 | + |
| 213 | + # Process algorithms section of the config |
| 214 | + algorithms = config.get("algorithms", []) |
| 215 | + |
| 216 | + # If there are no algorithms defined, raise an error |
| 217 | + if not algorithms: |
| 218 | + raise ValueError("No algorithms defined in the configuration file") |
| 219 | + |
| 220 | + # Process each algorithm and its parameter configurations |
| 221 | + for algo_config in algorithms: |
| 222 | + algorithm_name = algo_config.get("algorithm") |
| 223 | + parameters = algo_config.get("parameters", []) |
| 224 | + |
| 225 | + # Generate all parameter combinations |
| 226 | + param_combinations = generate_parameter_combinations(parameters) |
| 227 | + |
| 228 | + # Create a classifier for each parameter combination |
| 229 | + for params in param_combinations: |
| 230 | + try: |
| 231 | + # Create classifier instance |
| 232 | + clf = create_capymoa_classifier( |
| 233 | + algorithm_name, params, self.schema |
| 234 | + ) |
| 235 | + |
| 236 | + if clf is not None: |
| 237 | + self.active_models.append(clf) |
| 238 | + self.metrics.append( |
| 239 | + ClassificationEvaluator(schema=self.schema) |
| 240 | + ) |
| 241 | + |
| 242 | + if self.verbose: |
| 243 | + param_str = ", ".join( |
| 244 | + [f"{p['parameter']}={p['value']}" for p in params] |
| 245 | + ) |
| 246 | + print( |
| 247 | + f"Added model: {algorithm_name} with parameters: {param_str}" |
| 248 | + ) |
| 249 | + except Exception as e: |
| 250 | + print( |
| 251 | + f"Warning: Failed to create model {algorithm_name} with parameters {params}: {str(e)}" |
| 252 | + ) |
| 253 | + |
| 254 | + except (json.JSONDecodeError, FileNotFoundError) as e: |
| 255 | + raise ValueError(f"Error loading configuration file: {str(e)}") |
| 256 | + |
| 257 | + def train(self, instance): |
| 258 | + """Train the selected model(s) on the given instance.""" |
| 259 | + # Get the arm(s) to pull from the policy |
| 260 | + arm_ids = self.policy.pull(range(len(self.active_models))) |
| 261 | + |
| 262 | + # Train and evaluate each selected model |
| 263 | + for arm_id in arm_ids: |
| 264 | + model = self.active_models[arm_id] |
| 265 | + metric = self.metrics[arm_id] |
| 266 | + |
| 267 | + # Make prediction for evaluation before training |
| 268 | + y_pred = model.predict(instance) |
| 269 | + # Update metric with prediction |
| 270 | + metric.update(instance.y_index, y_pred) |
| 271 | + # Train the model |
| 272 | + model.train(instance) |
| 273 | + |
| 274 | + # Update the policy with the reward (metric value) |
| 275 | + reward = metric.accuracy() if self.metric == "accuracy" else metric.get() |
| 276 | + self.policy.update(arm_id, reward) |
| 277 | + |
| 278 | + # Check if this model is better than our current best |
| 279 | + if metric.accuracy() > self.metrics[self._best_model_idx].accuracy(): |
| 280 | + self._best_model_idx = arm_id |
| 281 | + |
| 282 | + # Add verbose logging |
| 283 | + self.log_cnt += 1 |
| 284 | + if self.verbose and self.log_cnt >= self.log_point: |
| 285 | + self.log_cnt = 0 |
| 286 | + current_accuracy = metric.accuracy() |
| 287 | + model_performances = [(i, self.metrics[i].accuracy()) for i in arm_ids] |
| 288 | + top_models = sorted(model_performances, key=lambda x: x[1], reverse=True)[ |
| 289 | + :3 |
| 290 | + ] |
| 291 | + |
| 292 | + print(f"\nChosen model: {model}") |
| 293 | + print(f"Current accuracy: {current_accuracy:.4f}") |
| 294 | + |
| 295 | + # Print top 3 models if there are many models |
| 296 | + if len(model_performances) >= 3: |
| 297 | + print("\nTop models:") |
| 298 | + for i, (model_idx, acc) in enumerate(top_models): |
| 299 | + model_name = str(self.active_models[model_idx]) |
| 300 | + print(f" {i + 1}. {model_name} - Accuracy: {acc:.4f}") |
| 301 | + |
| 302 | + def predict(self, instance): |
| 303 | + """Predict the class label for the given instance using the best model.""" |
| 304 | + if not self.active_models: |
| 305 | + raise ValueError( |
| 306 | + "No active models available. Please train the classifier first." |
| 307 | + ) |
| 308 | + idx = self.policy.get_best_arm_idx(range(len(self.active_models))) |
| 309 | + # Use the best performing model for predictions |
| 310 | + # idx = self._best_model_idx |
| 311 | + return self.active_models[idx].predict(instance) |
| 312 | + |
| 313 | + def predict_proba(self, instance): |
| 314 | + """Predict class probabilities for the given instance using the best model.""" |
| 315 | + if not self.active_models: |
| 316 | + raise ValueError( |
| 317 | + "No active models available. Please train the classifier first." |
| 318 | + ) |
| 319 | + |
| 320 | + idx = self.policy.get_best_arm_idx(range(len(self.active_models))) |
| 321 | + # Use the best performing model for predictions |
| 322 | + # idx = self._best_model_idx |
| 323 | + return self.active_models[idx].predict_proba(instance) |
| 324 | + |
| 325 | + def __str__(self): |
| 326 | + """Return a string representation of the model.""" |
| 327 | + return "BanditClassifier" |
| 328 | + |
| 329 | + @property |
| 330 | + def best_model(self): |
| 331 | + """Return the current best model.""" |
| 332 | + idx = self.policy.get_best_arm_idx(range(len(self.active_models))) |
| 333 | + return self.active_models[idx] |
| 334 | + |
| 335 | + def get_model_info(self) -> Dict[str, Any]: |
| 336 | + """ |
| 337 | + Get information about the current state of the classifier. |
| 338 | +
|
| 339 | + Returns: |
| 340 | + Dictionary containing classifier information |
| 341 | + """ |
| 342 | + # Get performance metrics for all models |
| 343 | + model_performances = { |
| 344 | + str(self.active_models[i]): self.metrics[i].accuracy() |
| 345 | + for i in range(len(self.active_models)) |
| 346 | + } |
| 347 | + sorted_dict = dict( |
| 348 | + sorted(model_performances.items(), key=lambda item: item[1], reverse=True) |
| 349 | + ) |
| 350 | + # Get top-performing models |
| 351 | + top_models = [] |
| 352 | + max_models = min(5, len(self.active_models)) |
| 353 | + i = 0 |
| 354 | + # idx = self.policy.get_best_arm_idx(range(len(self.active_models))) |
| 355 | + for key, value in sorted_dict.items(): |
| 356 | + if i >= max_models: |
| 357 | + break |
| 358 | + top_models.append({"model": key, "accuracy": value}) |
| 359 | + i += 1 |
| 360 | + |
| 361 | + return { |
| 362 | + "total_models": len(self.active_models), |
| 363 | + "best_model_index": self._best_model_idx, |
| 364 | + "model_performances": sorted_dict, |
| 365 | + "best_model_accuracy": ( |
| 366 | + self.metrics[self._best_model_idx].accuracy() if self.metrics else None |
| 367 | + ), |
| 368 | + "top_models": top_models, |
| 369 | + } |
0 commit comments