|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# Tire Anomaly Detection — Model Training & Evaluation\n", |
| 8 | + "\n", |
| 9 | + "This notebook trains a SageMaker Random Cut Forest model for tire pressure anomaly detection.\n", |
| 10 | + "\n", |
| 11 | + "**What it does:**\n", |
| 12 | + "1. Loads the synthetic training dataset (721K records, 50 vehicles, 6 months)\n", |
| 13 | + "2. Prepares and normalizes features (pressure, temperature, delta_pressure, delta_temp)\n", |
| 14 | + "3. Trains an RCF model on normal data only (unsupervised anomaly detection)\n", |
| 15 | + "4. Evaluates on labeled test data (slow leaks, punctures, valve failures)\n", |
| 16 | + "5. Deploys a real-time inference endpoint\n", |
| 17 | + "\n", |
| 18 | + "**Prerequisites:**\n", |
| 19 | + "- Run `python3 scripts/generate_training_data.py` first to create the dataset\n", |
| 20 | + "- SageMaker execution role with S3 and SSM access\n", |
| 21 | + "- S3 bucket for training artifacts" |
| 22 | + ] |
| 23 | + }, |
| 24 | + { |
| 25 | + "cell_type": "code", |
| 26 | + "execution_count": null, |
| 27 | + "metadata": {}, |
| 28 | + "outputs": [], |
| 29 | + "source": [ |
| 30 | + "import boto3\n", |
| 31 | + "import json\n", |
| 32 | + "import io\n", |
| 33 | + "import time\n", |
| 34 | + "import numpy as np\n", |
| 35 | + "import pandas as pd\n", |
| 36 | + "import matplotlib.pyplot as plt\n", |
| 37 | + "from datetime import datetime\n", |
| 38 | + "\n", |
| 39 | + "# Configuration\n", |
| 40 | + "REGION = 'us-east-2'\n", |
| 41 | + "BUCKET = 'cms-tire-prediction-ACCOUNT-REGION' # Update with your bucket\n", |
| 42 | + "ROLE_ARN = 'arn:aws:iam::ACCOUNT:role/cms-sagemaker-execution-role' # Update\n", |
| 43 | + "STAGE = 'prod'" |
| 44 | + ] |
| 45 | + }, |
| 46 | + { |
| 47 | + "cell_type": "markdown", |
| 48 | + "metadata": {}, |
| 49 | + "source": [ |
| 50 | + "## 1. Load and Explore Training Data" |
| 51 | + ] |
| 52 | + }, |
| 53 | + { |
| 54 | + "cell_type": "code", |
| 55 | + "execution_count": null, |
| 56 | + "metadata": {}, |
| 57 | + "outputs": [], |
| 58 | + "source": [ |
| 59 | + "df = pd.read_parquet('../data/training/tire_telemetry_full.parquet')\n", |
| 60 | + "print(f'Dataset: {len(df):,} records')\n", |
| 61 | + "print(f'Vehicles: {df.vehicle_id.nunique()}')\n", |
| 62 | + "print(f'Date range: {df.timestamp.min()} → {df.timestamp.max()}')\n", |
| 63 | + "print(f'\\nLabel distribution:')\n", |
| 64 | + "print(df.label.value_counts())" |
| 65 | + ] |
| 66 | + }, |
| 67 | + { |
| 68 | + "cell_type": "code", |
| 69 | + "execution_count": null, |
| 70 | + "metadata": {}, |
| 71 | + "outputs": [], |
| 72 | + "source": [ |
| 73 | + "# Visualize pressure distribution by label\n", |
| 74 | + "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", |
| 75 | + "\n", |
| 76 | + "for label in df.label.unique():\n", |
| 77 | + " subset = df[df.label == label]\n", |
| 78 | + " axes[0].hist(subset.pressure, bins=50, alpha=0.5, label=label)\n", |
| 79 | + "axes[0].set_xlabel('Pressure (PSI)')\n", |
| 80 | + "axes[0].set_ylabel('Count')\n", |
| 81 | + "axes[0].set_title('Pressure Distribution by Label')\n", |
| 82 | + "axes[0].legend()\n", |
| 83 | + "axes[0].axvline(x=28, color='red', linestyle='--', label='Alert threshold')\n", |
| 84 | + "\n", |
| 85 | + "# Show a slow leak example\n", |
| 86 | + "leak = df[(df.label == 'slow_leak') & (df.tire_id == 'FL')].sort_values('timestamp').head(500)\n", |
| 87 | + "if len(leak) > 0:\n", |
| 88 | + " vid = leak.vehicle_id.iloc[0]\n", |
| 89 | + " vehicle_leak = df[(df.vehicle_id == vid) & (df.tire_id == 'FL')].sort_values('timestamp')\n", |
| 90 | + " axes[1].plot(range(len(vehicle_leak)), vehicle_leak.pressure.values, linewidth=0.5)\n", |
| 91 | + " axes[1].axhline(y=28, color='red', linestyle='--', label='Alert threshold')\n", |
| 92 | + " axes[1].set_xlabel('Reading #')\n", |
| 93 | + " axes[1].set_ylabel('Pressure (PSI)')\n", |
| 94 | + " axes[1].set_title(f'Slow Leak Example ({vid} FL)')\n", |
| 95 | + " axes[1].legend()\n", |
| 96 | + "\n", |
| 97 | + "plt.tight_layout()\n", |
| 98 | + "plt.show()" |
| 99 | + ] |
| 100 | + }, |
| 101 | + { |
| 102 | + "cell_type": "markdown", |
| 103 | + "metadata": {}, |
| 104 | + "source": [ |
| 105 | + "## 2. Prepare Features" |
| 106 | + ] |
| 107 | + }, |
| 108 | + { |
| 109 | + "cell_type": "code", |
| 110 | + "execution_count": null, |
| 111 | + "metadata": {}, |
| 112 | + "outputs": [], |
| 113 | + "source": [ |
| 114 | + "features = ['pressure', 'temperature', 'delta_pressure', 'delta_temp']\n", |
| 115 | + "\n", |
| 116 | + "# Train on normal data only (unsupervised)\n", |
| 117 | + "normal = df[df.label == 'normal'][features].dropna()\n", |
| 118 | + "test = df[features + ['label']].dropna()\n", |
| 119 | + "\n", |
| 120 | + "# Normalize\n", |
| 121 | + "stats = {}\n", |
| 122 | + "for col in features:\n", |
| 123 | + " stats[col] = {'mean': float(normal[col].mean()), 'std': float(normal[col].std())}\n", |
| 124 | + "\n", |
| 125 | + "train_norm = normal.copy()\n", |
| 126 | + "test_norm = test.copy()\n", |
| 127 | + "for col in features:\n", |
| 128 | + " train_norm[col] = (train_norm[col] - stats[col]['mean']) / stats[col]['std']\n", |
| 129 | + " test_norm[col] = (test_norm[col] - stats[col]['mean']) / stats[col]['std']\n", |
| 130 | + "\n", |
| 131 | + "print(f'Training: {len(train_norm):,} (normal only)')\n", |
| 132 | + "print(f'Test: {len(test_norm):,} (all labels)')\n", |
| 133 | + "print(f'\\nNormalization stats:')\n", |
| 134 | + "for k, v in stats.items():\n", |
| 135 | + " print(f' {k}: mean={v[\"mean\"]:.3f}, std={v[\"std\"]:.3f}')" |
| 136 | + ] |
| 137 | + }, |
| 138 | + { |
| 139 | + "cell_type": "markdown", |
| 140 | + "metadata": {}, |
| 141 | + "source": [ |
| 142 | + "## 3. Train Random Cut Forest Model" |
| 143 | + ] |
| 144 | + }, |
| 145 | + { |
| 146 | + "cell_type": "code", |
| 147 | + "execution_count": null, |
| 148 | + "metadata": {}, |
| 149 | + "outputs": [], |
| 150 | + "source": [ |
| 151 | + "sm = boto3.client('sagemaker', region_name=REGION)\n", |
| 152 | + "s3 = boto3.client('s3', region_name=REGION)\n", |
| 153 | + "\n", |
| 154 | + "train_array = train_norm[features].values.astype('float32')\n", |
| 155 | + "job_name = f'tire-rcf-{datetime.now().strftime(\"%Y%m%d-%H%M%S\")}'\n", |
| 156 | + "prefix = f'tire-prediction/training/{job_name}'\n", |
| 157 | + "\n", |
| 158 | + "# Upload training CSV\n", |
| 159 | + "buf = io.StringIO()\n", |
| 160 | + "pd.DataFrame(train_array).to_csv(buf, header=False, index=False)\n", |
| 161 | + "s3.put_object(Bucket=BUCKET, Key=f'{prefix}/train/train.csv', Body=buf.getvalue())\n", |
| 162 | + "print(f'Uploaded {len(train_array):,} training samples to s3://{BUCKET}/{prefix}/train/')\n", |
| 163 | + "\n", |
| 164 | + "# RCF container\n", |
| 165 | + "acct_map = {'us-east-1': '382416733822', 'us-east-2': '404615174143', 'us-west-2': '174872318107'}\n", |
| 166 | + "image = f'{acct_map.get(REGION, \"404615174143\")}.dkr.ecr.{REGION}.amazonaws.com/randomcutforest:latest'\n", |
| 167 | + "\n", |
| 168 | + "sm.create_training_job(\n", |
| 169 | + " TrainingJobName=job_name,\n", |
| 170 | + " AlgorithmSpecification={'TrainingImage': image, 'TrainingInputMode': 'File'},\n", |
| 171 | + " RoleArn=ROLE_ARN,\n", |
| 172 | + " InputDataConfig=[{'ChannelName': 'train', 'DataSource': {'S3DataSource': {\n", |
| 173 | + " 'S3DataType': 'S3Prefix', 'S3Uri': f's3://{BUCKET}/{prefix}/train',\n", |
| 174 | + " 'S3DataDistributionType': 'ShardedByS3Key'}}, 'ContentType': 'text/csv;label_size=0'}],\n", |
| 175 | + " OutputDataConfig={'S3OutputPath': f's3://{BUCKET}/{prefix}/output'},\n", |
| 176 | + " ResourceConfig={'InstanceType': 'ml.m5.large', 'InstanceCount': 1, 'VolumeSizeInGB': 10},\n", |
| 177 | + " StoppingCondition={'MaxRuntimeInSeconds': 600},\n", |
| 178 | + " HyperParameters={'num_samples_per_tree': '256', 'num_trees': '100', 'feature_dim': '4'},\n", |
| 179 | + ")\n", |
| 180 | + "print(f'Training job started: {job_name}')\n", |
| 181 | + "\n", |
| 182 | + "# Wait\n", |
| 183 | + "while True:\n", |
| 184 | + " status = sm.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n", |
| 185 | + " print(f' {status}')\n", |
| 186 | + " if status in ('Completed', 'Failed', 'Stopped'): break\n", |
| 187 | + " time.sleep(30)\n", |
| 188 | + "\n", |
| 189 | + "model_data = sm.describe_training_job(TrainingJobName=job_name)['ModelArtifacts']['S3ModelArtifacts']\n", |
| 190 | + "print(f'\\n✅ Model: {model_data}')" |
| 191 | + ] |
| 192 | + }, |
| 193 | + { |
| 194 | + "cell_type": "markdown", |
| 195 | + "metadata": {}, |
| 196 | + "source": [ |
| 197 | + "## 4. Deploy Endpoint" |
| 198 | + ] |
| 199 | + }, |
| 200 | + { |
| 201 | + "cell_type": "code", |
| 202 | + "execution_count": null, |
| 203 | + "metadata": {}, |
| 204 | + "outputs": [], |
| 205 | + "source": [ |
| 206 | + "ts = datetime.now().strftime('%Y%m%d-%H%M%S')\n", |
| 207 | + "endpoint_name = f'tire-anomaly-{datetime.now().strftime(\"%Y%m%d\")}'\n", |
| 208 | + "model_name = f'tire-rcf-{ts}'\n", |
| 209 | + "config_name = f'tire-rcf-cfg-{ts}'\n", |
| 210 | + "\n", |
| 211 | + "sm.create_model(ModelName=model_name, ExecutionRoleArn=ROLE_ARN,\n", |
| 212 | + " PrimaryContainer={'Image': image, 'ModelDataUrl': model_data})\n", |
| 213 | + "\n", |
| 214 | + "sm.create_endpoint_config(EndpointConfigName=config_name, ProductionVariants=[{\n", |
| 215 | + " 'VariantName': 'default', 'ModelName': model_name,\n", |
| 216 | + " 'InstanceType': 'ml.m5.large', 'InitialInstanceCount': 1}])\n", |
| 217 | + "\n", |
| 218 | + "sm.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name)\n", |
| 219 | + "print(f'Creating endpoint: {endpoint_name}')\n", |
| 220 | + "\n", |
| 221 | + "while True:\n", |
| 222 | + " status = sm.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus']\n", |
| 223 | + " print(f' {status}')\n", |
| 224 | + " if status in ('InService', 'Failed'): break\n", |
| 225 | + " time.sleep(30)\n", |
| 226 | + "\n", |
| 227 | + "print(f'\\n✅ Endpoint ready: {endpoint_name}')" |
| 228 | + ] |
| 229 | + }, |
| 230 | + { |
| 231 | + "cell_type": "markdown", |
| 232 | + "metadata": {}, |
| 233 | + "source": [ |
| 234 | + "## 5. Evaluate Model" |
| 235 | + ] |
| 236 | + }, |
| 237 | + { |
| 238 | + "cell_type": "code", |
| 239 | + "execution_count": null, |
| 240 | + "metadata": {}, |
| 241 | + "outputs": [], |
| 242 | + "source": [ |
| 243 | + "sm_runtime = boto3.client('sagemaker-runtime', region_name=REGION)\n", |
| 244 | + "\n", |
| 245 | + "test_array = test_norm[features].values.astype('float32')\n", |
| 246 | + "labels = test_norm['label'].values\n", |
| 247 | + "\n", |
| 248 | + "# Get anomaly scores in batches\n", |
| 249 | + "scores = []\n", |
| 250 | + "batch_size = 500\n", |
| 251 | + "for i in range(0, min(len(test_array), 10000), batch_size): # Sample 10K for speed\n", |
| 252 | + " batch = test_array[i:i+batch_size]\n", |
| 253 | + " body = '\\n'.join(','.join(str(v) for v in row) for row in batch)\n", |
| 254 | + " resp = sm_runtime.invoke_endpoint(\n", |
| 255 | + " EndpointName=endpoint_name, ContentType='text/csv', Body=body)\n", |
| 256 | + " result = json.loads(resp['Body'].read().decode())\n", |
| 257 | + " scores.extend([r['score'] for r in result['scores']])\n", |
| 258 | + " if i % 2000 == 0: print(f' {i}/{min(len(test_array), 10000)}')\n", |
| 259 | + "\n", |
| 260 | + "scores = np.array(scores)\n", |
| 261 | + "sample_labels = labels[:len(scores)]\n", |
| 262 | + "\n", |
| 263 | + "# Threshold at 95th percentile of normal scores\n", |
| 264 | + "normal_scores = scores[sample_labels == 'normal']\n", |
| 265 | + "anomaly_scores = scores[sample_labels != 'normal']\n", |
| 266 | + "threshold = float(np.percentile(normal_scores, 95))\n", |
| 267 | + "\n", |
| 268 | + "print(f'Threshold: {threshold:.4f}')\n", |
| 269 | + "print(f'Normal scores: mean={normal_scores.mean():.4f}, p95={np.percentile(normal_scores, 95):.4f}')\n", |
| 270 | + "print(f'Anomaly scores: mean={anomaly_scores.mean():.4f}, p95={np.percentile(anomaly_scores, 95):.4f}')\n", |
| 271 | + "\n", |
| 272 | + "# Metrics\n", |
| 273 | + "predictions = scores > threshold\n", |
| 274 | + "true_anomalies = sample_labels != 'normal'\n", |
| 275 | + "tp = np.sum(predictions & true_anomalies)\n", |
| 276 | + "fp = np.sum(predictions & ~true_anomalies)\n", |
| 277 | + "fn = np.sum(~predictions & true_anomalies)\n", |
| 278 | + "precision = tp / (tp + fp) if (tp + fp) > 0 else 0\n", |
| 279 | + "recall = tp / (tp + fn) if (tp + fn) > 0 else 0\n", |
| 280 | + "f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0\n", |
| 281 | + "\n", |
| 282 | + "print(f'\\nPrecision: {precision:.3f}')\n", |
| 283 | + "print(f'Recall: {recall:.3f}')\n", |
| 284 | + "print(f'F1 Score: {f1:.3f}')" |
| 285 | + ] |
| 286 | + }, |
| 287 | + { |
| 288 | + "cell_type": "code", |
| 289 | + "execution_count": null, |
| 290 | + "metadata": {}, |
| 291 | + "outputs": [], |
| 292 | + "source": [ |
| 293 | + "# Visualize score distributions\n", |
| 294 | + "fig, ax = plt.subplots(figsize=(10, 5))\n", |
| 295 | + "ax.hist(normal_scores, bins=50, alpha=0.5, label='Normal', color='green')\n", |
| 296 | + "ax.hist(anomaly_scores, bins=50, alpha=0.5, label='Anomaly', color='red')\n", |
| 297 | + "ax.axvline(x=threshold, color='black', linestyle='--', label=f'Threshold ({threshold:.2f})')\n", |
| 298 | + "ax.set_xlabel('Anomaly Score')\n", |
| 299 | + "ax.set_ylabel('Count')\n", |
| 300 | + "ax.set_title('RCF Anomaly Score Distribution')\n", |
| 301 | + "ax.legend()\n", |
| 302 | + "plt.show()" |
| 303 | + ] |
| 304 | + }, |
| 305 | + { |
| 306 | + "cell_type": "markdown", |
| 307 | + "metadata": {}, |
| 308 | + "source": [ |
| 309 | + "## 6. Save Configuration to SSM" |
| 310 | + ] |
| 311 | + }, |
| 312 | + { |
| 313 | + "cell_type": "code", |
| 314 | + "execution_count": null, |
| 315 | + "metadata": {}, |
| 316 | + "outputs": [], |
| 317 | + "source": [ |
| 318 | + "ssm = boto3.client('ssm', region_name=REGION)\n", |
| 319 | + "prefix = f'/tire-prediction/{STAGE}'\n", |
| 320 | + "\n", |
| 321 | + "ssm.put_parameter(Name=f'{prefix}/normalization-stats', Value=json.dumps(stats), Type='String', Overwrite=True)\n", |
| 322 | + "ssm.put_parameter(Name=f'{prefix}/anomaly-threshold', Value=json.dumps({'threshold': threshold}), Type='String', Overwrite=True)\n", |
| 323 | + "ssm.put_parameter(Name=f'{prefix}/endpoint-name', Value=endpoint_name, Type='String', Overwrite=True)\n", |
| 324 | + "\n", |
| 325 | + "print(f'✅ Config saved to SSM ({prefix}/*)')\n", |
| 326 | + "print(f' Normalization stats: {json.dumps(stats, indent=2)}')\n", |
| 327 | + "print(f' Threshold: {threshold}')\n", |
| 328 | + "print(f' Endpoint: {endpoint_name}')" |
| 329 | + ] |
| 330 | + } |
| 331 | + ], |
| 332 | + "metadata": { |
| 333 | + "kernelspec": { |
| 334 | + "display_name": "Python 3", |
| 335 | + "language": "python", |
| 336 | + "name": "python3" |
| 337 | + }, |
| 338 | + "language_info": { |
| 339 | + "name": "python", |
| 340 | + "version": "3.13.0" |
| 341 | + } |
| 342 | + }, |
| 343 | + "nbformat": 4, |
| 344 | + "nbformat_minor": 4 |
| 345 | +} |
0 commit comments