diff --git a/eureka_ml_insights/core/inference.py b/eureka_ml_insights/core/inference.py index cb61b99b..1374c5bb 100644 --- a/eureka_ml_insights/core/inference.py +++ b/eureka_ml_insights/core/inference.py @@ -11,12 +11,21 @@ from .pipeline import Component from .reserved_names import INFERENCE_RESERVED_NAMES + MINUTE = 60 class Inference(Component): - def __init__(self, model_config, data_config, output_dir, resume_from=None, new_columns=None, requests_per_minute=None, max_concurrent=1): - + def __init__( + self, + model_config, + data_config, + output_dir, + resume_from=None, + new_columns=None, + requests_per_minute=None, + max_concurrent=1, + ): """ Initialize the Inference component. args: @@ -62,13 +71,13 @@ def fetch_previous_inference_results(self): # fetch previous results from the provided resume_from file logging.info(f"Resuming inference from {self.resume_from}") pre_inf_results_df = DataReader(self.resume_from, format=".jsonl").load_dataset() - + # add new columns listed by the user to the previous inference results if self.new_columns: for col in self.new_columns: if col not in pre_inf_results_df.columns: pre_inf_results_df[col] = None - + # validate the resume_from contents with self.data_loader as loader: _, sample_model_input = self.data_loader.get_sample_model_input() @@ -80,13 +89,17 @@ def fetch_previous_inference_results(self): # perform a sample inference call to get the model output keys and validate the resume_from contents sample_response_dict = self.model.generate(*sample_model_input) + if not sample_response_dict["is_valid"]: + raise ValueError( + "Sample inference call for resume_from returned invalid results, please check the model configuration." + ) # check if the inference response dictionary contains the same keys as the resume_from file eventual_keys = set(sample_response_dict.keys()) | set(sample_data_keys) # in case of resuming from a file that was generated by an older version of the model, # we let the discrepancy in the reserved keys slide and later set the missing keys to None - match_keys = set(pre_inf_results_df.columns) | set(INFERENCE_RESERVED_NAMES) - + match_keys = set(pre_inf_results_df.columns) | set(INFERENCE_RESERVED_NAMES) + if set(eventual_keys) != match_keys: diff = set(eventual_keys) ^ set(match_keys) raise ValueError( @@ -139,6 +152,11 @@ def retrieve_exisiting_result(self, data, pre_inf_results_df): prev_model_tokens, prev_model_time, ) + # add remaining pre_inf_results_df columns to the data point + for col in pre_inf_results_df.columns: + if col not in data: + data[col] = prev_results[col].values[0] + return data def run(self): diff --git a/main.py b/main.py index 8eed49c7..433dda38 100755 --- a/main.py +++ b/main.py @@ -21,10 +21,23 @@ parser.add_argument( "--resume_from", type=str, help="The path to the inference_result.jsonl to resume from.", default=None ) - args = parser.parse_args() + init_args = {} + + # catch any unknown arguments + args, unknown_args = parser.parse_known_args() + if unknown_args: + # if every other unknown arg starts with "--", parse the unknown args as key-value pairs in a dict + if all(arg.startswith("--") for arg in unknown_args[::2]): + init_args.update( + {arg[len("--") :]: unknown_args[i + 1] for i, arg in enumerate(unknown_args) if i % 2 == 0} + ) + logging.info(f"Unknown arguments: {init_args} will be sent to the experiment config class.") + # else, parse the unknown args as is ie. as a list + else: + init_args["unknown_args"] = unknown_args + logging.info(f"Unknown arguments: {unknown_args} will be sent as is to the experiment config class.") experiment_config_class = args.exp_config - init_args = {} if args.model_config: try: init_args["model_config"] = getattr(model_configs, args.model_config)