Skip to content

Commit 45e802e

Browse files
vidhishanairVidhisha Balachandran
and
Vidhisha Balachandran
authored
validate usage columns for extract usage (#110)
Co-authored-by: Vidhisha Balachandran <[email protected]>
1 parent 9bfdcd3 commit 45e802e

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

eureka_ml_insights/data_utils/transform.py

+34-7
Original file line numberDiff line numberDiff line change
@@ -433,12 +433,14 @@ class ExtractUsageTransform:
433433
Extracts token usage completion numbers (except prompt input tokens) for all models.
434434
args:
435435
model_config: config used for the experiment.
436-
usage_completion_output_col: str, default name of the column where completion numbers will be stored for all models
437-
prepend_completion_read_col: str, prepend string to add to the name of the usage column from which to read. Useful for cases when the usage column might have been renamed earlier in the pipeline.
436+
usage_completion_output_col: str, default name of the column where completion numbers will be stored for model
437+
usage_column: str, default name of the column where usage information is stored for model
438+
n_tokens_column: str, default name of the column where number of tokens is stored for model
438439
"""
439440
model_config: ModelConfig
440441
usage_completion_output_col: str = "usage_completion"
441-
prepend_completion_read_col: str = ""
442+
usage_column: str = "usage"
443+
n_tokens_column: str = "n_output_tokens"
442444

443445
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
444446
"""
@@ -448,7 +450,7 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
448450
df (pd.DataFrame): Input dataframe of inference results retrieved with the model_config.
449451
450452
Returns:
451-
pd.DataFrame: Transformed dataframe with completion token numbers in completion_usage_col.
453+
pd.DataFrame: Transformed dataframe with completion token numbers in usage_completion_output_col.
452454
"""
453455
usage_completion_read_col = None
454456
if (self.model_config.class_name is GeminiModel):
@@ -469,10 +471,35 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
469471
logging.warn(f"Model {self.model_config.class_name} is not recognized for extracting completion token usage.")
470472
# if the model is one for which the usage of completion tokens is known, use that corresponding column for the model
471473
# otherwise, use the default "n_output_tokens" which is computed with a universal tokenizer as shown in TokenCounterTransform()
474+
self.validate(df, usage_completion_read_col)
472475
if usage_completion_read_col:
473-
df[self.usage_completion_output_col] = df[self.prepend_completion_read_col + "usage"].apply(lambda x: x[usage_completion_read_col])
474-
elif self.prepend_completion_read_col + "n_output_tokens" in df.columns:
475-
df[self.usage_completion_output_col] = df[self.prepend_completion_read_col + "n_output_tokens"]
476+
df[self.usage_completion_output_col] = df.apply(lambda x: self._extract_usage(x, usage_completion_read_col), axis=1)
477+
elif self.n_tokens_column in df.columns:
478+
df[self.usage_completion_output_col] = df[self.n_tokens_column]
476479
else:
477480
df[self.usage_completion_output_col] = np.nan
478481
return df
482+
483+
def validate(self, df: pd.DataFrame, usage_completion_read_col: str) -> pd.DataFrame:
484+
"""Check that usage_columns or n_tokens_columns are present actually in the data frame.
485+
Args:
486+
df (pd.DataFrame): Input dataframe containing model_output_col and id_col.
487+
usage_completion_read_col (str): The column name for token extraction.
488+
"""
489+
if usage_completion_read_col and self.usage_column not in df.columns:
490+
raise ValueError(f"The {self.usage_column} column is not present in the data frame.")
491+
elif self.n_tokens_column not in df.columns:
492+
raise ValueError(f"The {self.n_tokens_column} column is not present in the data frame.")
493+
494+
def _extract_usage(self, row, usage_completion_read_col):
495+
"""
496+
Extracts the token usage for a given row if usage column and corresponding completion column exists.
497+
Args:
498+
row (pd.Series): A row of the dataframe.
499+
usage_completion_read_col (str): The column name to extract the token usage from.
500+
Returns:
501+
int: The token usage for the row.
502+
"""
503+
if not pd.isna(row[self.usage_column]) and usage_completion_read_col in row[self.usage_column]:
504+
return row[self.usage_column][usage_completion_read_col]
505+
return np.nan

tests/test_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def __init__(self):
5151

5252
def generate(self, text_prompt, query_images=None):
5353
return {"model_output": random.choice(["Final Answer: A", "Final Answer: B", "Final Answer: C", "Final Answer: D"]),
54-
"is_valid": random.choice([True, False])}
54+
"is_valid": random.choice([True, False]),
55+
"n_output_tokens": 3}
5556

5657
def name(self):
5758
return self.name
@@ -111,7 +112,9 @@ def __init__(self, model_name="generic_test_model"):
111112
self.name = model_name
112113

113114
def generate(self, text_prompt, *args, **kwargs):
114-
return {"model_output": "Generic model output", "is_valid": random.choice([True, False])}
115+
return {"model_output": "Generic model output",
116+
"is_valid": random.choice([True, False]),
117+
"n_output_tokens": 3}
115118

116119

117120
class DNAEvaluationInferenceTestModel:

0 commit comments

Comments
 (0)