-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathguardrails.py
More file actions
326 lines (256 loc) · 10.6 KB
/
guardrails.py
File metadata and controls
326 lines (256 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
"""
Guardrail validation module for product matching system.
This module implements a validation system using a mapped function architecture
similar to normalization, where each field has corresponding guardrail functions
to enforce data quality constraints and filter invalid candidates.
"""
import pandas as pd
from typing import Dict, List, Tuple, Optional
# ============================================================================
# Configuration Constants
# ============================================================================
# Default guardrail configuration
DEFAULT_GUARDRAIL_CONFIG = {
'tolerance': {
'net_weight_ml': 0.1, # 10% tolerance for weight/volume
'price_normalized_usd': 0.5, # 50% tolerance for price
},
'enabled': [
'is_set',
'net_weight_ml',
'price_normalized_usd',
'perfume_concentration',
],
}
# ============================================================================
# Guardrail Functions
# ============================================================================
def check_concentration_mismatch(
client_product: Dict,
competitor_product: Dict,
tolerance: float = 0.0
) -> Tuple[bool, str]:
"""
Check for perfume concentration mismatch.
If both products have concentration info and they don't match, reject.
Args:
client_product: Normalized client product
competitor_product: Normalized competitor product
tolerance: Not used for concentration (exact match required)
Returns:
Tuple of (passed, reason)
Examples:
- client="edp", competitor="edp" → (True, "")
- client="edp", competitor="edt" → (False, "concentration_mismatch")
"""
client_concentration = client_product.get('perfume_concentration', '')
competitor_concentration = competitor_product.get('perfume_concentration', '')
# If both have concentration info, they must match
if client_concentration and competitor_concentration:
if client_concentration != competitor_concentration:
return False, "concentration_mismatch"
return True, ""
def check_weight_mismatch(
client_product: Dict,
competitor_product: Dict,
tolerance: float = 0.1
) -> Tuple[bool, str]:
"""
Check for weight/volume mismatch.
If both products have weight info, they must be within tolerance.
Args:
client_product: Normalized client product
competitor_product: Normalized competitor product
tolerance: Allowed relative difference (default 10%)
Returns:
Tuple of (passed, reason)
Examples:
- client=50ml, competitor=50ml → (True, "")
- client=50ml, competitor=55ml (10% diff) → (True, "")
- client=50ml, competitor=60ml (20% diff) → (False, "weight_mismatch")
"""
client_weight = client_product.get('net_weight_ml', 0.0)
competitor_weight = competitor_product.get('net_weight_ml', 0.0)
# If both have weight info, check if within tolerance
if client_weight > 0 and competitor_weight > 0:
relative_diff = abs(client_weight - competitor_weight) / max(client_weight, competitor_weight)
if relative_diff > tolerance:
return False, "weight_mismatch"
return True, ""
def check_set_bundle_mismatch(
client_product: Dict,
competitor_product: Dict,
tolerance: float = 0.0
) -> Tuple[bool, str]:
"""
Check for set/bundle mismatch.
If one product is a set and the other is not, reject.
If both are sets, check if volumes match (with tolerance for sets with multiple volumes).
Args:
client_product: Normalized client product
competitor_product: Normalized competitor product
tolerance: Not used for set/bundle check
Returns:
Tuple of (passed, reason)
Examples:
- client is_set=True, competitor is_set=True → (True, "")
- client is_set=True, competitor is_set=False → (False, "set_bundle_mismatch")
"""
client_is_set = client_product.get('is_set', False)
competitor_is_set = competitor_product.get('is_set', False)
# If one is a set and the other is not, reject
if client_is_set != competitor_is_set:
return False, "set_bundle_mismatch"
# If both are sets, check volumes
if client_is_set and competitor_is_set:
client_volumes = client_product.get('set_volumes_ml', [])
competitor_volumes = competitor_product.get('set_volumes_ml', [])
# If both have volume info, check if they match
if client_volumes and competitor_volumes:
# For sets with multiple volumes, check if volumes match (order-independent)
if len(client_volumes) != len(competitor_volumes):
return False, "set_bundle_mismatch"
# Check if volumes match (with tolerance for each volume)
for client_vol in client_volumes:
found_match = False
for competitor_vol in competitor_volumes:
if abs(client_vol - competitor_vol) < 5: # 5ml tolerance
found_match = True
break
if not found_match:
return False, "set_bundle_mismatch"
return True, ""
def check_extreme_price_difference(
client_product: Dict,
competitor_product: Dict,
tolerance: float = 0.5
) -> Tuple[bool, str]:
"""
Check for extreme price difference.
If the price difference is too large, reject the match.
This can be used as a guardrail or as part of the scoring logic.
Args:
client_product: Normalized client product
competitor_product: Normalized competitor product
tolerance: Allowed relative difference (default 50%)
Returns:
Tuple of (passed, reason)
Examples:
- client=$100, competitor=$100 → (True, "")
- client=$100, competitor=$150 (50% diff) → (True, "")
- client=$100, competitor=$200 (100% diff) → (False, "extreme_price_difference")
"""
client_price = client_product.get('price_normalized_usd', 0.0)
competitor_price = competitor_product.get('price_normalized_usd', 0.0)
# If both have price info, check if within tolerance
if client_price > 0 and competitor_price > 0:
relative_diff = abs(client_price - competitor_price) / max(client_price, competitor_price)
if relative_diff > tolerance:
return False, "extreme_price_difference"
return True, ""
# ============================================================================
# Field Mapping Dictionary
# ============================================================================
GUARDRAIL_FUNCTIONS = {
'is_set': check_set_bundle_mismatch,
'net_weight_ml': check_weight_mismatch,
'price_normalized_usd': check_extreme_price_difference,
'perfume_concentration': check_concentration_mismatch,
}
# ============================================================================
# Main Guardrail Function
# ============================================================================
def apply_guardrails(
client_product: Dict,
competitor_product: Dict,
guardrail_config: Optional[Dict] = None
) -> Tuple[bool, List[str]]:
"""
Apply all guardrails to a product pair.
Args:
client_product: Normalized client product
competitor_product: Normalized competitor product
guardrail_config: Optional configuration for guardrails
{
'tolerance': {
'net_weight_ml': 0.1,
'price_normalized_usd': 0.5,
},
'enabled': ['is_set', 'net_weight_ml', 'price_normalized_usd', 'perfume_concentration']
}
Returns:
Tuple of (passed, reasons)
- passed: True if all guardrails pass
- reasons: List of reasons for failed guardrails
"""
if guardrail_config is None:
guardrail_config = {
'tolerance': {
'net_weight_ml': 0.1,
'price_normalized_usd': 0.5,
},
'enabled': ['is_set', 'net_weight_ml', 'price_normalized_usd', 'perfume_concentration'],
}
passed = True
reasons = []
# Get enabled guardrails
enabled_guardrails = guardrail_config.get('enabled', list(GUARDRAIL_FUNCTIONS.keys()))
# Get tolerance configuration
tolerance_config = guardrail_config.get('tolerance', {})
# Apply each guardrail
for field, guardrail_func in GUARDRAIL_FUNCTIONS.items():
if field not in enabled_guardrails:
continue
# Get tolerance for this field
tolerance = tolerance_config.get(field, 0.0)
# Apply guardrail
guardrail_passed, reason = guardrail_func(
client_product,
competitor_product,
tolerance
)
if not guardrail_passed:
passed = False
reasons.append(reason)
return passed, reasons
def apply_guardrails_batch(
client_products: List[Dict],
competitor_products: List[Dict],
guardrail_config: Optional[Dict] = None
) -> List[Tuple[Dict, Dict, bool, List[str]]]:
"""
Apply guardrails to all pairs of client and competitor products.
Args:
client_products: List of normalized client products
competitor_products: List of normalized competitor products
guardrail_config: Optional configuration for guardrails
Returns:
List of tuples (client_product, competitor_product, passed, reasons)
"""
results = []
for client_product in client_products:
for competitor_product in competitor_products:
passed, reasons = apply_guardrails(
client_product,
competitor_product,
guardrail_config
)
results.append((client_product, competitor_product, passed, reasons))
return results
# ============================================================================
# Utility Functions
# ============================================================================
def get_guardrail_schema() -> List[str]:
"""
Get the list of guardrail fields.
Returns:
List of field names
"""
return list(GUARDRAIL_FUNCTIONS.keys())
def get_default_guardrail_config() -> Dict:
"""
Get the default guardrail configuration.
Returns:
Default configuration dictionary
"""
return DEFAULT_GUARDRAIL_CONFIG.copy()