|
1 | | -from typing import List |
| 1 | +from typing import List,Union |
2 | 2 |
|
3 | 3 | from kubernetes.dynamic import DynamicClient |
4 | 4 | from ocp_resources.lm_eval_job import LMEvalJob |
@@ -36,30 +36,46 @@ def get_lmevaljob_pod(client: DynamicClient, lmevaljob: LMEvalJob, timeout: int |
36 | 36 | return lmeval_pod |
37 | 37 |
|
38 | 38 |
|
39 | | -def get_lmeval_tasks(min_downloads: int = 10000) -> List[str]: |
| 39 | +def get_lmeval_tasks(min_downloads: Union[int, float], max_downloads: Union[int, float, None] = None) -> List[str]: |
40 | 40 | """ |
41 | 41 | Gets the list of supported LM-Eval tasks that have above a certain number of minimum downloads on HuggingFace. |
42 | 42 |
|
43 | 43 | Args: |
44 | | - min_downloads: The minimum number of downloads |
| 44 | + min_downloads: The minimum number of downloads or the percentile of downloads to use as a minimum |
| 45 | + max_downloads: The maximum number of downloads or the percentile of downloads to use as a maximum |
45 | 46 |
|
46 | 47 | Returns: |
47 | 48 | List of LM-Eval task names |
48 | 49 | """ |
49 | | - if min_downloads < 1: |
| 50 | + if min_downloads <= 0: |
50 | 51 | raise ValueError("Minimum downloads must be greater than 0") |
51 | 52 |
|
52 | 53 | lmeval_tasks = pd.read_csv(filepath_or_buffer="tests/model_explainability/lm_eval/data/new_task_list.csv") |
53 | 54 |
|
54 | | - # filter for tasks that either exceed (min_downloads OR exist on the OpenLLM leaderboard) |
55 | | - # AND exist on LMEval AND do not include image data |
| 55 | + if isinstance(min_downloads, float): |
| 56 | + if not 0 <= min_downloads <= 1: |
| 57 | + raise ValueError("Minimum downloads as a percentile must be between 0 and 1") |
| 58 | + min_downloads = lmeval_tasks["HF dataset downloads"].quantile(q=min_downloads) |
56 | 59 |
|
| 60 | + # filter for tasks that either exceed min_downloads OR exist on the OpenLLM leaderboard |
| 61 | + # AND exist on LMEval AND do not include image data |
57 | 62 | filtered_df = lmeval_tasks[ |
58 | 63 | lmeval_tasks["Exists"] |
59 | 64 | & (lmeval_tasks["Dataset"] != "MMMU/MMMU") |
60 | 65 | & ((lmeval_tasks["HF dataset downloads"] >= min_downloads) | (lmeval_tasks["OpenLLM leaderboard"])) |
61 | 66 | ] |
62 | 67 |
|
| 68 | + # if max_downloads is provided, filter for tasks that have less than |
| 69 | + # or equal to the maximum number of downloads |
| 70 | + if max_downloads is not None: |
| 71 | + if (max_downloads <= 0 | max_downloads > max(lmeval_tasks["HF dataset downloads"])): |
| 72 | + raise ValueError("Maximum downloads must be greater than 0 and less than the maximum number of downloads") |
| 73 | + if isinstance(max_downloads, float): |
| 74 | + if not 0 <= min_downloads <= 1: |
| 75 | + raise ValueError("Maximum downloads as a percentile must be between 0 and 1") |
| 76 | + max_downloads = lmeval_tasks["HF dataset downloads"].quantile(q=max_downloads) |
| 77 | + filtered_df = filtered_df[filtered_df["HF dataset downloads"] <= max_downloads] |
| 78 | + |
63 | 79 | # group tasks by dataset and extract the task with shortest name in the group |
64 | 80 | unique_tasks = filtered_df.loc[filtered_df.groupby("Dataset")["Name"].apply(lambda x: x.str.len().idxmin())] |
65 | 81 |
|
|
0 commit comments