@@ -23,21 +23,24 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
2323 preview_btn .click (get_preview , [dataset_dir , dataset ], [preview_count , preview_samples , preview_box ])
2424
2525 with gr .Row ():
26+ max_source_length = gr .Slider (value = 512 , minimum = 4 , maximum = 4096 , step = 1 )
27+ max_target_length = gr .Slider (value = 512 , minimum = 4 , maximum = 4096 , step = 1 )
2628 learning_rate = gr .Textbox (value = "5e-5" )
2729 num_train_epochs = gr .Textbox (value = "3.0" )
2830 max_samples = gr .Textbox (value = "100000" )
2931
3032 with gr .Row ():
31- batch_size = gr .Slider (value = 4 , minimum = 1 , maximum = 128 , step = 1 )
32- gradient_accumulation_steps = gr .Slider (value = 4 , minimum = 1 , maximum = 32 , step = 1 )
33+ batch_size = gr .Slider (value = 4 , minimum = 1 , maximum = 512 , step = 1 )
34+ gradient_accumulation_steps = gr .Slider (value = 4 , minimum = 1 , maximum = 512 , step = 1 )
3335 lr_scheduler_type = gr .Dropdown (
3436 value = "cosine" , choices = [scheduler .value for scheduler in SchedulerType ]
3537 )
38+ dev_ratio = gr .Slider (value = 0 , minimum = 0 , maximum = 1 , step = 0.001 )
3639 fp16 = gr .Checkbox (value = True )
3740
3841 with gr .Row ():
3942 logging_steps = gr .Slider (value = 5 , minimum = 5 , maximum = 1000 , step = 5 )
40- save_steps = gr .Slider (value = 100 , minimum = 10 , maximum = 2000 , step = 10 )
43+ save_steps = gr .Slider (value = 100 , minimum = 10 , maximum = 5000 , step = 10 )
4144
4245 with gr .Row ():
4346 start_btn = gr .Button ()
@@ -63,12 +66,15 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
6366 top_elems ["source_prefix" ],
6467 dataset_dir ,
6568 dataset ,
69+ max_source_length ,
70+ max_target_length ,
6671 learning_rate ,
6772 num_train_epochs ,
6873 max_samples ,
6974 batch_size ,
7075 gradient_accumulation_steps ,
7176 lr_scheduler_type ,
77+ dev_ratio ,
7278 fp16 ,
7379 logging_steps ,
7480 save_steps ,
@@ -89,12 +95,15 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
8995 preview_count = preview_count ,
9096 preview_samples = preview_samples ,
9197 close_btn = close_btn ,
98+ max_source_length = max_source_length ,
99+ max_target_length = max_target_length ,
92100 learning_rate = learning_rate ,
93101 num_train_epochs = num_train_epochs ,
94102 max_samples = max_samples ,
95103 batch_size = batch_size ,
96104 gradient_accumulation_steps = gradient_accumulation_steps ,
97105 lr_scheduler_type = lr_scheduler_type ,
106+ dev_ratio = dev_ratio ,
98107 fp16 = fp16 ,
99108 logging_steps = logging_steps ,
100109 save_steps = save_steps ,
0 commit comments