|
| 1 | +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import argparse |
| 16 | +import json |
| 17 | +import os |
| 18 | +from datetime import datetime |
| 19 | +from pathlib import Path |
| 20 | + |
| 21 | +from datasets import load_dataset |
| 22 | +from dateutil.relativedelta import relativedelta |
| 23 | + |
| 24 | + |
| 25 | +class PromptConstants: |
| 26 | + # reference: https://github.com/QwenLM/Qwen2.5-Coder/blob/main/qwencoder-eval/reasoning/livecode_bench_cot/lcb_runner_cq/prompts/code_generation.py#L31 |
| 27 | + FORMATTING_MESSAGE_WITH_STARTER_CODE = "You will use the following starter code to write the solution to the problem and enclose your code within delimiters." |
| 28 | + FORMATTING_WITHOUT_STARTER_CODE = "Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs, it reads the inputs, runs the algorithm and writes output to STDOUT." |
| 29 | + |
| 30 | + |
| 31 | +def parse_data(release_version='release_latest'): |
| 32 | + data = load_dataset( |
| 33 | + "livecodebench/code_generation_lite", split="test", version_tag=release_version, trust_remote_code=True |
| 34 | + ) |
| 35 | + # data has the following fields |
| 36 | + # question_title: str |
| 37 | + # question_content: str |
| 38 | + # platform: Platform |
| 39 | + # question_id: str |
| 40 | + # contest_id: str |
| 41 | + # contest_date: datetime |
| 42 | + # starter_code: str |
| 43 | + # difficulty: Difficulty |
| 44 | + # public_test_cases: list[Test] |
| 45 | + # private_test_cases: list[Test] |
| 46 | + # metadata: dict |
| 47 | + return data |
| 48 | + |
| 49 | + |
| 50 | +def get_first_last_day(year_month_str): |
| 51 | + try: |
| 52 | + date_obj = datetime.strptime(year_month_str, "%Y-%m") |
| 53 | + first_day = date_obj.date().replace(day=1) |
| 54 | + last_day = (date_obj + relativedelta(months=1, days=-1)).date() |
| 55 | + return first_day, last_day |
| 56 | + except ValueError: |
| 57 | + raise ValueError("Invalid date format. Please use '%Y-%m'.") |
| 58 | + |
| 59 | + |
| 60 | +def parse_month_range(start_date, end_date): |
| 61 | + try: |
| 62 | + start_date, _ = get_first_last_day(start_date) |
| 63 | + _, end_date = get_first_last_day(end_date) |
| 64 | + return start_date, end_date |
| 65 | + except ValueError as e: |
| 66 | + raise ValueError(str(e)) |
| 67 | + |
| 68 | + |
| 69 | +def clean_data(dataset): |
| 70 | + def map_fn(data): |
| 71 | + question = data["question_content"] + "\n\n" |
| 72 | + if data["starter_code"]: |
| 73 | + question += f"{PromptConstants.FORMATTING_MESSAGE_WITH_STARTER_CODE}\n" |
| 74 | + question += f"```python\n{data['starter_code']}\n```\n\n" |
| 75 | + else: |
| 76 | + question += f"{PromptConstants.FORMATTING_WITHOUT_STARTER_CODE}\n\n" |
| 77 | + question += f"```python\n# YOUR CODE HERE\n```\n\n" |
| 78 | + |
| 79 | + data["task_id"] = data["question_id"] |
| 80 | + data['question'] = question.replace(' ', '\t') |
| 81 | + return data |
| 82 | + |
| 83 | + remove_columns = [ |
| 84 | + 'question_title', |
| 85 | + 'contest_id', |
| 86 | + 'public_test_cases', |
| 87 | + 'private_test_cases', |
| 88 | + 'metadata', |
| 89 | + 'question_content', |
| 90 | + 'platform', |
| 91 | + 'question_id', |
| 92 | + 'starter_code', |
| 93 | + ] |
| 94 | + dataset = dataset.map(map_fn, remove_columns=remove_columns) |
| 95 | + return dataset |
| 96 | + |
| 97 | + |
| 98 | +def prepare(start_date, end_date, release_version, output_dir): |
| 99 | + start_date, end_date = parse_month_range(start_date, end_date) |
| 100 | + start_yymm = start_date.strftime("%y%m") |
| 101 | + end_yymm = end_date.strftime("%y%m") |
| 102 | + output_file_path = os.path.join(output_dir, f"test_{release_version}_{start_yymm}_{end_yymm}.jsonl") |
| 103 | + |
| 104 | + assert release_version in ["v1", "v2", "v3", "v4", "v5", "v6"] |
| 105 | + |
| 106 | + data = parse_data(release_version=f"release_{release_version}") |
| 107 | + data = clean_data(data) |
| 108 | + print("Len of data: ", len(data)) |
| 109 | + |
| 110 | + print("Writing to file...") |
| 111 | + if not os.path.exists(output_dir): |
| 112 | + os.makedirs(output_dir) |
| 113 | + |
| 114 | + with open(output_file_path, 'w') as f: |
| 115 | + for problem in data: |
| 116 | + input_date = datetime.strptime(problem['contest_date'], '%Y-%m-%dT%H:%M:%S').date() |
| 117 | + if start_date <= input_date <= end_date: |
| 118 | + json.dump( |
| 119 | + { |
| 120 | + "task_id": problem["task_id"], |
| 121 | + "question": problem["question"], |
| 122 | + "difficulty": problem["difficulty"], |
| 123 | + "subset_for_metrics": problem["difficulty"], |
| 124 | + }, |
| 125 | + f, |
| 126 | + ) |
| 127 | + f.write('\n') |
| 128 | + |
| 129 | + |
| 130 | +DEFAULT_SPLITS = [ |
| 131 | + ('v5', '2024-08', '2025-02'), |
| 132 | + ('v5', '2024-10', '2025-02'), |
| 133 | + ('v5', '2024-10', '2025-04'), |
| 134 | + ('v6', '2024-08', '2025-02'), |
| 135 | + ('v6', '2024-10', '2025-02'), |
| 136 | + ('v6', '2024-10', '2025-04'), |
| 137 | +] |
| 138 | + |
| 139 | + |
| 140 | +if __name__ == '__main__': |
| 141 | + # Write an argparse to a json file, read it in and parse it |
| 142 | + parser = argparse.ArgumentParser() |
| 143 | + parser.add_argument('--output_dir', type=str, default=str(Path(__file__).parent)) |
| 144 | + parser.add_argument('--release_version', type=str, default='all') |
| 145 | + parser.add_argument('--start_date', type=str, default='all', help="End date in YYYY-MM format") |
| 146 | + parser.add_argument('--end_date', type=str, default='all', help="End date in YYYY-MM format") |
| 147 | + |
| 148 | + args = parser.parse_args() |
| 149 | + |
| 150 | + if args.release_version == 'all' and args.start_date == 'all' and args.end_date == 'all': |
| 151 | + # Prepare all splits |
| 152 | + for release_version, start_date, end_date in DEFAULT_SPLITS: |
| 153 | + print(f"Processing data for {release_version} from {start_date} to {end_date}") |
| 154 | + prepare(start_date, end_date, release_version, args.output_dir) |
| 155 | + else: |
| 156 | + if args.release_version == 'all' or args.start_date == 'all' or args.end_date == 'all': |
| 157 | + raise ValueError( |
| 158 | + "If preparing a custom split, you must specify all " |
| 159 | + "--release_version, --start_date, and --end_date arguments." |
| 160 | + ) |
| 161 | + prepare(args.start_date, args.end_date, args.release_version, args.output_dir) |
| 162 | + |
| 163 | + # test_v5_2408_2502.jsonl: 279 samples |
| 164 | + # test_v5_2410_2502.jsonl: 166 samples |
| 165 | + # test_v5_2410_2504.jsonl: 166 samples |
| 166 | + # test_v6_2408_2502.jsonl: 374 samples |
| 167 | + # test_v6_2410_2502.jsonl: 261 samples |
| 168 | + # test_v6_2410_2504.jsonl: 341 samples |
0 commit comments