Skip to content

Commit b17aefd

Browse files
author
Vidhisha Balachandran
committed
change NA to none, test fixes
1 parent b1cab65 commit b17aefd

File tree

8 files changed

+47
-76
lines changed

8 files changed

+47
-76
lines changed

eureka_ml_insights/configs/model_configs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,4 @@
199199
},
200200
"model_name": "Mistral-large-2407",
201201
},
202-
)
202+
)

eureka_ml_insights/metrics/ba_calendar_metrics.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def run_programmatic_tests(self, instance):
9898
for key, value in result.items():
9999
if value == 0:
100100
all_correct = 0
101-
if value != 'NA' and pd.notna(value) and isinstance(value, int):
101+
if value is not None and value != 'NA' and pd.notna(value) and isinstance(value, int):
102102
passed_constraints.append(value)
103103
result['all_correct'] = all_correct
104104
result['fraction_passed'] = np.mean(passed_constraints)
@@ -114,7 +114,8 @@ def is_formatted(self, solution):
114114

115115
def check_availability_programmatic(self, instance, solution):
116116
if not instance['constraints'].get('availability', True):
117-
result = {'availability_programmatic_check': 'NA'}
117+
# result = {'availability_programmatic_check': 'NA'}
118+
result = {'availability_programmatic_check': None}
118119
return result
119120

120121
if not self.is_formatted(solution):
@@ -144,7 +145,8 @@ def check_availability_programmatic(self, instance, solution):
144145

145146
def check_meeting_duration_programmatic(self, instance, solution):
146147
if not instance['constraints'].get('meeting_duration', True):
147-
result = {'meeting_duration_programmatic_check': 'NA'}
148+
# result = {'meeting_duration_programmatic_check': 'NA'}
149+
result = {'meeting_duration_programmatic_check': None}
148150
return result
149151

150152
if not self.is_formatted(solution):
@@ -162,7 +164,8 @@ def check_meeting_duration_programmatic(self, instance, solution):
162164
def check_buffer_time_programmatic(self, instance, solution):
163165
buffer_time = instance['constraints'].get('buffer_time_before_and_after_meeting', True)
164166
if buffer_time is None or not buffer_time:
165-
result = {'buffer_time_programmatic_check': 'NA'}
167+
# result = {'buffer_time_programmatic_check': 'NA'}
168+
result = {'buffer_time_programmatic_check': None}
166169
return result
167170

168171
if not self.is_formatted(solution):
@@ -195,7 +198,8 @@ def check_buffer_time_programmatic(self, instance, solution):
195198

196199
def check_no_weekends_programmatic(self, instance, solution):
197200
if not instance['constraints'].get('no_meetings_on_weekends', True):
198-
return {'no_weekends_programmatic_check': 'NA'}
201+
# return {'no_weekends_programmatic_check': 'NA'}
202+
return {'no_weekends_programmatic_check': None}
199203

200204
if not self.is_formatted(solution):
201205
return {'no_weekends_programmatic_check': 0}
@@ -207,7 +211,8 @@ def check_no_weekends_programmatic(self, instance, solution):
207211

208212
def check_time_restrictions_programmatic(self, instance, solution):
209213
if not instance['constraints'].get('no_meetings_before', True) and not instance['constraints'].get('no_meetings_after', True):
210-
return {'time_restrictions_programmatic_check': 'NA'}
214+
# return {'time_restrictions_programmatic_check': 'NA'}
215+
return {'time_restrictions_programmatic_check': None}
211216

212217
if not self.is_formatted(solution):
213218
return {'time_restrictions_programmatic_check': 0}
@@ -231,7 +236,8 @@ def check_time_restrictions_programmatic(self, instance, solution):
231236

232237
def check_priority_programmatic(self, instance, solution):
233238
if not instance['constraints'].get('high_priority_meeting', False):
234-
return {'priority_programmatic_check': 'NA'}
239+
# return {'priority_programmatic_check': 'NA'}
240+
return {'priority_programmatic_check': None}
235241

236242
if not self.is_formatted(solution):
237243
return {'priority_programmatic_check': 0}
@@ -269,7 +275,8 @@ def check_priority_programmatic(self, instance, solution):
269275

270276
def check_specific_times_programmatic(self, instance, solution):
271277
if not instance['constraints'].get('no_meetings_during_specific_times', True):
272-
return {'specific_times_programmatic_check': 'NA'}
278+
# return {'specific_times_programmatic_check': 'NA'}
279+
return {'specific_times_programmatic_check': None}
273280

274281
if not self.is_formatted(solution):
275282
return {'specific_times_programmatic_check': 0}

eureka_ml_insights/metrics/reports.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,13 @@ def _aggregate_grouped(self, data):
348348
divided_result = (gb[self.numerator_column_name].sum() / gb[self.denominator_column_name].sum()).to_dict()
349349
self.aggregated_result = {"ratio": divided_result}
350350

351-
class NAFilteredAggregator(Aggregator):
352-
def __init__(self, agg_class, column_names, output_dir, group_by=None, ignore_non_numeric=False, filename_base=None, **kwargs):
351+
class ValueFilteredAggregator(Aggregator):
352+
def __init__(self, agg_class, value, column_names, output_dir, group_by=None, ignore_non_numeric=False, filename_base=None, **kwargs):
353353
"""
354-
Aggregator that filters out "NA" values before aggregating the data.
354+
Aggregator that filters out a particular value before aggregating the data.
355355
args:
356356
agg_class: Aggregator class to use for aggregation
357+
value: value to filter out
357358
column_names: column names to filter and aggregate
358359
output_dir: str. directory to save the report
359360
group_by: str. or list of str. column(s) to group by before aggregating
@@ -362,19 +363,19 @@ def __init__(self, agg_class, column_names, output_dir, group_by=None, ignore_no
362363
"""
363364

364365
self.base_aggregator = agg_class(column_names, output_dir, group_by, ignore_non_numeric, filename_base, **kwargs)
366+
self.value = value
365367
self.column_names = column_names
366368
self.group_by = group_by
367369
self.output_dir = output_dir
368370
self.aggregated_result = None
369371
self.ignore_non_numeric = ignore_non_numeric
370372
self.filename_base = filename_base
371-
# super().__init__(self.input_column_names, output_dir, group_by, ignore_non_numeric, filename_base, **kwargs)
372373

373374
def aggregate(self, data):
374375
agg_results = {}
375376
for col in self.column_names:
376377
# workaround to process one column at a time
377-
filtered_data = data[data[col] != "NA"].copy()
378+
filtered_data = data[data[col] != self.value].copy()
378379
self.base_aggregator.column_names = [col]
379380
self.base_aggregator.aggregate(filtered_data)
380381
agg_results.update(self.base_aggregator.aggregated_result)

eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_brief.jinja

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ You are a scheduling assistant. Given the availability schedules of multiple par
22
Make sure you use the availability schedules to generate your response.
33
High priority meetings should be scheduled as early as possible.
44
Buffer time refers to the required remaining available time before and after a meeting. For example, if buffer time is 15 minutes, a meeting from 9:00-10:00 will require availability from 8:45-10:15.
5-
Respond with "[day] [start_time]-[end_time]" or "No common time slot available"
6-
Do not respond with any additional information or comments.
5+
The final time slot solution should be "[day] [start_time]-[end_time]" or "No common time slot available".
6+
Provide your answer in the format:
7+
Final Answer: <answer>
78
{{prompt}}

eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Make sure you use the availability schedules to generate your response.
33
High priority meetings should be scheduled as early as possible.
44
Buffer time refers to the required remaining available time before and after a meeting. For example, if buffer time is 15 minutes, a meeting from 9:00-10:00 will require availability from 8:45-10:15.
55
The final time slot solution should be "[day] [start_time]-[end_time]" or "No common time slot available".
6-
Think through and provide your your answer in the format:
6+
Think through and provide your answer in the format:
77
Reason: <reasoning>
88
Final Answer: <answer>
99
{{prompt}}

eureka_ml_insights/user_configs/ba_calendar.py

+13-47
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from eureka_ml_insights.metrics.reports import (
2020
AverageAggregator,
2121
BiLevelMaxAggregator,
22-
MaxAggregator,
23-
NAFilteredAggregator,
2422
)
2523

2624
from ..configs.config import (
@@ -99,17 +97,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
9997
{
10098
"column_names": [
10199
"BACalendarMetric_all_correct",
102-
"BACalendarMetric_fraction_passed"
103-
],
104-
"filename_base": "BaCal_OverallMetrics_SeparateRuns",
105-
"group_by": "data_repeat_id",
106-
},
107-
),
108-
AggregatorConfig(
109-
NAFilteredAggregator,
110-
{
111-
"agg_class": AverageAggregator,
112-
"column_names": [
100+
"BACalendarMetric_fraction_passed",
113101
"BACalendarMetric_availability_programmatic_check",
114102
"BACalendarMetric_meeting_duration_programmatic_check",
115103
"BACalendarMetric_buffer_time_programmatic_check",
@@ -118,7 +106,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
118106
"BACalendarMetric_specific_times_programmatic_check",
119107
"BACalendarMetric_priority_programmatic_check"
120108
],
121-
"filename_base": "BaCal_Constraint_Level_SeprateRuns",
109+
"filename_base": "BaCal_OverallMetrics_SeparateRuns",
122110
"group_by": "data_repeat_id",
123111
},
124112
),
@@ -142,18 +130,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
142130
{
143131
"column_names": [
144132
"BACalendarMetric_all_correct",
145-
"BACalendarMetric_fraction_passed"
146-
],
147-
"first_groupby": "data_point_id",
148-
"filename_base": "BaCal_BestOfN_Aggregated",
149-
"normalize": True,
150-
},
151-
),
152-
AggregatorConfig(
153-
NAFilteredAggregator,
154-
{
155-
"agg_class": MaxAggregator,
156-
"column_names": [
133+
"BACalendarMetric_fraction_passed",
157134
"BACalendarMetric_availability_programmatic_check",
158135
"BACalendarMetric_meeting_duration_programmatic_check",
159136
"BACalendarMetric_buffer_time_programmatic_check",
@@ -162,18 +139,17 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
162139
"BACalendarMetric_specific_times_programmatic_check",
163140
"BACalendarMetric_priority_programmatic_check"
164141
],
165-
"filename_base": "BaCal_Constraint_Level_BestOfN_Aggregated",
166-
"group_by": "data_repeat_id",
142+
"first_groupby": "data_point_id",
143+
"filename_base": "BaCal_BestOfN_Aggregated",
144+
"normalize": True,
167145
},
168146
),
169-
170-
171147
],
172148
output_dir=os.path.join(self.log_dir, "bestofn_eval_report"),
173149
)
174150

175151
# Aggregate the results by a majority vote
176-
self.data_post_processing_addmv = DataProcessingConfig(
152+
self.maj_vote_data_post_processing = DataProcessingConfig(
177153
component_type=DataProcessing,
178154
data_reader_config=DataSetConfig(
179155
DataReader,
@@ -197,12 +173,12 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
197173
output_dir=os.path.join(self.log_dir, "data_majvote_output"),
198174
)
199175
# Second, compute eaxct match
200-
self.postevalprocess_comp = EvalReportingConfig(
176+
self.majvote_evalreporting_comp = EvalReportingConfig(
201177
component_type=EvalReporting,
202178
data_reader_config=DataSetConfig(
203179
DataReader,
204180
{
205-
"path": os.path.join(self.data_post_processing_addmv.output_dir, "transformed_data.jsonl"),
181+
"path": os.path.join(self.maj_vote_data_post_processing.output_dir, "transformed_data.jsonl"),
206182
"format": ".jsonl",
207183
"transform": SequenceTransform(
208184
[
@@ -218,17 +194,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
218194
{
219195
"column_names": [
220196
"BACalendarMetric_all_correct",
221-
"BACalendarMetric_fraction_passed"
222-
],
223-
"filename_base": "BaCal_MajVote_OverallMetrics_Aggregated",
224-
"group_by": "data_repeat_id",
225-
},
226-
),
227-
AggregatorConfig(
228-
NAFilteredAggregator,
229-
{
230-
"agg_class": AverageAggregator,
231-
"column_names": [
197+
"BACalendarMetric_fraction_passed",
232198
"BACalendarMetric_availability_programmatic_check",
233199
"BACalendarMetric_meeting_duration_programmatic_check",
234200
"BACalendarMetric_buffer_time_programmatic_check",
@@ -237,7 +203,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
237203
"BACalendarMetric_specific_times_programmatic_check",
238204
"BACalendarMetric_priority_programmatic_check"
239205
],
240-
"filename_base": "BaCal_MajVote_Constraint_Level_Aggregated",
206+
"filename_base": "BaCal_MajVote_OverallMetrics_Aggregated",
241207
"group_by": "data_repeat_id",
242208
},
243209
),
@@ -252,8 +218,8 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
252218
self.inference_comp,
253219
self.evalreporting_comp,
254220
self.bon_evalreporting_comp,
255-
self.data_post_processing_addmv,
256-
self.postevalprocess_comp
221+
self.maj_vote_data_post_processing,
222+
self.majvote_evalreporting_comp
257223
],
258224
self.log_dir,
259225

tests/metric_utils_tests/aggregator_tests.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
SumAggregator,
1616
TwoColumnSumAverageAggregator,
1717
)
18-
from eureka_ml_insights.metrics.reports import BiLevelMaxAggregator, MaxAggregator, NAFilteredAggregator
18+
from eureka_ml_insights.metrics.reports import BiLevelMaxAggregator, MaxAggregator, ValueFilteredAggregator
1919

2020
PRECISION = 3
2121

@@ -419,7 +419,7 @@ def test_average_aggregator_group_by_multiple_columns(self):
419419
self.assertTrue(os.path.exists(avg_agg.output_file))
420420

421421

422-
class NAFilteredAggregatorTestData:
422+
class ValueFilteredAggregatorTestData:
423423
def setUp(self):
424424
self.data = pd.DataFrame(
425425
{
@@ -430,18 +430,14 @@ def setUp(self):
430430
"col3": [5, 8, 'NA', 3, 'abc', 8, 3, 4, 5, 8, 4, 2],
431431
"categorical_metric": ["x", "y", "z", "z", "y", "y", "z", "y", "x", "y", "y", "x"],
432432
"group": ["a", "a", "b", "b", "a", "a", "b", "b", "a", "a", "b", "b"],
433-
# [5, 6, 8, 5, 8, ]
434-
# [2, 3, 3, 4, 2]
435-
# [5, 8, 6, 8, 5, 8, ]
436-
# [2, 3, 4, 2]
437433
}
438434
)
439435
self.output_dir = "output_dir"
440436
self.precision = PRECISION
441437

442-
class TestNAFilteredAggregator(NAFilteredAggregatorTestData, unittest.TestCase):
438+
class TestValueFilteredAggregator(ValueFilteredAggregatorTestData, unittest.TestCase):
443439
def test_average_aggregator(self):
444-
avg_agg = NAFilteredAggregator(AverageAggregator, ["col1", "col2"], self.output_dir)
440+
avg_agg = ValueFilteredAggregator(AverageAggregator, "NA", ["col1", "col2"], self.output_dir)
445441
avg_agg.aggregate(self.data)
446442
x = [a for a in self.data["col1"] if a != 'NA']
447443
y = [a for a in self.data["col2"] if a != 'NA']
@@ -451,12 +447,12 @@ def test_average_aggregator(self):
451447
)
452448

453449
def test_average_aggregator_input_validation(self):
454-
avg_agg = NAFilteredAggregator(AverageAggregator, ["col3"], self.output_dir)
450+
avg_agg = ValueFilteredAggregator(AverageAggregator, 'NA', ["col3"], self.output_dir)
455451
self.assertRaises(ValueError, avg_agg.aggregate, self.data)
456452

457453
def test_average_aggregator_group_by(self):
458-
self.output_dir = create_logdir("NAFilteredAggregatorTests")
459-
avg_agg = NAFilteredAggregator(AverageAggregator, ["col1", "col2"], self.output_dir, group_by="group")
454+
self.output_dir = create_logdir("ValueFilteredAggregatorTests")
455+
avg_agg = ValueFilteredAggregator(AverageAggregator, 'NA', ["col1", "col2"], self.output_dir, group_by="group")
460456
avg_agg.aggregate(self.data)
461457
self.assertEqual(avg_agg.aggregated_result, {"col1": {"a": 6.4, "b": 2.8}, "col2": {"a": 6.667, "b": 2.75}})
462458
avg_agg.write_results()

tests/pipeline_tests.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def get_config(self):
487487

488488
def setUp(self) -> None:
489489
super().setUp()
490-
self.eval_configs = [self.test_pipeline.evalreporting_comp]
490+
self.eval_configs = [self.test_pipeline.evalreporting_comp,self.test_pipeline.bon_evalreporting_comp, self.test_pipeline.majvote_evalreporting_comp]
491491

492492
def test_outputs_exist(self) -> None:
493493
logging.info("Running test_outputs_exist test in PipelineTest")

0 commit comments

Comments
 (0)