Skip to content

Commit 4afc303

Browse files
committed
feat: SageMaker notebook for tire model training + evaluation with visualizations
1 parent e83febf commit 4afc303

File tree

1 file changed

+345
-0
lines changed

1 file changed

+345
-0
lines changed
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Tire Anomaly Detection — Model Training & Evaluation\n",
8+
"\n",
9+
"This notebook trains a SageMaker Random Cut Forest model for tire pressure anomaly detection.\n",
10+
"\n",
11+
"**What it does:**\n",
12+
"1. Loads the synthetic training dataset (721K records, 50 vehicles, 6 months)\n",
13+
"2. Prepares and normalizes features (pressure, temperature, delta_pressure, delta_temp)\n",
14+
"3. Trains an RCF model on normal data only (unsupervised anomaly detection)\n",
15+
"4. Evaluates on labeled test data (slow leaks, punctures, valve failures)\n",
16+
"5. Deploys a real-time inference endpoint\n",
17+
"\n",
18+
"**Prerequisites:**\n",
19+
"- Run `python3 scripts/generate_training_data.py` first to create the dataset\n",
20+
"- SageMaker execution role with S3 and SSM access\n",
21+
"- S3 bucket for training artifacts"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"import boto3\n",
31+
"import json\n",
32+
"import io\n",
33+
"import time\n",
34+
"import numpy as np\n",
35+
"import pandas as pd\n",
36+
"import matplotlib.pyplot as plt\n",
37+
"from datetime import datetime\n",
38+
"\n",
39+
"# Configuration\n",
40+
"REGION = 'us-east-2'\n",
41+
"BUCKET = 'cms-tire-prediction-ACCOUNT-REGION' # Update with your bucket\n",
42+
"ROLE_ARN = 'arn:aws:iam::ACCOUNT:role/cms-sagemaker-execution-role' # Update\n",
43+
"STAGE = 'prod'"
44+
]
45+
},
46+
{
47+
"cell_type": "markdown",
48+
"metadata": {},
49+
"source": [
50+
"## 1. Load and Explore Training Data"
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": null,
56+
"metadata": {},
57+
"outputs": [],
58+
"source": [
59+
"df = pd.read_parquet('../data/training/tire_telemetry_full.parquet')\n",
60+
"print(f'Dataset: {len(df):,} records')\n",
61+
"print(f'Vehicles: {df.vehicle_id.nunique()}')\n",
62+
"print(f'Date range: {df.timestamp.min()} → {df.timestamp.max()}')\n",
63+
"print(f'\\nLabel distribution:')\n",
64+
"print(df.label.value_counts())"
65+
]
66+
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": null,
70+
"metadata": {},
71+
"outputs": [],
72+
"source": [
73+
"# Visualize pressure distribution by label\n",
74+
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
75+
"\n",
76+
"for label in df.label.unique():\n",
77+
" subset = df[df.label == label]\n",
78+
" axes[0].hist(subset.pressure, bins=50, alpha=0.5, label=label)\n",
79+
"axes[0].set_xlabel('Pressure (PSI)')\n",
80+
"axes[0].set_ylabel('Count')\n",
81+
"axes[0].set_title('Pressure Distribution by Label')\n",
82+
"axes[0].legend()\n",
83+
"axes[0].axvline(x=28, color='red', linestyle='--', label='Alert threshold')\n",
84+
"\n",
85+
"# Show a slow leak example\n",
86+
"leak = df[(df.label == 'slow_leak') & (df.tire_id == 'FL')].sort_values('timestamp').head(500)\n",
87+
"if len(leak) > 0:\n",
88+
" vid = leak.vehicle_id.iloc[0]\n",
89+
" vehicle_leak = df[(df.vehicle_id == vid) & (df.tire_id == 'FL')].sort_values('timestamp')\n",
90+
" axes[1].plot(range(len(vehicle_leak)), vehicle_leak.pressure.values, linewidth=0.5)\n",
91+
" axes[1].axhline(y=28, color='red', linestyle='--', label='Alert threshold')\n",
92+
" axes[1].set_xlabel('Reading #')\n",
93+
" axes[1].set_ylabel('Pressure (PSI)')\n",
94+
" axes[1].set_title(f'Slow Leak Example ({vid} FL)')\n",
95+
" axes[1].legend()\n",
96+
"\n",
97+
"plt.tight_layout()\n",
98+
"plt.show()"
99+
]
100+
},
101+
{
102+
"cell_type": "markdown",
103+
"metadata": {},
104+
"source": [
105+
"## 2. Prepare Features"
106+
]
107+
},
108+
{
109+
"cell_type": "code",
110+
"execution_count": null,
111+
"metadata": {},
112+
"outputs": [],
113+
"source": [
114+
"features = ['pressure', 'temperature', 'delta_pressure', 'delta_temp']\n",
115+
"\n",
116+
"# Train on normal data only (unsupervised)\n",
117+
"normal = df[df.label == 'normal'][features].dropna()\n",
118+
"test = df[features + ['label']].dropna()\n",
119+
"\n",
120+
"# Normalize\n",
121+
"stats = {}\n",
122+
"for col in features:\n",
123+
" stats[col] = {'mean': float(normal[col].mean()), 'std': float(normal[col].std())}\n",
124+
"\n",
125+
"train_norm = normal.copy()\n",
126+
"test_norm = test.copy()\n",
127+
"for col in features:\n",
128+
" train_norm[col] = (train_norm[col] - stats[col]['mean']) / stats[col]['std']\n",
129+
" test_norm[col] = (test_norm[col] - stats[col]['mean']) / stats[col]['std']\n",
130+
"\n",
131+
"print(f'Training: {len(train_norm):,} (normal only)')\n",
132+
"print(f'Test: {len(test_norm):,} (all labels)')\n",
133+
"print(f'\\nNormalization stats:')\n",
134+
"for k, v in stats.items():\n",
135+
" print(f' {k}: mean={v[\"mean\"]:.3f}, std={v[\"std\"]:.3f}')"
136+
]
137+
},
138+
{
139+
"cell_type": "markdown",
140+
"metadata": {},
141+
"source": [
142+
"## 3. Train Random Cut Forest Model"
143+
]
144+
},
145+
{
146+
"cell_type": "code",
147+
"execution_count": null,
148+
"metadata": {},
149+
"outputs": [],
150+
"source": [
151+
"sm = boto3.client('sagemaker', region_name=REGION)\n",
152+
"s3 = boto3.client('s3', region_name=REGION)\n",
153+
"\n",
154+
"train_array = train_norm[features].values.astype('float32')\n",
155+
"job_name = f'tire-rcf-{datetime.now().strftime(\"%Y%m%d-%H%M%S\")}'\n",
156+
"prefix = f'tire-prediction/training/{job_name}'\n",
157+
"\n",
158+
"# Upload training CSV\n",
159+
"buf = io.StringIO()\n",
160+
"pd.DataFrame(train_array).to_csv(buf, header=False, index=False)\n",
161+
"s3.put_object(Bucket=BUCKET, Key=f'{prefix}/train/train.csv', Body=buf.getvalue())\n",
162+
"print(f'Uploaded {len(train_array):,} training samples to s3://{BUCKET}/{prefix}/train/')\n",
163+
"\n",
164+
"# RCF container\n",
165+
"acct_map = {'us-east-1': '382416733822', 'us-east-2': '404615174143', 'us-west-2': '174872318107'}\n",
166+
"image = f'{acct_map.get(REGION, \"404615174143\")}.dkr.ecr.{REGION}.amazonaws.com/randomcutforest:latest'\n",
167+
"\n",
168+
"sm.create_training_job(\n",
169+
" TrainingJobName=job_name,\n",
170+
" AlgorithmSpecification={'TrainingImage': image, 'TrainingInputMode': 'File'},\n",
171+
" RoleArn=ROLE_ARN,\n",
172+
" InputDataConfig=[{'ChannelName': 'train', 'DataSource': {'S3DataSource': {\n",
173+
" 'S3DataType': 'S3Prefix', 'S3Uri': f's3://{BUCKET}/{prefix}/train',\n",
174+
" 'S3DataDistributionType': 'ShardedByS3Key'}}, 'ContentType': 'text/csv;label_size=0'}],\n",
175+
" OutputDataConfig={'S3OutputPath': f's3://{BUCKET}/{prefix}/output'},\n",
176+
" ResourceConfig={'InstanceType': 'ml.m5.large', 'InstanceCount': 1, 'VolumeSizeInGB': 10},\n",
177+
" StoppingCondition={'MaxRuntimeInSeconds': 600},\n",
178+
" HyperParameters={'num_samples_per_tree': '256', 'num_trees': '100', 'feature_dim': '4'},\n",
179+
")\n",
180+
"print(f'Training job started: {job_name}')\n",
181+
"\n",
182+
"# Wait\n",
183+
"while True:\n",
184+
" status = sm.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n",
185+
" print(f' {status}')\n",
186+
" if status in ('Completed', 'Failed', 'Stopped'): break\n",
187+
" time.sleep(30)\n",
188+
"\n",
189+
"model_data = sm.describe_training_job(TrainingJobName=job_name)['ModelArtifacts']['S3ModelArtifacts']\n",
190+
"print(f'\\n✅ Model: {model_data}')"
191+
]
192+
},
193+
{
194+
"cell_type": "markdown",
195+
"metadata": {},
196+
"source": [
197+
"## 4. Deploy Endpoint"
198+
]
199+
},
200+
{
201+
"cell_type": "code",
202+
"execution_count": null,
203+
"metadata": {},
204+
"outputs": [],
205+
"source": [
206+
"ts = datetime.now().strftime('%Y%m%d-%H%M%S')\n",
207+
"endpoint_name = f'tire-anomaly-{datetime.now().strftime(\"%Y%m%d\")}'\n",
208+
"model_name = f'tire-rcf-{ts}'\n",
209+
"config_name = f'tire-rcf-cfg-{ts}'\n",
210+
"\n",
211+
"sm.create_model(ModelName=model_name, ExecutionRoleArn=ROLE_ARN,\n",
212+
" PrimaryContainer={'Image': image, 'ModelDataUrl': model_data})\n",
213+
"\n",
214+
"sm.create_endpoint_config(EndpointConfigName=config_name, ProductionVariants=[{\n",
215+
" 'VariantName': 'default', 'ModelName': model_name,\n",
216+
" 'InstanceType': 'ml.m5.large', 'InitialInstanceCount': 1}])\n",
217+
"\n",
218+
"sm.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name)\n",
219+
"print(f'Creating endpoint: {endpoint_name}')\n",
220+
"\n",
221+
"while True:\n",
222+
" status = sm.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus']\n",
223+
" print(f' {status}')\n",
224+
" if status in ('InService', 'Failed'): break\n",
225+
" time.sleep(30)\n",
226+
"\n",
227+
"print(f'\\n✅ Endpoint ready: {endpoint_name}')"
228+
]
229+
},
230+
{
231+
"cell_type": "markdown",
232+
"metadata": {},
233+
"source": [
234+
"## 5. Evaluate Model"
235+
]
236+
},
237+
{
238+
"cell_type": "code",
239+
"execution_count": null,
240+
"metadata": {},
241+
"outputs": [],
242+
"source": [
243+
"sm_runtime = boto3.client('sagemaker-runtime', region_name=REGION)\n",
244+
"\n",
245+
"test_array = test_norm[features].values.astype('float32')\n",
246+
"labels = test_norm['label'].values\n",
247+
"\n",
248+
"# Get anomaly scores in batches\n",
249+
"scores = []\n",
250+
"batch_size = 500\n",
251+
"for i in range(0, min(len(test_array), 10000), batch_size): # Sample 10K for speed\n",
252+
" batch = test_array[i:i+batch_size]\n",
253+
" body = '\\n'.join(','.join(str(v) for v in row) for row in batch)\n",
254+
" resp = sm_runtime.invoke_endpoint(\n",
255+
" EndpointName=endpoint_name, ContentType='text/csv', Body=body)\n",
256+
" result = json.loads(resp['Body'].read().decode())\n",
257+
" scores.extend([r['score'] for r in result['scores']])\n",
258+
" if i % 2000 == 0: print(f' {i}/{min(len(test_array), 10000)}')\n",
259+
"\n",
260+
"scores = np.array(scores)\n",
261+
"sample_labels = labels[:len(scores)]\n",
262+
"\n",
263+
"# Threshold at 95th percentile of normal scores\n",
264+
"normal_scores = scores[sample_labels == 'normal']\n",
265+
"anomaly_scores = scores[sample_labels != 'normal']\n",
266+
"threshold = float(np.percentile(normal_scores, 95))\n",
267+
"\n",
268+
"print(f'Threshold: {threshold:.4f}')\n",
269+
"print(f'Normal scores: mean={normal_scores.mean():.4f}, p95={np.percentile(normal_scores, 95):.4f}')\n",
270+
"print(f'Anomaly scores: mean={anomaly_scores.mean():.4f}, p95={np.percentile(anomaly_scores, 95):.4f}')\n",
271+
"\n",
272+
"# Metrics\n",
273+
"predictions = scores > threshold\n",
274+
"true_anomalies = sample_labels != 'normal'\n",
275+
"tp = np.sum(predictions & true_anomalies)\n",
276+
"fp = np.sum(predictions & ~true_anomalies)\n",
277+
"fn = np.sum(~predictions & true_anomalies)\n",
278+
"precision = tp / (tp + fp) if (tp + fp) > 0 else 0\n",
279+
"recall = tp / (tp + fn) if (tp + fn) > 0 else 0\n",
280+
"f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0\n",
281+
"\n",
282+
"print(f'\\nPrecision: {precision:.3f}')\n",
283+
"print(f'Recall: {recall:.3f}')\n",
284+
"print(f'F1 Score: {f1:.3f}')"
285+
]
286+
},
287+
{
288+
"cell_type": "code",
289+
"execution_count": null,
290+
"metadata": {},
291+
"outputs": [],
292+
"source": [
293+
"# Visualize score distributions\n",
294+
"fig, ax = plt.subplots(figsize=(10, 5))\n",
295+
"ax.hist(normal_scores, bins=50, alpha=0.5, label='Normal', color='green')\n",
296+
"ax.hist(anomaly_scores, bins=50, alpha=0.5, label='Anomaly', color='red')\n",
297+
"ax.axvline(x=threshold, color='black', linestyle='--', label=f'Threshold ({threshold:.2f})')\n",
298+
"ax.set_xlabel('Anomaly Score')\n",
299+
"ax.set_ylabel('Count')\n",
300+
"ax.set_title('RCF Anomaly Score Distribution')\n",
301+
"ax.legend()\n",
302+
"plt.show()"
303+
]
304+
},
305+
{
306+
"cell_type": "markdown",
307+
"metadata": {},
308+
"source": [
309+
"## 6. Save Configuration to SSM"
310+
]
311+
},
312+
{
313+
"cell_type": "code",
314+
"execution_count": null,
315+
"metadata": {},
316+
"outputs": [],
317+
"source": [
318+
"ssm = boto3.client('ssm', region_name=REGION)\n",
319+
"prefix = f'/tire-prediction/{STAGE}'\n",
320+
"\n",
321+
"ssm.put_parameter(Name=f'{prefix}/normalization-stats', Value=json.dumps(stats), Type='String', Overwrite=True)\n",
322+
"ssm.put_parameter(Name=f'{prefix}/anomaly-threshold', Value=json.dumps({'threshold': threshold}), Type='String', Overwrite=True)\n",
323+
"ssm.put_parameter(Name=f'{prefix}/endpoint-name', Value=endpoint_name, Type='String', Overwrite=True)\n",
324+
"\n",
325+
"print(f'✅ Config saved to SSM ({prefix}/*)')\n",
326+
"print(f' Normalization stats: {json.dumps(stats, indent=2)}')\n",
327+
"print(f' Threshold: {threshold}')\n",
328+
"print(f' Endpoint: {endpoint_name}')"
329+
]
330+
}
331+
],
332+
"metadata": {
333+
"kernelspec": {
334+
"display_name": "Python 3",
335+
"language": "python",
336+
"name": "python3"
337+
},
338+
"language_info": {
339+
"name": "python",
340+
"version": "3.13.0"
341+
}
342+
},
343+
"nbformat": 4,
344+
"nbformat_minor": 4
345+
}

0 commit comments

Comments
 (0)