Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validate usage columns for extract usage #110

Merged
merged 16 commits into from
Mar 16, 2025
41 changes: 34 additions & 7 deletions eureka_ml_insights/data_utils/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,14 @@ class ExtractUsageTransform:
Extracts token usage completion numbers (except prompt input tokens) for all models.
args:
model_config: config used for the experiment.
usage_completion_output_col: str, default name of the column where completion numbers will be stored for all models
prepend_completion_read_col: str, prepend string to add to the name of the usage column from which to read. Useful for cases when the usage column might have been renamed earlier in the pipeline.
usage_completion_output_col: str, default name of the column where completion numbers will be stored for model
usage_column: str, default name of the column where usage information is stored for model
n_tokens_column: str, default name of the column where number of tokens is stored for model
"""
model_config: ModelConfig
usage_completion_output_col: str = "usage_completion"
prepend_completion_read_col: str = ""
usage_column: str = "usage"
n_tokens_column: str = "n_output_tokens"

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Expand All @@ -448,7 +450,7 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
df (pd.DataFrame): Input dataframe of inference results retrieved with the model_config.

Returns:
pd.DataFrame: Transformed dataframe with completion token numbers in completion_usage_col.
pd.DataFrame: Transformed dataframe with completion token numbers in usage_completion_output_col.
"""
usage_completion_read_col = None
if (self.model_config.class_name is GeminiModel):
Expand All @@ -469,10 +471,35 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
logging.warn(f"Model {self.model_config.class_name} is not recognized for extracting completion token usage.")
# if the model is one for which the usage of completion tokens is known, use that corresponding column for the model
# otherwise, use the default "n_output_tokens" which is computed with a universal tokenizer as shown in TokenCounterTransform()
self.validate(df, usage_completion_read_col)
if usage_completion_read_col:
df[self.usage_completion_output_col] = df[self.prepend_completion_read_col + "usage"].apply(lambda x: x[usage_completion_read_col])
elif self.prepend_completion_read_col + "n_output_tokens" in df.columns:
df[self.usage_completion_output_col] = df[self.prepend_completion_read_col + "n_output_tokens"]
df[self.usage_completion_output_col] = df.apply(lambda x: self._extract_usage(x, usage_completion_read_col), axis=1)
elif self.n_tokens_column in df.columns:
df[self.usage_completion_output_col] = df[self.n_tokens_column]
else:
df[self.usage_completion_output_col] = np.nan
return df

def validate(self, df: pd.DataFrame, usage_completion_read_col: str) -> pd.DataFrame:
"""Check that usage_columns or n_tokens_columns are present actually in the data frame.
Args:
df (pd.DataFrame): Input dataframe containing model_output_col and id_col.
usage_completion_read_col (str): The column name for token extraction.
"""
if usage_completion_read_col and self.usage_column not in df.columns:
raise ValueError(f"The {self.usage_column} column is not present in the data frame.")
elif self.n_tokens_column not in df.columns:
raise ValueError(f"The {self.n_tokens_column} column is not present in the data frame.")

def _extract_usage(self, row, usage_completion_read_col):
"""
Extracts the token usage for a given row if usage column and corresponding completion column exists.
Args:
row (pd.Series): A row of the dataframe.
usage_completion_read_col (str): The column name to extract the token usage from.
Returns:
int: The token usage for the row.
"""
if not pd.isna(row[self.usage_column]) and usage_completion_read_col in row[self.usage_column]:
return row[self.usage_column][usage_completion_read_col]
return np.nan
7 changes: 5 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def __init__(self):

def generate(self, text_prompt, query_images=None):
return {"model_output": random.choice(["Final Answer: A", "Final Answer: B", "Final Answer: C", "Final Answer: D"]),
"is_valid": random.choice([True, False])}
"is_valid": random.choice([True, False]),
"n_output_tokens": 3}

def name(self):
return self.name
Expand Down Expand Up @@ -111,7 +112,9 @@ def __init__(self, model_name="generic_test_model"):
self.name = model_name

def generate(self, text_prompt, *args, **kwargs):
return {"model_output": "Generic model output", "is_valid": random.choice([True, False])}
return {"model_output": "Generic model output",
"is_valid": random.choice([True, False]),
"n_output_tokens": 3}


class DNAEvaluationInferenceTestModel:
Expand Down
Loading