forked from thinking-machines-lab/tinker-cookbook
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsl_basic.py
More file actions
51 lines (45 loc) · 1.81 KB
/
sl_basic.py
File metadata and controls
51 lines (45 loc) · 1.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import chz
import sys
from tinker_cookbook import cli_utils, model_info
from tinker_cookbook.recipes.chat_sl import chat_datasets
from tinker_cookbook.renderers import TrainOnWhat
from tinker_cookbook.supervised import train
from tinker_cookbook.supervised.data import FromConversationFileBuilder
from tinker_cookbook.supervised.types import ChatDatasetBuilderCommonConfig
import asyncio
def build_config_blueprint() -> chz.Blueprint[train.Config]:
model_name = "meta-llama/Llama-3.1-8B"
renderer_name = model_info.get_recommended_renderer_name(model_name)
common_config = ChatDatasetBuilderCommonConfig(
model_name_for_tokenizer=model_name,
renderer_name=renderer_name,
max_length=32768,
batch_size=128,
train_on_what=TrainOnWhat.ALL_ASSISTANT_MESSAGES,
)
dataset = chat_datasets.NoRobotsBuilder(common_config=common_config)
if 0: # To swap in your own dataset:
dataset = FromConversationFileBuilder(
common_config=common_config, file_path="/path/to/your/dataset.jsonl"
)
# ^^^ Create a dataset from a JSONL file in the same format as
# example-data/conversations.jsonl
return chz.Blueprint(train.Config).apply(
{
"log_path": "/tmp/tinker-examples/sl_basic",
"model_name": model_name,
"dataset_builder": dataset,
"learning_rate": 2e-4,
"lr_schedule": "linear",
"num_epochs": 1,
"eval_every": 8,
}
)
def main(config: train.Config):
# Avoid clobbering log dir from your previous run:
cli_utils.check_log_dir(config.log_path, behavior_if_exists="ask")
asyncio.run(train.main(config))
if __name__ == "__main__":
blueprint = build_config_blueprint()
blueprint.make_from_argv(sys.argv[1:])
main(blueprint.make())