Skip to content

Commit

Permalink
change NA to none, test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vidhisha Balachandran committed Jan 17, 2025
1 parent b1cab65 commit b17aefd
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 76 deletions.
2 changes: 1 addition & 1 deletion eureka_ml_insights/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,4 @@
},
"model_name": "Mistral-large-2407",
},
)
)
23 changes: 15 additions & 8 deletions eureka_ml_insights/metrics/ba_calendar_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run_programmatic_tests(self, instance):
for key, value in result.items():
if value == 0:
all_correct = 0
if value != 'NA' and pd.notna(value) and isinstance(value, int):
if value is not None and value != 'NA' and pd.notna(value) and isinstance(value, int):
passed_constraints.append(value)
result['all_correct'] = all_correct
result['fraction_passed'] = np.mean(passed_constraints)
Expand All @@ -114,7 +114,8 @@ def is_formatted(self, solution):

def check_availability_programmatic(self, instance, solution):
if not instance['constraints'].get('availability', True):
result = {'availability_programmatic_check': 'NA'}
# result = {'availability_programmatic_check': 'NA'}
result = {'availability_programmatic_check': None}
return result

if not self.is_formatted(solution):
Expand Down Expand Up @@ -144,7 +145,8 @@ def check_availability_programmatic(self, instance, solution):

def check_meeting_duration_programmatic(self, instance, solution):
if not instance['constraints'].get('meeting_duration', True):
result = {'meeting_duration_programmatic_check': 'NA'}
# result = {'meeting_duration_programmatic_check': 'NA'}
result = {'meeting_duration_programmatic_check': None}
return result

if not self.is_formatted(solution):
Expand All @@ -162,7 +164,8 @@ def check_meeting_duration_programmatic(self, instance, solution):
def check_buffer_time_programmatic(self, instance, solution):
buffer_time = instance['constraints'].get('buffer_time_before_and_after_meeting', True)
if buffer_time is None or not buffer_time:
result = {'buffer_time_programmatic_check': 'NA'}
# result = {'buffer_time_programmatic_check': 'NA'}
result = {'buffer_time_programmatic_check': None}
return result

if not self.is_formatted(solution):
Expand Down Expand Up @@ -195,7 +198,8 @@ def check_buffer_time_programmatic(self, instance, solution):

def check_no_weekends_programmatic(self, instance, solution):
if not instance['constraints'].get('no_meetings_on_weekends', True):
return {'no_weekends_programmatic_check': 'NA'}
# return {'no_weekends_programmatic_check': 'NA'}
return {'no_weekends_programmatic_check': None}

if not self.is_formatted(solution):
return {'no_weekends_programmatic_check': 0}
Expand All @@ -207,7 +211,8 @@ def check_no_weekends_programmatic(self, instance, solution):

def check_time_restrictions_programmatic(self, instance, solution):
if not instance['constraints'].get('no_meetings_before', True) and not instance['constraints'].get('no_meetings_after', True):
return {'time_restrictions_programmatic_check': 'NA'}
# return {'time_restrictions_programmatic_check': 'NA'}
return {'time_restrictions_programmatic_check': None}

if not self.is_formatted(solution):
return {'time_restrictions_programmatic_check': 0}
Expand All @@ -231,7 +236,8 @@ def check_time_restrictions_programmatic(self, instance, solution):

def check_priority_programmatic(self, instance, solution):
if not instance['constraints'].get('high_priority_meeting', False):
return {'priority_programmatic_check': 'NA'}
# return {'priority_programmatic_check': 'NA'}
return {'priority_programmatic_check': None}

if not self.is_formatted(solution):
return {'priority_programmatic_check': 0}
Expand Down Expand Up @@ -269,7 +275,8 @@ def check_priority_programmatic(self, instance, solution):

def check_specific_times_programmatic(self, instance, solution):
if not instance['constraints'].get('no_meetings_during_specific_times', True):
return {'specific_times_programmatic_check': 'NA'}
# return {'specific_times_programmatic_check': 'NA'}
return {'specific_times_programmatic_check': None}

if not self.is_formatted(solution):
return {'specific_times_programmatic_check': 0}
Expand Down
11 changes: 6 additions & 5 deletions eureka_ml_insights/metrics/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,13 @@ def _aggregate_grouped(self, data):
divided_result = (gb[self.numerator_column_name].sum() / gb[self.denominator_column_name].sum()).to_dict()
self.aggregated_result = {"ratio": divided_result}

class NAFilteredAggregator(Aggregator):
def __init__(self, agg_class, column_names, output_dir, group_by=None, ignore_non_numeric=False, filename_base=None, **kwargs):
class ValueFilteredAggregator(Aggregator):
def __init__(self, agg_class, value, column_names, output_dir, group_by=None, ignore_non_numeric=False, filename_base=None, **kwargs):
"""
Aggregator that filters out "NA" values before aggregating the data.
Aggregator that filters out a particular value before aggregating the data.
args:
agg_class: Aggregator class to use for aggregation
value: value to filter out
column_names: column names to filter and aggregate
output_dir: str. directory to save the report
group_by: str. or list of str. column(s) to group by before aggregating
Expand All @@ -362,19 +363,19 @@ def __init__(self, agg_class, column_names, output_dir, group_by=None, ignore_no
"""

self.base_aggregator = agg_class(column_names, output_dir, group_by, ignore_non_numeric, filename_base, **kwargs)
self.value = value
self.column_names = column_names
self.group_by = group_by
self.output_dir = output_dir
self.aggregated_result = None
self.ignore_non_numeric = ignore_non_numeric
self.filename_base = filename_base
# super().__init__(self.input_column_names, output_dir, group_by, ignore_non_numeric, filename_base, **kwargs)

def aggregate(self, data):
agg_results = {}
for col in self.column_names:
# workaround to process one column at a time
filtered_data = data[data[col] != "NA"].copy()
filtered_data = data[data[col] != self.value].copy()
self.base_aggregator.column_names = [col]
self.base_aggregator.aggregate(filtered_data)
agg_results.update(self.base_aggregator.aggregated_result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ You are a scheduling assistant. Given the availability schedules of multiple par
Make sure you use the availability schedules to generate your response.
High priority meetings should be scheduled as early as possible.
Buffer time refers to the required remaining available time before and after a meeting. For example, if buffer time is 15 minutes, a meeting from 9:00-10:00 will require availability from 8:45-10:15.
Respond with "[day] [start_time]-[end_time]" or "No common time slot available"
Do not respond with any additional information or comments.
The final time slot solution should be "[day] [start_time]-[end_time]" or "No common time slot available".
Provide your answer in the format:
Final Answer: <answer>
{{prompt}}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Make sure you use the availability schedules to generate your response.
High priority meetings should be scheduled as early as possible.
Buffer time refers to the required remaining available time before and after a meeting. For example, if buffer time is 15 minutes, a meeting from 9:00-10:00 will require availability from 8:45-10:15.
The final time slot solution should be "[day] [start_time]-[end_time]" or "No common time slot available".
Think through and provide your your answer in the format:
Think through and provide your answer in the format:
Reason: <reasoning>
Final Answer: <answer>
{{prompt}}
60 changes: 13 additions & 47 deletions eureka_ml_insights/user_configs/ba_calendar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from eureka_ml_insights.metrics.reports import (
AverageAggregator,
BiLevelMaxAggregator,
MaxAggregator,
NAFilteredAggregator,
)

from ..configs.config import (
Expand Down Expand Up @@ -99,17 +97,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
{
"column_names": [
"BACalendarMetric_all_correct",
"BACalendarMetric_fraction_passed"
],
"filename_base": "BaCal_OverallMetrics_SeparateRuns",
"group_by": "data_repeat_id",
},
),
AggregatorConfig(
NAFilteredAggregator,
{
"agg_class": AverageAggregator,
"column_names": [
"BACalendarMetric_fraction_passed",
"BACalendarMetric_availability_programmatic_check",
"BACalendarMetric_meeting_duration_programmatic_check",
"BACalendarMetric_buffer_time_programmatic_check",
Expand All @@ -118,7 +106,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
"BACalendarMetric_specific_times_programmatic_check",
"BACalendarMetric_priority_programmatic_check"
],
"filename_base": "BaCal_Constraint_Level_SeprateRuns",
"filename_base": "BaCal_OverallMetrics_SeparateRuns",
"group_by": "data_repeat_id",
},
),
Expand All @@ -142,18 +130,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
{
"column_names": [
"BACalendarMetric_all_correct",
"BACalendarMetric_fraction_passed"
],
"first_groupby": "data_point_id",
"filename_base": "BaCal_BestOfN_Aggregated",
"normalize": True,
},
),
AggregatorConfig(
NAFilteredAggregator,
{
"agg_class": MaxAggregator,
"column_names": [
"BACalendarMetric_fraction_passed",
"BACalendarMetric_availability_programmatic_check",
"BACalendarMetric_meeting_duration_programmatic_check",
"BACalendarMetric_buffer_time_programmatic_check",
Expand All @@ -162,18 +139,17 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
"BACalendarMetric_specific_times_programmatic_check",
"BACalendarMetric_priority_programmatic_check"
],
"filename_base": "BaCal_Constraint_Level_BestOfN_Aggregated",
"group_by": "data_repeat_id",
"first_groupby": "data_point_id",
"filename_base": "BaCal_BestOfN_Aggregated",
"normalize": True,
},
),


],
output_dir=os.path.join(self.log_dir, "bestofn_eval_report"),
)

# Aggregate the results by a majority vote
self.data_post_processing_addmv = DataProcessingConfig(
self.maj_vote_data_post_processing = DataProcessingConfig(
component_type=DataProcessing,
data_reader_config=DataSetConfig(
DataReader,
Expand All @@ -197,12 +173,12 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
output_dir=os.path.join(self.log_dir, "data_majvote_output"),
)
# Second, compute eaxct match
self.postevalprocess_comp = EvalReportingConfig(
self.majvote_evalreporting_comp = EvalReportingConfig(
component_type=EvalReporting,
data_reader_config=DataSetConfig(
DataReader,
{
"path": os.path.join(self.data_post_processing_addmv.output_dir, "transformed_data.jsonl"),
"path": os.path.join(self.maj_vote_data_post_processing.output_dir, "transformed_data.jsonl"),
"format": ".jsonl",
"transform": SequenceTransform(
[
Expand All @@ -218,17 +194,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
{
"column_names": [
"BACalendarMetric_all_correct",
"BACalendarMetric_fraction_passed"
],
"filename_base": "BaCal_MajVote_OverallMetrics_Aggregated",
"group_by": "data_repeat_id",
},
),
AggregatorConfig(
NAFilteredAggregator,
{
"agg_class": AverageAggregator,
"column_names": [
"BACalendarMetric_fraction_passed",
"BACalendarMetric_availability_programmatic_check",
"BACalendarMetric_meeting_duration_programmatic_check",
"BACalendarMetric_buffer_time_programmatic_check",
Expand All @@ -237,7 +203,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
"BACalendarMetric_specific_times_programmatic_check",
"BACalendarMetric_priority_programmatic_check"
],
"filename_base": "BaCal_MajVote_Constraint_Level_Aggregated",
"filename_base": "BaCal_MajVote_OverallMetrics_Aggregated",
"group_by": "data_repeat_id",
},
),
Expand All @@ -252,8 +218,8 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
self.inference_comp,
self.evalreporting_comp,
self.bon_evalreporting_comp,
self.data_post_processing_addmv,
self.postevalprocess_comp
self.maj_vote_data_post_processing,
self.majvote_evalreporting_comp
],
self.log_dir,

Expand Down
18 changes: 7 additions & 11 deletions tests/metric_utils_tests/aggregator_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SumAggregator,
TwoColumnSumAverageAggregator,
)
from eureka_ml_insights.metrics.reports import BiLevelMaxAggregator, MaxAggregator, NAFilteredAggregator
from eureka_ml_insights.metrics.reports import BiLevelMaxAggregator, MaxAggregator, ValueFilteredAggregator

PRECISION = 3

Expand Down Expand Up @@ -419,7 +419,7 @@ def test_average_aggregator_group_by_multiple_columns(self):
self.assertTrue(os.path.exists(avg_agg.output_file))


class NAFilteredAggregatorTestData:
class ValueFilteredAggregatorTestData:
def setUp(self):
self.data = pd.DataFrame(
{
Expand All @@ -430,18 +430,14 @@ def setUp(self):
"col3": [5, 8, 'NA', 3, 'abc', 8, 3, 4, 5, 8, 4, 2],
"categorical_metric": ["x", "y", "z", "z", "y", "y", "z", "y", "x", "y", "y", "x"],
"group": ["a", "a", "b", "b", "a", "a", "b", "b", "a", "a", "b", "b"],
# [5, 6, 8, 5, 8, ]
# [2, 3, 3, 4, 2]
# [5, 8, 6, 8, 5, 8, ]
# [2, 3, 4, 2]
}
)
self.output_dir = "output_dir"
self.precision = PRECISION

class TestNAFilteredAggregator(NAFilteredAggregatorTestData, unittest.TestCase):
class TestValueFilteredAggregator(ValueFilteredAggregatorTestData, unittest.TestCase):
def test_average_aggregator(self):
avg_agg = NAFilteredAggregator(AverageAggregator, ["col1", "col2"], self.output_dir)
avg_agg = ValueFilteredAggregator(AverageAggregator, "NA", ["col1", "col2"], self.output_dir)
avg_agg.aggregate(self.data)
x = [a for a in self.data["col1"] if a != 'NA']
y = [a for a in self.data["col2"] if a != 'NA']
Expand All @@ -451,12 +447,12 @@ def test_average_aggregator(self):
)

def test_average_aggregator_input_validation(self):
avg_agg = NAFilteredAggregator(AverageAggregator, ["col3"], self.output_dir)
avg_agg = ValueFilteredAggregator(AverageAggregator, 'NA', ["col3"], self.output_dir)
self.assertRaises(ValueError, avg_agg.aggregate, self.data)

def test_average_aggregator_group_by(self):
self.output_dir = create_logdir("NAFilteredAggregatorTests")
avg_agg = NAFilteredAggregator(AverageAggregator, ["col1", "col2"], self.output_dir, group_by="group")
self.output_dir = create_logdir("ValueFilteredAggregatorTests")
avg_agg = ValueFilteredAggregator(AverageAggregator, 'NA', ["col1", "col2"], self.output_dir, group_by="group")
avg_agg.aggregate(self.data)
self.assertEqual(avg_agg.aggregated_result, {"col1": {"a": 6.4, "b": 2.8}, "col2": {"a": 6.667, "b": 2.75}})
avg_agg.write_results()
Expand Down
2 changes: 1 addition & 1 deletion tests/pipeline_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def get_config(self):

def setUp(self) -> None:
super().setUp()
self.eval_configs = [self.test_pipeline.evalreporting_comp]
self.eval_configs = [self.test_pipeline.evalreporting_comp,self.test_pipeline.bon_evalreporting_comp, self.test_pipeline.majvote_evalreporting_comp]

def test_outputs_exist(self) -> None:
logging.info("Running test_outputs_exist test in PipelineTest")
Expand Down

0 comments on commit b17aefd

Please sign in to comment.