diff --git a/eureka_ml_insights/data_utils/transform.py b/eureka_ml_insights/data_utils/transform.py index 4d163630..6597b355 100644 --- a/eureka_ml_insights/data_utils/transform.py +++ b/eureka_ml_insights/data_utils/transform.py @@ -433,12 +433,14 @@ class ExtractUsageTransform: Extracts token usage completion numbers (except prompt input tokens) for all models. args: model_config: config used for the experiment. - usage_completion_output_col: str, default name of the column where completion numbers will be stored for all models - prepend_completion_read_col: str, prepend string to add to the name of the usage column from which to read. Useful for cases when the usage column might have been renamed earlier in the pipeline. + usage_completion_output_col: str, default name of the column where completion numbers will be stored for model + usage_column: str, default name of the column where usage information is stored for model + n_tokens_column: str, default name of the column where number of tokens is stored for model """ model_config: ModelConfig usage_completion_output_col: str = "usage_completion" - prepend_completion_read_col: str = "" + usage_column: str = "usage" + n_tokens_column: str = "n_output_tokens" def transform(self, df: pd.DataFrame) -> pd.DataFrame: """ @@ -448,7 +450,7 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: df (pd.DataFrame): Input dataframe of inference results retrieved with the model_config. Returns: - pd.DataFrame: Transformed dataframe with completion token numbers in completion_usage_col. + pd.DataFrame: Transformed dataframe with completion token numbers in usage_completion_output_col. """ usage_completion_read_col = None if (self.model_config.class_name is GeminiModel): @@ -469,10 +471,35 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: logging.warn(f"Model {self.model_config.class_name} is not recognized for extracting completion token usage.") # if the model is one for which the usage of completion tokens is known, use that corresponding column for the model # otherwise, use the default "n_output_tokens" which is computed with a universal tokenizer as shown in TokenCounterTransform() + self.validate(df, usage_completion_read_col) if usage_completion_read_col: - df[self.usage_completion_output_col] = df[self.prepend_completion_read_col + "usage"].apply(lambda x: x[usage_completion_read_col]) - elif self.prepend_completion_read_col + "n_output_tokens" in df.columns: - df[self.usage_completion_output_col] = df[self.prepend_completion_read_col + "n_output_tokens"] + df[self.usage_completion_output_col] = df.apply(lambda x: self._extract_usage(x, usage_completion_read_col), axis=1) + elif self.n_tokens_column in df.columns: + df[self.usage_completion_output_col] = df[self.n_tokens_column] else: df[self.usage_completion_output_col] = np.nan return df + + def validate(self, df: pd.DataFrame, usage_completion_read_col: str) -> pd.DataFrame: + """Check that usage_columns or n_tokens_columns are present actually in the data frame. + Args: + df (pd.DataFrame): Input dataframe containing model_output_col and id_col. + usage_completion_read_col (str): The column name for token extraction. + """ + if usage_completion_read_col and self.usage_column not in df.columns: + raise ValueError(f"The {self.usage_column} column is not present in the data frame.") + elif self.n_tokens_column not in df.columns: + raise ValueError(f"The {self.n_tokens_column} column is not present in the data frame.") + + def _extract_usage(self, row, usage_completion_read_col): + """ + Extracts the token usage for a given row if usage column and corresponding completion column exists. + Args: + row (pd.Series): A row of the dataframe. + usage_completion_read_col (str): The column name to extract the token usage from. + Returns: + int: The token usage for the row. + """ + if not pd.isna(row[self.usage_column]) and usage_completion_read_col in row[self.usage_column]: + return row[self.usage_column][usage_completion_read_col] + return np.nan diff --git a/tests/test_utils.py b/tests/test_utils.py index dc8ae164..824bbabb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -51,7 +51,8 @@ def __init__(self): def generate(self, text_prompt, query_images=None): return {"model_output": random.choice(["Final Answer: A", "Final Answer: B", "Final Answer: C", "Final Answer: D"]), - "is_valid": random.choice([True, False])} + "is_valid": random.choice([True, False]), + "n_output_tokens": 3} def name(self): return self.name @@ -111,7 +112,9 @@ def __init__(self, model_name="generic_test_model"): self.name = model_name def generate(self, text_prompt, *args, **kwargs): - return {"model_output": "Generic model output", "is_valid": random.choice([True, False])} + return {"model_output": "Generic model output", + "is_valid": random.choice([True, False]), + "n_output_tokens": 3} class DNAEvaluationInferenceTestModel: