-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
112 lines (89 loc) · 3.74 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import re
def parse_context(context):
sections = {
'id': None,
'problem_statement': None,
'solution': None,
'keywords': None,
'reference': None
}
id_match = re.search(r'ID:\s*(\d+)', context)
if id_match:
sections['id'] = id_match.group(1)
problem_match = re.search(r'Problem Statement\s*(.*?)\s*(?=(Solution|Keywords|Reference|$))', context, re.DOTALL)
if problem_match:
sections['problem_statement'] = problem_match.group(1).strip()
solution_match = re.search(r'Solution\s*(.*?)\s*(?=(Problem Statement|Keywords|Reference|$))', context, re.DOTALL)
if solution_match:
sections['solution'] = solution_match.group(1).strip()
keywords_match = re.search(r'Keywords\s*(.*?)\s*(?=(Problem Statement|Solution|Reference|$))', context, re.DOTALL)
if keywords_match:
sections['keywords'] = keywords_match.group(1).strip()
reference_match = re.search(r'Reference\s*(.*?)\s*(?=(Problem Statement|Solution|Keywords|$))', context, re.DOTALL)
if reference_match:
sections['reference'] = reference_match.group(1).strip()
return sections
def parse_file(file_path):
with open(file_path, 'r') as file:
content = file.read()
contexts = content.split('--------------------------')
parsed_contexts = []
for context in contexts:
context = context.strip()
if context:
parsed_context = parse_context(context)
parsed_contexts.append(parsed_context)
return parsed_contexts
# Example usage:
file_path = 'context.txt'
parsed_contexts = parse_file(file_path)
for context in parsed_contexts:
print(context)
from transformers import Dataset, AutoModelForCausalLM, AutoTokenizer, Trainer
import torch
# Define a custom dataset class
class MyCustomDataset(Dataset):
def __init__(self, txt_file):
# Load the CSV file into a Pandas DataFrame
with open(txt_file, 'r') as file:
content = file.read()
contexts = content.split('--------------------------')
parsed_contexts = []
for context in contexts:
context = context.strip()
if context:
parsed_context = parse_context(context)
parsed_contexts.append(parsed_context)
self.data = parsed_contexts
def __len__(self):
# Return the total number of samples in the dataset
return len(self.data)
def getitem(self, idx):
# Return a dictionary containing the problem, solution, keyword, and reference for each sample
problem = self.data[idx]['problem']
solution = self.data[idx]['solution']
keyword = self.data[idx]['keyword']
reference = self.data[idx]['reference']
return {
'problem': problem,
'solution': solution,
'keyword': keyword,
'reference': reference
}
# Load the pre-trained LLaMA-3 model and tokenizer
model_name = 'llaama/llaama-3'
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Initialize the Trainer with your custom dataset and hyperparameters
trainer = Trainer(
model=model,
args={
'output_dir': '.', # Output directory for the trained model
'max_train_steps': 1000, # Maximum number of training steps
'per_device_train_batch_size': 16, # Batch size per device during training
},
train_dataset=MyCustomDataset(csv_file='your_custom_dataset.csv'), # Load your custom dataset
compute_metrics=lambda pred: {'accuracy': torch.sum(pred.label_ids == pred.predictions).item()}
)
# Fine-tune the model using your custom dataset
trainer.train()