Skip to content

Commit dcbca4c

Browse files
committed
Integrate new classifier ML model
Signed-off-by: Vellaisamy, Sathyendran <sathyendran.vellaisamy@intel.com>
1 parent 4fcb33e commit dcbca4c

File tree

4 files changed

+180
-43
lines changed

4 files changed

+180
-43
lines changed

manufacturing-ai-suite/industrial-edge-insights-multimodal/configs/telegraf/config/Telegraf.conf

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3948,6 +3948,10 @@
39483948
data_format = "json"
39493949
json_time_key = "time"
39503950
json_time_format = "RFC3339"
3951+
3952+
# Include string fields explicitly
3953+
json_string_fields = ["defect_type"]
3954+
39513955
#
39523956
# # if true, messages that can't be delivered while the subscriber is offline
39533957
# # will be delivered when it comes back (such as on service restart).
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
catboost==1.2.10

manufacturing-ai-suite/industrial-edge-insights-multimodal/configs/time-series-analytics-microservice/udfs/weld_anomaly_detector.py

Lines changed: 172 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,48 @@
55
#
66

77
""" Custom user defined function for anomaly detection in weld sensor data. """
8-
8+
import json
99
import os
1010
import logging
1111
import time
1212
import warnings
13+
14+
log_level = os.getenv('KAPACITOR_LOGGING_LEVEL', 'INFO').upper()
15+
enable_benchmarking = os.getenv('ENABLE_BENCHMARKING', 'false').upper() == 'TRUE'
16+
total_no_pts = int(os.getenv('BENCHMARK_TOTAL_PTS', "0"))
17+
logging_level = getattr(logging, log_level, logging.INFO)
18+
19+
# Configure logging before importing sklearnex so basicConfig takes effect
20+
logging.basicConfig(
21+
level=logging_level,
22+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
23+
)
24+
logging.getLogger("sklearnex").setLevel(logging.INFO)
25+
1326
from kapacitor.udf.agent import Agent, Handler
1427
from kapacitor.udf import udf_pb2
15-
import catboost as cb
16-
import pandas as pd
1728
import numpy as np
18-
19-
29+
import joblib
30+
from sklearnex import patch_sklearn, config_context, set_config
31+
patch_sklearn()
2032

2133
warnings.filterwarnings(
2234
"ignore",
2335
message=".*Threading.*parallel backend is not supported by Extension for Scikit-learn.*"
2436
)
2537

26-
27-
log_level = os.getenv('KAPACITOR_LOGGING_LEVEL', 'INFO').upper()
28-
enable_benchmarking = os.getenv('ENABLE_BENCHMARKING', 'false').upper() == 'TRUE'
29-
total_no_pts = int(os.getenv('BENCHMARK_TOTAL_PTS', "0"))
30-
logging_level = getattr(logging, log_level, logging.INFO)
31-
3238
# Primary weld current threshold
3339
WELD_CURRENT_THRESHOLD = 50
34-
35-
# Configure logging
36-
logging.basicConfig(
37-
level=logging_level, # Set the log level to DEBUG
38-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', # Log format
39-
)
40-
40+
GOOD_WELD_LABEL = "Good Weld"
41+
NO_WELD_LABEL = "No Weld"
42+
FEATURES = [
43+
"Pressure",
44+
"CO2 Weld Flow",
45+
"Feed",
46+
"Primary Weld Current",
47+
"Secondary Weld Voltage",
48+
]
49+
MODEL_WITH_EXPLANATION = True
4150
logger = logging.getLogger()
4251

4352
# Anomaly detection on the weld sensor data
@@ -48,22 +57,33 @@ class AnomalyDetectorHandler(Handler):
4857
def __init__(self, agent):
4958
self._agent = agent
5059
# Need to enable after model training
51-
model_name = (os.path.basename(__file__)).replace('.py', '.cb')
60+
self.info_data = {}
61+
model_name = (os.path.basename(__file__)).replace('.py', '.pkl')
62+
label_name = (os.path.basename(__file__)).replace('.py', '_labels.pkl')
5263
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
5364
"../models/" + model_name)
5465
model_path = os.path.abspath(model_path)
66+
label_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
67+
"../models/" + label_name)
68+
label_path = os.path.abspath(label_path)
69+
self.pipeline = joblib.load(model_path)
70+
self.le = joblib.load(label_path)
71+
self.device = os.getenv('DEVICE', 'gpu').strip().lower() or 'gpu'
72+
logger.info(f"on device: {self.device}")
73+
global MODEL_WITH_EXPLANATION
74+
if MODEL_WITH_EXPLANATION:
75+
logger.info("Model explanations are enabled for this UDF.")
76+
model_json_info = (os.path.basename(__file__)).replace('.py', '.json')
77+
78+
info_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
79+
"../models/" + model_json_info)
80+
info_path = os.path.abspath(info_path)
5581

56-
# Initialize a CatBoostClassifier model for anomaly detection
57-
self.model = cb.CatBoostClassifier(
58-
depth=10, # Set the depth of each tree to 10
59-
iterations=2000, # Number of boosting iterations (trees)
60-
learning_rate=0.1, # Step size for each iteration
61-
task_type="CPU", # Specify to use CPU for training/inference
62-
devices="1:2", # Specify device IDs (not used for CPU, but kept for config compatibility)
63-
random_seed=40, # Set random seed for reproducibility
64-
)
65-
66-
self.model.load_model(model_path)
82+
with open(info_path, "r", encoding="utf-8") as f:
83+
self.info_data = json.load(f)
84+
logger.info(f"Model : {self.info_data.get('algorithm', 'unknown')}")
85+
logger.info(f"Classes : {len(self.info_data.get('classes', []))}")
86+
logger.info(f"Trained w/ Intel: {self.info_data.get('intel_patched', 'unknown')}")
6787

6888
self.points_received = {}
6989
global total_no_pts
@@ -103,6 +123,63 @@ def begin_batch(self, begin_req):
103123
""" A batch has begun.
104124
"""
105125
raise Exception("not supported")
126+
127+
def _build_explanation(self, input_row: dict, predicted_category: str, prob_map: dict, model_info: dict) -> dict:
128+
"""Create a human-readable reason block for why a row was classified as a category."""
129+
stats = model_info.get("class_feature_stats", {}) if model_info else {}
130+
pred_stats = stats.get(predicted_category, {})
131+
good_stats = stats.get(GOOD_WELD_LABEL, {})
132+
133+
# Sort probabilities and include top alternatives for context.
134+
ranked = sorted(prob_map.items(), key=lambda kv: kv[1], reverse=True)
135+
top_probs = [{"category": k, "probability": round(float(v), 6)} for k, v in ranked[:3]]
136+
137+
signal_features = []
138+
for feat in FEATURES:
139+
if feat not in pred_stats or feat not in good_stats:
140+
continue
141+
value = float(input_row[feat])
142+
pred_mean = float(pred_stats[feat].get("mean", 0.0))
143+
pred_std = max(float(pred_stats[feat].get("std", 0.0)), 1e-6)
144+
good_mean = float(good_stats[feat].get("mean", 0.0))
145+
good_std = max(float(good_stats[feat].get("std", 0.0)), 1e-6)
146+
147+
# Positive score means closer to predicted class profile than Good Weld profile.
148+
z_to_pred = abs(value - pred_mean) / pred_std
149+
z_to_good = abs(value - good_mean) / good_std
150+
evidence = z_to_good - z_to_pred
151+
152+
signal_features.append(
153+
{
154+
"feature": feat,
155+
"value": round(value, 6),
156+
"predicted_mean": round(pred_mean, 6),
157+
"good_weld_mean": round(good_mean, 6),
158+
"evidence_score": round(float(evidence), 6),
159+
}
160+
)
161+
162+
signal_features.sort(key=lambda x: x["evidence_score"], reverse=True)
163+
top_signals = signal_features[:3]
164+
165+
if top_signals:
166+
reason = (
167+
f"Classified as {predicted_category} because key signals "
168+
f"({', '.join(s['feature'] for s in top_signals)}) align more with "
169+
f"{predicted_category} profile than Good Weld profile."
170+
)
171+
else:
172+
reason = (
173+
f"Classified as {predicted_category} based on model probability ranking; "
174+
"class profile statistics were not available."
175+
)
176+
177+
return {
178+
"reason": reason,
179+
"top_probabilities": top_probs,
180+
"top_signal_features": top_signals,
181+
}
182+
106183

107184
def point(self, point):
108185
""" A point has arrived.
@@ -128,17 +205,71 @@ def point(self, point):
128205

129206
for key, value in point.fieldsInt.items():
130207
fields[key] = value
131-
132-
point_series = pd.Series(fields)
133-
if "Primary Weld Current" in point_series and point_series["Primary Weld Current"] > WELD_CURRENT_THRESHOLD:
134-
defect_likelihood_main = self.model.predict_proba(point_series)
135-
bad_defect = defect_likelihood_main[0]*100
136-
good_defect = defect_likelihood_main[1]*100
137-
if bad_defect > 50:
138-
point.fieldsDouble["anomaly_status"] = 1.0
139-
logger.info(f"Good Weld: {good_defect:.2f}%, Defective Weld: {bad_defect:.2f}%")
208+
for key, value in point.fieldsString.items():
209+
fields[key] = value
210+
211+
if "Primary Weld Current" in fields and fields["Primary Weld Current"] < WELD_CURRENT_THRESHOLD:
212+
point.fieldsString["predicted_category"] = NO_WELD_LABEL
213+
point.fieldsDouble["Good Weld"] = 0.0
214+
point.fieldsDouble["Defective Weld"] = 0.0
215+
point.fieldsDouble["anomaly_status"] = 0.0
216+
logger.debug(
217+
"Primary Weld Current below threshold (%d). Classified as %s.",
218+
WELD_CURRENT_THRESHOLD,
219+
NO_WELD_LABEL,
220+
)
221+
elif "Primary Weld Current" in fields and fields["Primary Weld Current"] >= WELD_CURRENT_THRESHOLD:
222+
missing_features = [f for f in FEATURES if f not in fields]
223+
if missing_features:
224+
logger.warning("Missing required features for inference: %s", missing_features)
225+
else:
226+
x = np.array(
227+
[[
228+
fields["Pressure"],
229+
fields["CO2 Weld Flow"],
230+
fields["Feed"],
231+
fields["Primary Weld Current"],
232+
fields["Secondary Weld Voltage"],
233+
]],
234+
dtype=np.float32,
235+
)
236+
237+
with config_context(target_offload=self.device, allow_fallback_to_host=True):
238+
pred_idx = self.pipeline.predict(x)[0]
239+
pred_proba = self.pipeline.predict_proba(x)[0]
240+
classes = list(self.le.classes_)
241+
prob_map = {cls: float(p) for cls, p in zip(classes, pred_proba)}
242+
243+
predicted_category = self.le.inverse_transform([pred_idx])[0]
244+
point.fieldsString["predicted_category"] = str(predicted_category)
245+
good_weld_prob = prob_map.get(GOOD_WELD_LABEL, 0.0)
246+
good_defect = good_weld_prob * 100.0
247+
bad_defect = (1.0 - good_weld_prob) * 100.0
248+
confidence = round(float(np.max(pred_proba)), 6)
249+
if MODEL_WITH_EXPLANATION:
250+
explanation = self._build_explanation(fields, predicted_category, prob_map, self.info_data)
251+
252+
data_prediction = {
253+
"predicted_category": predicted_category,
254+
"is_defect": predicted_category != GOOD_WELD_LABEL,
255+
"defect_probability": round(1.0 - good_weld_prob, 6),
256+
"good_weld_probability": round(good_weld_prob, 6),
257+
"confidence": confidence,
258+
"probabilities": prob_map,
259+
"explanation": explanation if MODEL_WITH_EXPLANATION else "N/A",
260+
}
261+
logger.debug(
262+
"Prediction details: %s",
263+
data_prediction,
264+
)
265+
266+
point.fieldsString["prediction_details"] = str(data_prediction)
267+
268+
if bad_defect > 50:
269+
point.fieldsDouble["anomaly_status"] = 1.0
270+
logger.info("Good Weld: %.2f%%, Defective Weld: %.2f%%", good_defect, bad_defect)
140271
else:
141-
logger.info("Primary Weld Current below threshold (%d). Skipping anomaly detection.", WELD_CURRENT_THRESHOLD)
272+
logger.debug("Primary Weld Current below threshold (%d). Skipping anomaly detection.", WELD_CURRENT_THRESHOLD)
142273

143274
point.fieldsDouble["Good Weld"] = round(good_defect, 2) if "good_defect" in locals() else 0.0
144275
point.fieldsDouble["Defective Weld"] = round(bad_defect, 2) if "bad_defect" in locals() else 0.0
@@ -148,6 +279,7 @@ def point(self, point):
148279
end_end_time = time_now - point.time
149280
point.fieldsDouble["processing_time"] = processing_time
150281
point.fieldsDouble["end_end_time"] = end_end_time
282+
151283

152284
logger.info("Processing point %s %s for source %s", point.time, time.time(), stream_src)
153285

manufacturing-ai-suite/industrial-edge-insights-multimodal/weld-data-simulator/publisher.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ def stream_video_and_csv(base_filename: str, simulation_data_dir: str = "/simula
289289

290290
# Write frame bytes to ffmpeg stdin
291291
ffmpeg_proc.stdin.write(frame.tobytes())
292-
292+
csv_row["defect_type"] = base_filename.replace("-", "_") # Add defect type from filename for easier analysis
293+
csv_row["Frame ID"] = 13
293294
if "Date" in csv_row:
294295
del csv_row["Date"]
295296
if "Time" in csv_row:
@@ -301,9 +302,10 @@ def stream_video_and_csv(base_filename: str, simulation_data_dir: str = "/simula
301302
now_ns = time.time_ns()
302303
seconds = now_ns // 1_000_000_000
303304
nanoseconds = now_ns % 1_000_000_000
304-
csv_row["time"] = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime(seconds)) + f".{nanoseconds:09d}Z"
305+
csv_row["time"] = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime(seconds)) + f".{nanoseconds:09d}Z"
305306
# csv_row["frame_id"] = frame_id
306307
csv_row = json.dumps(csv_row)
308+
307309
# Publish each CSV row only once
308310

309311
# global published_data

0 commit comments

Comments
 (0)