55#
66
77""" Custom user defined function for anomaly detection in weld sensor data. """
8-
8+ import json
99import os
1010import logging
1111import time
1212import warnings
13+
14+ log_level = os .getenv ('KAPACITOR_LOGGING_LEVEL' , 'INFO' ).upper ()
15+ enable_benchmarking = os .getenv ('ENABLE_BENCHMARKING' , 'false' ).upper () == 'TRUE'
16+ total_no_pts = int (os .getenv ('BENCHMARK_TOTAL_PTS' , "0" ))
17+ logging_level = getattr (logging , log_level , logging .INFO )
18+
19+ # Configure logging before importing sklearnex so basicConfig takes effect
20+ logging .basicConfig (
21+ level = logging_level ,
22+ format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' ,
23+ )
24+ logging .getLogger ("sklearnex" ).setLevel (logging .INFO )
25+
1326from kapacitor .udf .agent import Agent , Handler
1427from kapacitor .udf import udf_pb2
15- import catboost as cb
16- import pandas as pd
1728import numpy as np
18-
19-
29+ import joblib
30+ from sklearnex import patch_sklearn , config_context , set_config
31+ patch_sklearn ()
2032
2133warnings .filterwarnings (
2234 "ignore" ,
2335 message = ".*Threading.*parallel backend is not supported by Extension for Scikit-learn.*"
2436)
2537
26-
27- log_level = os .getenv ('KAPACITOR_LOGGING_LEVEL' , 'INFO' ).upper ()
28- enable_benchmarking = os .getenv ('ENABLE_BENCHMARKING' , 'false' ).upper () == 'TRUE'
29- total_no_pts = int (os .getenv ('BENCHMARK_TOTAL_PTS' , "0" ))
30- logging_level = getattr (logging , log_level , logging .INFO )
31-
3238# Primary weld current threshold
3339WELD_CURRENT_THRESHOLD = 50
34-
35- # Configure logging
36- logging .basicConfig (
37- level = logging_level , # Set the log level to DEBUG
38- format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' , # Log format
39- )
40-
40+ GOOD_WELD_LABEL = "Good Weld"
41+ NO_WELD_LABEL = "No Weld"
42+ FEATURES = [
43+ "Pressure" ,
44+ "CO2 Weld Flow" ,
45+ "Feed" ,
46+ "Primary Weld Current" ,
47+ "Secondary Weld Voltage" ,
48+ ]
49+ MODEL_WITH_EXPLANATION = True
4150logger = logging .getLogger ()
4251
4352# Anomaly detection on the weld sensor data
@@ -48,22 +57,33 @@ class AnomalyDetectorHandler(Handler):
4857 def __init__ (self , agent ):
4958 self ._agent = agent
5059 # Need to enable after model training
51- model_name = (os .path .basename (__file__ )).replace ('.py' , '.cb' )
60+ self .info_data = {}
61+ model_name = (os .path .basename (__file__ )).replace ('.py' , '.pkl' )
62+ label_name = (os .path .basename (__file__ )).replace ('.py' , '_labels.pkl' )
5263 model_path = os .path .join (os .path .dirname (os .path .abspath (__file__ )),
5364 "../models/" + model_name )
5465 model_path = os .path .abspath (model_path )
66+ label_path = os .path .join (os .path .dirname (os .path .abspath (__file__ )),
67+ "../models/" + label_name )
68+ label_path = os .path .abspath (label_path )
69+ self .pipeline = joblib .load (model_path )
70+ self .le = joblib .load (label_path )
71+ self .device = os .getenv ('DEVICE' , 'gpu' ).strip ().lower () or 'gpu'
72+ logger .info (f"on device: { self .device } " )
73+ global MODEL_WITH_EXPLANATION
74+ if MODEL_WITH_EXPLANATION :
75+ logger .info ("Model explanations are enabled for this UDF." )
76+ model_json_info = (os .path .basename (__file__ )).replace ('.py' , '.json' )
77+
78+ info_path = os .path .join (os .path .dirname (os .path .abspath (__file__ )),
79+ "../models/" + model_json_info )
80+ info_path = os .path .abspath (info_path )
5581
56- # Initialize a CatBoostClassifier model for anomaly detection
57- self .model = cb .CatBoostClassifier (
58- depth = 10 , # Set the depth of each tree to 10
59- iterations = 2000 , # Number of boosting iterations (trees)
60- learning_rate = 0.1 , # Step size for each iteration
61- task_type = "CPU" , # Specify to use CPU for training/inference
62- devices = "1:2" , # Specify device IDs (not used for CPU, but kept for config compatibility)
63- random_seed = 40 , # Set random seed for reproducibility
64- )
65-
66- self .model .load_model (model_path )
82+ with open (info_path , "r" , encoding = "utf-8" ) as f :
83+ self .info_data = json .load (f )
84+ logger .info (f"Model : { self .info_data .get ('algorithm' , 'unknown' )} " )
85+ logger .info (f"Classes : { len (self .info_data .get ('classes' , []))} " )
86+ logger .info (f"Trained w/ Intel: { self .info_data .get ('intel_patched' , 'unknown' )} " )
6787
6888 self .points_received = {}
6989 global total_no_pts
@@ -103,6 +123,63 @@ def begin_batch(self, begin_req):
103123 """ A batch has begun.
104124 """
105125 raise Exception ("not supported" )
126+
127+ def _build_explanation (self , input_row : dict , predicted_category : str , prob_map : dict , model_info : dict ) -> dict :
128+ """Create a human-readable reason block for why a row was classified as a category."""
129+ stats = model_info .get ("class_feature_stats" , {}) if model_info else {}
130+ pred_stats = stats .get (predicted_category , {})
131+ good_stats = stats .get (GOOD_WELD_LABEL , {})
132+
133+ # Sort probabilities and include top alternatives for context.
134+ ranked = sorted (prob_map .items (), key = lambda kv : kv [1 ], reverse = True )
135+ top_probs = [{"category" : k , "probability" : round (float (v ), 6 )} for k , v in ranked [:3 ]]
136+
137+ signal_features = []
138+ for feat in FEATURES :
139+ if feat not in pred_stats or feat not in good_stats :
140+ continue
141+ value = float (input_row [feat ])
142+ pred_mean = float (pred_stats [feat ].get ("mean" , 0.0 ))
143+ pred_std = max (float (pred_stats [feat ].get ("std" , 0.0 )), 1e-6 )
144+ good_mean = float (good_stats [feat ].get ("mean" , 0.0 ))
145+ good_std = max (float (good_stats [feat ].get ("std" , 0.0 )), 1e-6 )
146+
147+ # Positive score means closer to predicted class profile than Good Weld profile.
148+ z_to_pred = abs (value - pred_mean ) / pred_std
149+ z_to_good = abs (value - good_mean ) / good_std
150+ evidence = z_to_good - z_to_pred
151+
152+ signal_features .append (
153+ {
154+ "feature" : feat ,
155+ "value" : round (value , 6 ),
156+ "predicted_mean" : round (pred_mean , 6 ),
157+ "good_weld_mean" : round (good_mean , 6 ),
158+ "evidence_score" : round (float (evidence ), 6 ),
159+ }
160+ )
161+
162+ signal_features .sort (key = lambda x : x ["evidence_score" ], reverse = True )
163+ top_signals = signal_features [:3 ]
164+
165+ if top_signals :
166+ reason = (
167+ f"Classified as { predicted_category } because key signals "
168+ f"({ ', ' .join (s ['feature' ] for s in top_signals )} ) align more with "
169+ f"{ predicted_category } profile than Good Weld profile."
170+ )
171+ else :
172+ reason = (
173+ f"Classified as { predicted_category } based on model probability ranking; "
174+ "class profile statistics were not available."
175+ )
176+
177+ return {
178+ "reason" : reason ,
179+ "top_probabilities" : top_probs ,
180+ "top_signal_features" : top_signals ,
181+ }
182+
106183
107184 def point (self , point ):
108185 """ A point has arrived.
@@ -128,17 +205,71 @@ def point(self, point):
128205
129206 for key , value in point .fieldsInt .items ():
130207 fields [key ] = value
131-
132- point_series = pd .Series (fields )
133- if "Primary Weld Current" in point_series and point_series ["Primary Weld Current" ] > WELD_CURRENT_THRESHOLD :
134- defect_likelihood_main = self .model .predict_proba (point_series )
135- bad_defect = defect_likelihood_main [0 ]* 100
136- good_defect = defect_likelihood_main [1 ]* 100
137- if bad_defect > 50 :
138- point .fieldsDouble ["anomaly_status" ] = 1.0
139- logger .info (f"Good Weld: { good_defect :.2f} %, Defective Weld: { bad_defect :.2f} %" )
208+ for key , value in point .fieldsString .items ():
209+ fields [key ] = value
210+
211+ if "Primary Weld Current" in fields and fields ["Primary Weld Current" ] < WELD_CURRENT_THRESHOLD :
212+ point .fieldsString ["predicted_category" ] = NO_WELD_LABEL
213+ point .fieldsDouble ["Good Weld" ] = 0.0
214+ point .fieldsDouble ["Defective Weld" ] = 0.0
215+ point .fieldsDouble ["anomaly_status" ] = 0.0
216+ logger .debug (
217+ "Primary Weld Current below threshold (%d). Classified as %s." ,
218+ WELD_CURRENT_THRESHOLD ,
219+ NO_WELD_LABEL ,
220+ )
221+ elif "Primary Weld Current" in fields and fields ["Primary Weld Current" ] >= WELD_CURRENT_THRESHOLD :
222+ missing_features = [f for f in FEATURES if f not in fields ]
223+ if missing_features :
224+ logger .warning ("Missing required features for inference: %s" , missing_features )
225+ else :
226+ x = np .array (
227+ [[
228+ fields ["Pressure" ],
229+ fields ["CO2 Weld Flow" ],
230+ fields ["Feed" ],
231+ fields ["Primary Weld Current" ],
232+ fields ["Secondary Weld Voltage" ],
233+ ]],
234+ dtype = np .float32 ,
235+ )
236+
237+ with config_context (target_offload = self .device , allow_fallback_to_host = True ):
238+ pred_idx = self .pipeline .predict (x )[0 ]
239+ pred_proba = self .pipeline .predict_proba (x )[0 ]
240+ classes = list (self .le .classes_ )
241+ prob_map = {cls : float (p ) for cls , p in zip (classes , pred_proba )}
242+
243+ predicted_category = self .le .inverse_transform ([pred_idx ])[0 ]
244+ point .fieldsString ["predicted_category" ] = str (predicted_category )
245+ good_weld_prob = prob_map .get (GOOD_WELD_LABEL , 0.0 )
246+ good_defect = good_weld_prob * 100.0
247+ bad_defect = (1.0 - good_weld_prob ) * 100.0
248+ confidence = round (float (np .max (pred_proba )), 6 )
249+ if MODEL_WITH_EXPLANATION :
250+ explanation = self ._build_explanation (fields , predicted_category , prob_map , self .info_data )
251+
252+ data_prediction = {
253+ "predicted_category" : predicted_category ,
254+ "is_defect" : predicted_category != GOOD_WELD_LABEL ,
255+ "defect_probability" : round (1.0 - good_weld_prob , 6 ),
256+ "good_weld_probability" : round (good_weld_prob , 6 ),
257+ "confidence" : confidence ,
258+ "probabilities" : prob_map ,
259+ "explanation" : explanation if MODEL_WITH_EXPLANATION else "N/A" ,
260+ }
261+ logger .debug (
262+ "Prediction details: %s" ,
263+ data_prediction ,
264+ )
265+
266+ point .fieldsString ["prediction_details" ] = str (data_prediction )
267+
268+ if bad_defect > 50 :
269+ point .fieldsDouble ["anomaly_status" ] = 1.0
270+ logger .info ("Good Weld: %.2f%%, Defective Weld: %.2f%%" , good_defect , bad_defect )
140271 else :
141- logger .info ("Primary Weld Current below threshold (%d). Skipping anomaly detection." , WELD_CURRENT_THRESHOLD )
272+ logger .debug ("Primary Weld Current below threshold (%d). Skipping anomaly detection." , WELD_CURRENT_THRESHOLD )
142273
143274 point .fieldsDouble ["Good Weld" ] = round (good_defect , 2 ) if "good_defect" in locals () else 0.0
144275 point .fieldsDouble ["Defective Weld" ] = round (bad_defect , 2 ) if "bad_defect" in locals () else 0.0
@@ -148,6 +279,7 @@ def point(self, point):
148279 end_end_time = time_now - point .time
149280 point .fieldsDouble ["processing_time" ] = processing_time
150281 point .fieldsDouble ["end_end_time" ] = end_end_time
282+
151283
152284 logger .info ("Processing point %s %s for source %s" , point .time , time .time (), stream_src )
153285
0 commit comments