@@ -55,7 +55,7 @@ def __init__(
5555 data : List [Dict [str , str ]],
5656 tokenizer : Tokenizer ,
5757 prompt_style : Union [str , PromptStyle ],
58- max_seq_length : int = - 1 ,
58+ max_seq_length : Optional [ int ] = None ,
5959 mask_prompt : bool = True ,
6060 ignore_index : int = - 100 ,
6161 transform : Optional [Callable [[Dict [str , str ]], Dict [str , str ]]] = None ,
@@ -84,9 +84,10 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
8484 if self .transform is not None :
8585 example = self .transform (example )
8686 prompt = self .prompt_style .apply (prompt = example ["instruction" ], ** example )
87+ max_length = - 1 if self .max_seq_length is None else self .max_seq_length
8788 encoded_prompt = self .tokenizer .encode (
8889 prompt ,
89- max_length = self . max_seq_length ,
90+ max_length = max_length ,
9091 )
9192 targets = example ["output" ]
9293 if isinstance (targets , list ):
@@ -99,15 +100,14 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
99100 _targets ,
100101 bos = False ,
101102 eos = True ,
102- max_length = self . max_seq_length ,
103+ max_length = max_length ,
103104 )
104105 encoded_prompt_and_response = torch .cat (
105106 (encoded_prompt , encoded_response )
106107 ).type (torch .int64 )
107- msl = self .max_seq_length
108- if 0 < msl < len (encoded_prompt_and_response ):
109- encoded_prompt_and_response = encoded_prompt_and_response [:msl ]
110- encoded_prompt_and_response [msl - 1 ] = self .tokenizer .eos_id
108+ if 0 < max_length < len (encoded_prompt_and_response ):
109+ encoded_prompt_and_response = encoded_prompt_and_response [:max_length ]
110+ encoded_prompt_and_response [max_length - 1 ] = self .tokenizer .eos_id
111111
112112 # The labels are the full prompt with response, but with the prompt masked out
113113 labels = encoded_prompt_and_response .clone ()
0 commit comments