From d60e6825f11368b613fc34376f196424110f622a Mon Sep 17 00:00:00 2001 From: Matt Stancliff Date: Sun, 24 Jul 2022 12:02:58 -0700 Subject: [PATCH 01/12] Reformat everything properly Join 2022 with the rest of us. Fixed via: `fd -e py -X black` ``` $ black --version black, 22.6.0 (compiled: yes) Python (CPython) 3.10.2 ``` --- mingpt/bpe.py | 128 ++++++++++------- mingpt/model.py | 232 +++++++++++++++++++++---------- mingpt/trainer.py | 16 ++- mingpt/utils.py | 44 +++--- projects/adder/adder.py | 115 +++++++++------ projects/chargpt/chargpt.py | 36 +++-- tests/test_huggingface_import.py | 24 ++-- 7 files changed, 380 insertions(+), 215 deletions(-) diff --git a/mingpt/bpe.py b/mingpt/bpe.py index b8468ef9..a0c92a66 100644 --- a/mingpt/bpe.py +++ b/mingpt/bpe.py @@ -17,6 +17,7 @@ # ----------------------------------------------------------------------------- + def bytes_to_unicode(): """ Every possible byte (really an integer 0..255) gets mapped by OpenAI to a unicode @@ -33,8 +34,12 @@ def bytes_to_unicode(): like 'Ā', or 'Ġ', etc. """ # the 188 integers that render fine in their original form and need no shifting - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) - cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict # now get the representations of the other 68 integers that do need shifting # each will get mapped chr(256 + n), where n will grow from 0...67 in the loop n = 0 @@ -42,12 +47,13 @@ def bytes_to_unicode(): if b not in bs: # if this byte is "ugly" then map it to the next available "nice" character bs.append(b) - cs.append(2**8+n) + cs.append(2**8 + n) n += 1 cs = [chr(n) for n in cs] d = dict(zip(bs, cs)) return d + def get_pairs(word): """ Return all bigrams as a set of tuples, of consecutive elements in the iterable word. @@ -59,15 +65,15 @@ def get_pairs(word): prev_char = char return pairs -class Encoder: +class Encoder: def __init__(self, encoder, bpe_merges): # byte encoder/decoder self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} # bpe token encoder/decoder self.encoder = encoder - self.decoder = {v:k for k,v in self.encoder.items()} + self.decoder = {v: k for k, v in self.encoder.items()} # bpe merge list that defines the bpe "tree", of tuples (a,b) that are to merge to token ab self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) # the splitting pattern used for pre-tokenization @@ -89,7 +95,9 @@ def __init__(self, encoder, bpe_merges): - we are special casing a few common apostrophe constructs ('s, 't, 're, ...) and making those into separate tokens - we then separate out strings into consecutive chunks of 1) letters, 2) numbers, 3) non-letter-numbers, 4) whitespaces """ - self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + self.pat = re.compile( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) self.cache = {} def bpe(self, token): @@ -104,8 +112,8 @@ def bpe(self, token): if token in self.cache: return self.cache[token] - word = tuple(token) # individual characters that make up the token, in a tuple - pairs = get_pairs(word) # get all bigrams + word = tuple(token) # individual characters that make up the token, in a tuple + pairs = get_pairs(word) # get all bigrams if not pairs: return token @@ -113,9 +121,9 @@ def bpe(self, token): while True: # find the next lowest rank bigram that can be merged - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) if bigram not in self.bpe_ranks: - break # no more bigrams are eligible to be merged + break # no more bigrams are eligible to be merged first, second = bigram # we will now replace all occurences of (first, second) in the list of current @@ -134,8 +142,8 @@ def bpe(self, token): break # if this occurence is also followed by second, then merge them into one - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) i += 2 else: new_word.append(word[i]) @@ -152,25 +160,25 @@ def bpe(self, token): # concat all words into a string, and use ' ' as the separator. Note that # by now all characters have been byte encoded, guaranteeing that ' ' is # not used in the actual data and is a 'special' delimiter character - word = ' '.join(word) + word = " ".join(word) # cache the result and return self.cache[token] = word return word def encode(self, text): - """ string goes in, list of integers comes out """ + """string goes in, list of integers comes out""" bpe_idx = [] # pre-tokenize the input text into string tokens (words, roughly speaking) tokens = re.findall(self.pat, text) # process each token into BPE integers for token in tokens: # encode the token as a bytes (b'') object - token_bytes = token.encode('utf-8') + token_bytes = token.encode("utf-8") # translate all bytes to their unicode string representation and flatten - token_translated = ''.join(self.byte_encoder[b] for b in token_bytes) + token_translated = "".join(self.byte_encoder[b] for b in token_bytes) # perform all the applicable bpe merges according to self.bpe_ranks - token_merged = self.bpe(token_translated).split(' ') + token_merged = self.bpe(token_translated).split(" ") # translate all bpe tokens to integers token_ix = [self.encoder[bpe_token] for bpe_token in token_merged] # extend our running list of all output integers @@ -178,91 +186,103 @@ def encode(self, text): return bpe_idx def encode_and_show_work(self, text): - """ debugging function, same as encode but returns all intermediate work """ + """debugging function, same as encode but returns all intermediate work""" bpe_idx = [] parts = [] tokens = re.findall(self.pat, text) for token in tokens: - token_bytes = token.encode('utf-8') - token_translated = ''.join(self.byte_encoder[b] for b in token_bytes) - token_merged = self.bpe(token_translated).split(' ') + token_bytes = token.encode("utf-8") + token_translated = "".join(self.byte_encoder[b] for b in token_bytes) + token_merged = self.bpe(token_translated).split(" ") token_ix = [self.encoder[bpe_token] for bpe_token in token_merged] bpe_idx.extend(token_ix) - parts.append({ - 'token': token, - 'token_bytes': token_bytes, - 'token_translated': token_translated, - 'token_merged': token_merged, - 'token_ix': token_ix, - }) + parts.append( + { + "token": token, + "token_bytes": token_bytes, + "token_translated": token_translated, + "token_merged": token_merged, + "token_ix": token_ix, + } + ) out = { - 'bpe_idx': bpe_idx, # the actual output sequence - 'tokens': tokens, # result of pre-tokenization - 'parts': parts, # intermediates for each token part + "bpe_idx": bpe_idx, # the actual output sequence + "tokens": tokens, # result of pre-tokenization + "parts": parts, # intermediates for each token part } return out def decode(self, bpe_idx): - """ list of integers comes in, string comes out """ + """list of integers comes in, string comes out""" # inverse map the integers to get the tokens tokens_merged = [self.decoder[token] for token in bpe_idx] # inverse the byte encoder, e.g. recovering 'Ġ' -> ' ', and get the bytes - tokens_flat = ''.join(tokens_merged) + tokens_flat = "".join(tokens_merged) tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat]) # recover the full utf-8 string - text = tokens_bytes.decode('utf-8', errors='replace') + text = tokens_bytes.decode("utf-8", errors="replace") return text + def get_file(local_file, remote_file): - """ downloads remote_file to local_file if necessary """ + """downloads remote_file to local_file if necessary""" if not os.path.isfile(local_file): print(f"downloading {remote_file} to {local_file}") response = requests.get(remote_file) open(local_file, "wb").write(response.content) + def get_encoder(): """ Returns an instance of the GPT BPE Encoder/Decoder and handles caching of "database" files. """ - home_dir = os.path.expanduser('~') - cache_dir = os.path.join(home_dir, '.cache', 'mingpt') + home_dir = os.path.expanduser("~") + cache_dir = os.path.join(home_dir, ".cache", "mingpt") os.makedirs(cache_dir, exist_ok=True) # load encoder.json that has the raw mappings from token -> bpe index - encoder_local_file = os.path.join(cache_dir, 'encoder.json') - encoder_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json' + encoder_local_file = os.path.join(cache_dir, "encoder.json") + encoder_remote_file = ( + "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json" + ) get_file(encoder_local_file, encoder_remote_file) - with open(encoder_local_file, 'r') as f: + with open(encoder_local_file, "r") as f: encoder = json.load(f) - assert len(encoder) == 50257 # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token + assert ( + len(encoder) == 50257 + ) # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token # load vocab.bpe that contains the bpe merges, i.e. the bpe tree structure # in the form tuples (a, b), that indicate that (a, b) is to be merged to one token ab - vocab_local_file = os.path.join(cache_dir, 'vocab.bpe') - vocab_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe' + vocab_local_file = os.path.join(cache_dir, "vocab.bpe") + vocab_remote_file = ( + "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe" + ) get_file(vocab_local_file, vocab_remote_file) - with open(vocab_local_file, 'r', encoding="utf-8") as f: + with open(vocab_local_file, "r", encoding="utf-8") as f: bpe_data = f.read() # light postprocessing: strip the version on first line and the last line is a blank - bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] - assert len(bpe_merges) == 50000 # 50,000 merged tokens + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] + assert len(bpe_merges) == 50000 # 50,000 merged tokens # construct the Encoder object and return enc = Encoder(encoder, bpe_merges) return enc + # ----------------------------------------------------------------------------- + class BPETokenizer: - """ PyTorch-aware class that wraps the Encoder above """ + """PyTorch-aware class that wraps the Encoder above""" def __init__(self): self.encoder = get_encoder() - def __call__(self, text, return_tensors='pt'): + def __call__(self, text, return_tensors="pt"): # PyTorch only; here because we want to match huggingface/transformers interface - assert return_tensors == 'pt' + assert return_tensors == "pt" # single string input for now, in the future potentially a list of strings assert isinstance(text, str) # encode and create a "batch dimension" of 1 @@ -279,7 +299,7 @@ def decode(self, idx): return text -if __name__ == '__main__': +if __name__ == "__main__": # here is an encoding example text = "Hello!! I'm Andrej Karpathy. It's 2022. w00t :D 🤗" @@ -289,10 +309,10 @@ def decode(self, idx): print("Original text is:") print(text) print("First the text gets pre-tokenized, broken up into chunks, the outcome is:") - print(r['tokens']) + print(r["tokens"]) # ['Hello', '!!', ' I', "'m", ' Andrej', ' Karpathy', '.', ' It', "'s", ' 2022', '.', ' w', '00', 't', ' :', 'D', ' 🤗'] print("Then we iterate over each chunk and process them in turn...") - for part in r['parts']: + for part in r["parts"]: print(part) # {'token': 'Hello', 'token_bytes': b'Hello', 'token_translated': 'Hello', 'token_merged': ['Hello'], 'token_ix': [15496]} # {'token': '!!', 'token_bytes': b'!!', 'token_translated': '!!', 'token_merged': ['!!'], 'token_ix': [3228]} @@ -313,7 +333,7 @@ def decode(self, idx): # {'token': ' 🤗', 'token_bytes': b' \xf0\x9f\xa4\x97', 'token_translated': 'Ġð٤Ĺ', 'token_merged': ['ĠðŁ', '¤', 'Ĺ'], 'token_ix': [12520, 97, 245]} # (refer to the code inside Encoder.encode for what these intermediates are) print("and the final outcome is concatenating and flattening all the token_ix:") - print(r['bpe_idx']) + print(r["bpe_idx"]) # [15496, 3228, 314, 1101, 10948, 73, 509, 5117, 10036, 13, 632, 338, 33160, 13, 266, 405, 83, 1058, 35, 12520, 97, 245] # this would then become the integer input sequence to the transformer print("ready to feed into a Transformer!") diff --git a/mingpt/model.py b/mingpt/model.py index dfa95bf6..b40a3f5e 100644 --- a/mingpt/model.py +++ b/mingpt/model.py @@ -18,13 +18,25 @@ # ----------------------------------------------------------------------------- + class NewGELU(nn.Module): """ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 """ + def forward(self, x): - return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + return ( + 0.5 + * x + * ( + 1.0 + + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) + ) + ) + ) + class CausalSelfAttention(nn.Module): """ @@ -44,65 +56,85 @@ def __init__(self, config): self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) - .view(1, 1, config.block_size, config.block_size)) + self.register_buffer( + "bias", + torch.tril(torch.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size + ), + ) self.n_head = config.n_head self.n_embd = config.n_embd def forward(self, x): - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + ( + B, + T, + C, + ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) return y + class Block(nn.Module): - """ an unassuming Transformer block """ + """an unassuming Transformer block""" def __init__(self, config): super().__init__() self.ln_1 = nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) self.ln_2 = nn.LayerNorm(config.n_embd) - self.mlp = nn.ModuleDict(dict( - c_fc = nn.Linear(config.n_embd, 4 * config.n_embd), - c_proj = nn.Linear(4 * config.n_embd, config.n_embd), - act = NewGELU(), - dropout = nn.Dropout(config.resid_pdrop), - )) + self.mlp = nn.ModuleDict( + dict( + c_fc=nn.Linear(config.n_embd, 4 * config.n_embd), + c_proj=nn.Linear(4 * config.n_embd, config.n_embd), + act=NewGELU(), + dropout=nn.Dropout(config.resid_pdrop), + ) + ) m = self.mlp - self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward + self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward def forward(self, x): x = x + self.attn(self.ln_1(x)) x = x + self.mlpf(self.ln_2(x)) return x + class GPT(nn.Module): - """ GPT Language Model """ + """GPT Language Model""" @staticmethod def get_default_config(): C = CN() # either model_type or (n_layer, n_head, n_embd) must be given in the config - C.model_type = 'gpt' + C.model_type = "gpt" C.n_layer = None C.n_head = None - C.n_embd = None + C.n_embd = None # these options must be filled in externally C.vocab_size = None C.block_size = None @@ -119,46 +151,66 @@ def __init__(self, config): self.block_size = config.block_size type_given = config.model_type is not None - params_given = all([config.n_layer is not None, config.n_head is not None, config.n_embd is not None]) - assert (type_given and not params_given) or (not type_given and params_given) # exactly one of these + params_given = all( + [ + config.n_layer is not None, + config.n_head is not None, + config.n_embd is not None, + ] + ) + assert (type_given and not params_given) or ( + not type_given and params_given + ) # exactly one of these if type_given: # translate from model_type to detailed configuration - config.merge_from_dict({ - # names follow the huggingface naming conventions - # GPT-1 - 'openai-gpt': dict(n_layer=12, n_head=12, n_embd=768), # 117M params - # GPT-2 configs - 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params - 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params - 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params - 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params - # Gophers - 'gopher-44m': dict(n_layer=8, n_head=16, n_embd=512), - # (there are a number more...) - # I made these tiny models up - 'gpt-mini': dict(n_layer=6, n_head=6, n_embd=192), - 'gpt-micro': dict(n_layer=4, n_head=4, n_embd=128), - 'gpt-nano': dict(n_layer=3, n_head=3, n_embd=48), - }[config.model_type]) - - self.transformer = nn.ModuleDict(dict( - wte = nn.Embedding(config.vocab_size, config.n_embd), - wpe = nn.Embedding(config.block_size, config.n_embd), - drop = nn.Dropout(config.embd_pdrop), - h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - ln_f = nn.LayerNorm(config.n_embd), - )) + config.merge_from_dict( + { + # names follow the huggingface naming conventions + # GPT-1 + "openai-gpt": dict( + n_layer=12, n_head=12, n_embd=768 + ), # 117M params + # GPT-2 configs + "gpt2": dict(n_layer=12, n_head=12, n_embd=768), # 124M params + "gpt2-medium": dict( + n_layer=24, n_head=16, n_embd=1024 + ), # 350M params + "gpt2-large": dict( + n_layer=36, n_head=20, n_embd=1280 + ), # 774M params + "gpt2-xl": dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params + # Gophers + "gopher-44m": dict(n_layer=8, n_head=16, n_embd=512), + # (there are a number more...) + # I made these tiny models up + "gpt-mini": dict(n_layer=6, n_head=6, n_embd=192), + "gpt-micro": dict(n_layer=4, n_head=4, n_embd=128), + "gpt-nano": dict(n_layer=3, n_head=3, n_embd=48), + }[config.model_type] + ) + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.vocab_size, config.n_embd), + wpe=nn.Embedding(config.block_size, config.n_embd), + drop=nn.Dropout(config.embd_pdrop), + h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + ln_f=nn.LayerNorm(config.n_embd), + ) + ) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper self.apply(self._init_weights) for pn, p in self.named_parameters(): - if pn.endswith('c_proj.weight'): - torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) + ) # report number of parameters (note we don't count the decoder parameters in lm_head) n_params = sum(p.numel() for p in self.transformer.parameters()) - print("number of parameters: %.2fM" % (n_params/1e6,)) + print("number of parameters: %.2fM" % (n_params / 1e6,)) def _init_weights(self, module): if isinstance(module, nn.Linear): @@ -177,13 +229,13 @@ def from_pretrained(cls, model_type): Initialize a pretrained GPT model by copying over the weights from a huggingface/transformers checkpoint. """ - assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} + assert model_type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"} from transformers import GPT2LMHeadModel # create a from-scratch initialized minGPT model config = cls.get_default_config() config.model_type = model_type - config.vocab_size = 50257 # openai's model vocabulary + config.vocab_size = 50257 # openai's model vocabulary config.block_size = 1024 # openai's model block_size model = GPT(config) sd = model.state_dict() @@ -193,8 +245,13 @@ def from_pretrained(cls, model_type): sd_hf = model_hf.state_dict() # copy while ensuring all of the parameters are aligned and match in names and shapes - keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these - transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] + keys = [k for k in sd_hf if not k.endswith("attn.masked_bias")] # ignore these + transposed = [ + "attn.c_attn.weight", + "attn.c_proj.weight", + "mlp.c_fc.weight", + "mlp.c_proj.weight", + ] # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear. # this means that we have to transpose these weights when we import them assert len(keys) == len(sd) @@ -223,21 +280,21 @@ def configure_optimizers(self, train_config): # separate out all parameters to those that will and won't experience regularizing weight decay decay = set() no_decay = set() - whitelist_weight_modules = (torch.nn.Linear, ) + whitelist_weight_modules = (torch.nn.Linear,) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + fpn = "%s.%s" % (mn, pn) if mn else pn # full param name # random note: because named_modules and named_parameters are recursive # we will see the same tensors p many many times. but doing it this way # allows us to know which parent module any tensor p belongs to... - if pn.endswith('bias'): + if pn.endswith("bias"): # all biases will not be decayed no_decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) @@ -245,27 +302,46 @@ def configure_optimizers(self, train_config): param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay - assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) - assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ - % (str(param_dict.keys() - union_params), ) + assert ( + len(inter_params) == 0 + ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) + assert ( + len(param_dict.keys() - union_params) == 0 + ), "parameters %s were not separated into either decay/no_decay set!" % ( + str(param_dict.keys() - union_params), + ) # create the pytorch optimizer object optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + { + "params": [param_dict[pn] for pn in sorted(list(decay))], + "weight_decay": train_config.weight_decay, + }, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0, + }, ] - optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) + optimizer = torch.optim.AdamW( + optim_groups, lr=train_config.learning_rate, betas=train_config.betas + ) return optimizer def forward(self, idx, targets=None): device = idx.device b, t = idx.size() - assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}" - pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + assert ( + t <= self.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.block_size}" + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( + 0 + ) # shape (1, t) # forward the GPT model itself - tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + pos_emb = self.transformer.wpe( + pos + ) # position embeddings of shape (1, t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x) @@ -275,12 +351,16 @@ def forward(self, idx, targets=None): # if we are given some desired targets also calculate the loss loss = None if targets is not None: - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 + ) return logits, loss @torch.no_grad() - def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None): + def generate( + self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None + ): """ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. @@ -288,7 +368,9 @@ def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k= """ for _ in range(max_new_tokens): # if the sequence context is growing too long we must crop it at block_size - idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:] + idx_cond = ( + idx if idx.size(1) <= self.block_size else idx[:, -self.block_size :] + ) # forward the model to get the logits for the index in the sequence logits, _ = self(idx_cond) # pluck the logits at the final step and scale by desired temperature @@ -296,7 +378,7 @@ def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k= # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, top_k) - logits[logits < v[:, [-1]]] = -float('Inf') + logits[logits < v[:, [-1]]] = -float("Inf") # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) # either sample from the distribution or take the most likely element diff --git a/mingpt/trainer.py b/mingpt/trainer.py index ee6bc0b8..1d8f493c 100644 --- a/mingpt/trainer.py +++ b/mingpt/trainer.py @@ -10,13 +10,13 @@ from torch.utils.data.dataloader import DataLoader from mingpt.utils import CfgNode as CN -class Trainer: +class Trainer: @staticmethod def get_default_config(): C = CN() # device to train on - C.device = 'auto' + C.device = "auto" # dataloder parameters C.num_workers = 4 # optimizer parameters @@ -24,7 +24,7 @@ def get_default_config(): C.batch_size = 64 C.learning_rate = 3e-4 C.betas = (0.9, 0.95) - C.weight_decay = 0.1 # only applied on matmul weights + C.weight_decay = 0.1 # only applied on matmul weights C.grad_norm_clip = 1.0 return C @@ -35,8 +35,8 @@ def __init__(self, config, model, train_dataset): self.callbacks = defaultdict(list) # determine the device we'll train on - if config.device == 'auto': - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + if config.device == "auto": + self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = config.device self.model = self.model.to(self.device) @@ -66,7 +66,9 @@ def run(self): # setup the dataloader train_loader = DataLoader( self.train_dataset, - sampler=torch.utils.data.RandomSampler(self.train_dataset, replacement=True, num_samples=int(1e10)), + sampler=torch.utils.data.RandomSampler( + self.train_dataset, replacement=True, num_samples=int(1e10) + ), shuffle=False, pin_memory=True, batch_size=config.batch_size, @@ -97,7 +99,7 @@ def run(self): torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) optimizer.step() - self.trigger_callbacks('on_batch_end') + self.trigger_callbacks("on_batch_end") self.iter_num += 1 tnow = time.time() self.iter_dt = tnow - self.iter_time diff --git a/mingpt/utils.py b/mingpt/utils.py index af864ecb..f24436aa 100644 --- a/mingpt/utils.py +++ b/mingpt/utils.py @@ -1,4 +1,3 @@ - import os import sys import json @@ -10,26 +9,30 @@ # ----------------------------------------------------------------------------- + def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + def setup_logging(config): - """ monotonous bookkeeping """ + """monotonous bookkeeping""" work_dir = config.system.work_dir # create the work directory if it doesn't already exist os.makedirs(work_dir, exist_ok=True) # log the args (if any) - with open(os.path.join(work_dir, 'args.txt'), 'w') as f: - f.write(' '.join(sys.argv)) + with open(os.path.join(work_dir, "args.txt"), "w") as f: + f.write(" ".join(sys.argv)) # log the config itself - with open(os.path.join(work_dir, 'config.json'), 'w') as f: + with open(os.path.join(work_dir, "config.json"), "w") as f: f.write(json.dumps(config.to_dict(), indent=4)) + class CfgNode: - """ a lightweight configuration class inspired by yacs """ + """a lightweight configuration class inspired by yacs""" + # TODO: convert to subclass from a dict like in yacs? # TODO: implement freezing to prevent shooting of own foot # TODO: additional existence/override checks when reading/writing params? @@ -41,7 +44,7 @@ def __str__(self): return self._str_helper(0) def _str_helper(self, indent): - """ need to have a helper to support nested indentation for pretty printing """ + """need to have a helper to support nested indentation for pretty printing""" parts = [] for k, v in self.__dict__.items(): if isinstance(v, CfgNode): @@ -49,12 +52,15 @@ def _str_helper(self, indent): parts.append(v._str_helper(indent + 1)) else: parts.append("%s: %s\n" % (k, v)) - parts = [' ' * (indent * 4) + p for p in parts] + parts = [" " * (indent * 4) + p for p in parts] return "".join(parts) def to_dict(self): - """ return a dict representation of the config """ - return { k: v.to_dict() if isinstance(v, CfgNode) else v for k, v in self.__dict__.items() } + """return a dict representation of the config""" + return { + k: v.to_dict() if isinstance(v, CfgNode) else v + for k, v in self.__dict__.items() + } def merge_from_dict(self, d): self.__dict__.update(d) @@ -71,9 +77,11 @@ def merge_from_args(self, args): """ for arg in args: - keyval = arg.split('=') - assert len(keyval) == 2, "expecting each override arg to be of form --arg=value, got %s" % arg - key, val = keyval # unpack + keyval = arg.split("=") + assert len(keyval) == 2, ( + "expecting each override arg to be of form --arg=value, got %s" % arg + ) + key, val = keyval # unpack # first translate val into a python object try: @@ -87,16 +95,18 @@ def merge_from_args(self, args): pass # find the appropriate object to insert the attribute into - assert key[:2] == '--' - key = key[2:] # strip the '--' - keys = key.split('.') + assert key[:2] == "--" + key = key[2:] # strip the '--' + keys = key.split(".") obj = self for k in keys[:-1]: obj = getattr(obj, k) leaf_key = keys[-1] # ensure that this attribute exists - assert hasattr(obj, leaf_key), f"{key} is not an attribute that exists in the config" + assert hasattr( + obj, leaf_key + ), f"{key} is not an attribute that exists in the config" # overwrite the attribute print("command line overwriting config attribute %s with %s" % (key, val)) diff --git a/projects/adder/adder.py b/projects/adder/adder.py index 55f03ee1..90d9dd1e 100644 --- a/projects/adder/adder.py +++ b/projects/adder/adder.py @@ -16,6 +16,7 @@ # ----------------------------------------------------------------------------- + def get_config(): C = CN() @@ -23,23 +24,27 @@ def get_config(): # system C.system = CN() C.system.seed = 3407 - C.system.work_dir = './out/adder' + C.system.work_dir = "./out/adder" # data C.data = AdditionDataset.get_default_config() # model C.model = GPT.get_default_config() - C.model.model_type = 'gpt-nano' + C.model.model_type = "gpt-nano" # trainer C.trainer = Trainer.get_default_config() - C.trainer.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster + C.trainer.learning_rate = ( + 5e-4 # the model we're using is so small that we can go a bit faster + ) return C + # ----------------------------------------------------------------------------- + class AdditionDataset(Dataset): """ Creates n-digit addition problems. For example, if n=2, then an example @@ -73,26 +78,32 @@ def get_default_config(): def __init__(self, config, split): self.config = config - self.split = split # train/test + self.split = split # train/test # split up all addition problems into either training data or test data ndigit = self.config.ndigit - assert ndigit <= 3, "the lines below would be very memory inefficient, in future maybe refactor to support" - num = (10**ndigit)**2 # total number of possible addition problems with ndigit numbers + assert ( + ndigit <= 3 + ), "the lines below would be very memory inefficient, in future maybe refactor to support" + num = ( + 10**ndigit + ) ** 2 # total number of possible addition problems with ndigit numbers rng = torch.Generator() rng.manual_seed(1337) perm = torch.randperm(num, generator=rng) - num_test = min(int(num*0.2), 500) # 20% of the whole dataset, or only up to 500 - self.ixes = perm[:num_test] if split == 'test' else perm[num_test:] + num_test = min( + int(num * 0.2), 500 + ) # 20% of the whole dataset, or only up to 500 + self.ixes = perm[:num_test] if split == "test" else perm[num_test:] def get_vocab_size(self): - return 10 # digits 0..9 + return 10 # digits 0..9 def get_block_size(self): # a,b,a+b, and +1 due to potential carry overflow, # but then also -1 because very last digit doesn't ever plug back # as there is no explicit token to predict, it is implied - return 3*self.config.ndigit + 1 - 1 + return 3 * self.config.ndigit + 1 - 1 def __len__(self): return self.ixes.nelement() @@ -103,24 +114,29 @@ def __getitem__(self, idx): idx = self.ixes[idx].item() nd = 10**ndigit a = idx // nd - b = idx % nd + b = idx % nd # calculate the "label" of the addition problem a + b c = a + b # encode the digits of a, b, c into strings - astr = f'%0{ndigit}d' % a - bstr = f'%0{ndigit}d' % b - cstr = (f'%0{ndigit+1}d' % c)[::-1] # reverse c to make addition easier + astr = f"%0{ndigit}d" % a + bstr = f"%0{ndigit}d" % b + cstr = (f"%0{ndigit+1}d" % c)[::-1] # reverse c to make addition easier render = astr + bstr + cstr - dix = [int(s) for s in render] # convert each character to its token index + dix = [int(s) for s in render] # convert each character to its token index # x will be input to GPT and y will be the associated expected outputs x = torch.tensor(dix[:-1], dtype=torch.long) - y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence - y[:ndigit*2-1] = -1 # we will only train in the output locations. -1 will mask loss to zero + y = torch.tensor( + dix[1:], dtype=torch.long + ) # predict the next token in the sequence + y[ + : ndigit * 2 - 1 + ] = -1 # we will only train in the output locations. -1 will mask loss to zero return x, y + # ----------------------------------------------------------------------------- -if __name__ == '__main__': +if __name__ == "__main__": # get default config and overrides from the command line, if any config = get_config() @@ -130,8 +146,8 @@ def __getitem__(self, idx): set_seed(config.system.seed) # construct train and test datasets - train_dataset = AdditionDataset(config.data, split='train') - test_dataset = AdditionDataset(config.data, split='test') + train_dataset = AdditionDataset(config.data, split="train") + test_dataset = AdditionDataset(config.data, split="test") # construct the model config.model.vocab_size = train_dataset.get_vocab_size() @@ -143,54 +159,75 @@ def __getitem__(self, idx): # helper function for the evaluation of a model def eval_split(trainer, split, max_batches=None): - dataset = {'train':train_dataset, 'test':test_dataset}[split] + dataset = {"train": train_dataset, "test": test_dataset}[split] ndigit = config.data.ndigit results = [] mistakes_printed_already = 0 - factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device) + factors = torch.tensor([[10**i for i in range(ndigit + 1)][::-1]]).to( + trainer.device + ) loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False) for b, (x, y) in enumerate(loader): x = x.to(trainer.device) # isolate the first two digits of the input sequence alone - d1d2 = x[:, :ndigit*2] + d1d2 = x[:, : ndigit * 2] # let the model sample the rest of the sequence - d1d2d3 = model.generate(d1d2, ndigit+1, do_sample=False) # using greedy argmax, not sampling + d1d2d3 = model.generate( + d1d2, ndigit + 1, do_sample=False + ) # using greedy argmax, not sampling # isolate the last digit of the sampled sequence - d3 = d1d2d3[:, -(ndigit+1):] - d3 = d3.flip(1) # reverse the digits to their "normal" order + d3 = d1d2d3[:, -(ndigit + 1) :] + d3 = d3.flip(1) # reverse the digits to their "normal" order # decode the integers from individual digits - d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1) - d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1) + d1i = (d1d2[:, :ndigit] * factors[:, 1:]).sum(1) + d2i = (d1d2[:, ndigit : ndigit * 2] * factors[:, 1:]).sum(1) d3i_pred = (d3 * factors).sum(1) - d3i_gt = d1i + d2i # manually calculate the ground truth + d3i_gt = d1i + d2i # manually calculate the ground truth # evaluate the correctness of the results in this batch - correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha + correct = ( + d3i_pred == d3i_gt + ).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha for i in range(x.size(0)): results.append(int(correct[i])) - if not correct[i] and mistakes_printed_already < 5: # only print up to 5 mistakes to get a sense + if ( + not correct[i] and mistakes_printed_already < 5 + ): # only print up to 5 mistakes to get a sense mistakes_printed_already += 1 - print("GPT claims that %d + %d = %d but gt is %d" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i])) - if max_batches is not None and b+1 >= max_batches: + print( + "GPT claims that %d + %d = %d but gt is %d" + % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i]) + ) + if max_batches is not None and b + 1 >= max_batches: break rt = torch.tensor(results, dtype=torch.float) - print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean())) + print( + "%s final score: %d/%d = %.2f%% correct" + % (split, rt.sum(), len(results), 100 * rt.mean()) + ) return rt.sum() # iteration callback top_score = 0 + def batch_end_callback(trainer): global top_score if trainer.iter_num % 10 == 0: - print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}") + print( + f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}" + ) if trainer.iter_num % 500 == 0: # evaluate both the train and test score - train_max_batches = {1: None, 2: None, 3: 5}[config.data.ndigit] # if ndigit=2 we can afford the whole train set, ow no + train_max_batches = {1: None, 2: None, 3: 5}[ + config.data.ndigit + ] # if ndigit=2 we can afford the whole train set, ow no model.eval() with torch.no_grad(): - train_score = eval_split(trainer, 'train', max_batches=train_max_batches) - test_score = eval_split(trainer, 'test', max_batches=None) + train_score = eval_split( + trainer, "train", max_batches=train_max_batches + ) + test_score = eval_split(trainer, "test", max_batches=None) score = train_score + test_score # save the model if this is the best score we've seen so far if score > top_score: @@ -201,7 +238,7 @@ def batch_end_callback(trainer): # revert model to training mode model.train() - trainer.set_callback('on_batch_end', batch_end_callback) + trainer.set_callback("on_batch_end", batch_end_callback) # run the optimization trainer.run() diff --git a/projects/chargpt/chargpt.py b/projects/chargpt/chargpt.py index 5de925b0..29dbf01d 100644 --- a/projects/chargpt/chargpt.py +++ b/projects/chargpt/chargpt.py @@ -15,6 +15,7 @@ # ----------------------------------------------------------------------------- + def get_config(): C = CN() @@ -22,23 +23,27 @@ def get_config(): # system C.system = CN() C.system.seed = 3407 - C.system.work_dir = './out/chargpt' + C.system.work_dir = "./out/chargpt" # data C.data = CharDataset.get_default_config() # model C.model = GPT.get_default_config() - C.model.model_type = 'gpt-mini' + C.model.model_type = "gpt-mini" # trainer C.trainer = Trainer.get_default_config() - C.trainer.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster + C.trainer.learning_rate = ( + 5e-4 # the model we're using is so small that we can go a bit faster + ) return C + # ----------------------------------------------------------------------------- + class CharDataset(Dataset): """ Emits batches of characters @@ -55,10 +60,10 @@ def __init__(self, config, data): chars = sorted(list(set(data))) data_size, vocab_size = len(data), len(chars) - print('data has %d characters, %d unique.' % (data_size, vocab_size)) + print("data has %d characters, %d unique." % (data_size, vocab_size)) - self.stoi = { ch:i for i,ch in enumerate(chars) } - self.itos = { i:ch for i,ch in enumerate(chars) } + self.stoi = {ch: i for i, ch in enumerate(chars)} + self.itos = {i: ch for i, ch in enumerate(chars)} self.vocab_size = vocab_size self.data = data @@ -73,7 +78,7 @@ def __len__(self): def __getitem__(self, idx): # grab a chunk of (block_size + 1) characters from the data - chunk = self.data[idx:idx + self.config.block_size + 1] + chunk = self.data[idx : idx + self.config.block_size + 1] # encode every character to an integer dix = [self.stoi[s] for s in chunk] # return as tensors @@ -81,9 +86,10 @@ def __getitem__(self, idx): y = torch.tensor(dix[1:], dtype=torch.long) return x, y + # ----------------------------------------------------------------------------- -if __name__ == '__main__': +if __name__ == "__main__": # get default config and overrides from the command line, if any config = get_config() @@ -93,7 +99,7 @@ def __getitem__(self, idx): set_seed(config.system.seed) # construct the training dataset - text = open('input.txt', 'r').read() # don't worry we won't run out of file handles + text = open("input.txt", "r").read() # don't worry we won't run out of file handles train_dataset = CharDataset(config.data, text) # construct the model @@ -108,7 +114,9 @@ def __getitem__(self, idx): def batch_end_callback(trainer): if trainer.iter_num % 10 == 0: - print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}") + print( + f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}" + ) if trainer.iter_num % 500 == 0: # evaluate both the train and test score @@ -116,9 +124,11 @@ def batch_end_callback(trainer): with torch.no_grad(): # sample from the model... context = "O God, O God!" - x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device) + x = torch.tensor( + [train_dataset.stoi[s] for s in context], dtype=torch.long + )[None, ...].to(trainer.device) y = model.generate(x, 500, temperature=1.0, do_sample=True, top_k=10)[0] - completion = ''.join([train_dataset.itos[int(i)] for i in y]) + completion = "".join([train_dataset.itos[int(i)] for i in y]) print(completion) # save the latest model print("saving model") @@ -127,7 +137,7 @@ def batch_end_callback(trainer): # revert model to training mode model.train() - trainer.set_callback('on_batch_end', batch_end_callback) + trainer.set_callback("on_batch_end", batch_end_callback) # run the optimization trainer.run() diff --git a/tests/test_huggingface_import.py b/tests/test_huggingface_import.py index dab52a82..dcd87c88 100644 --- a/tests/test_huggingface_import.py +++ b/tests/test_huggingface_import.py @@ -7,18 +7,19 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel from mingpt.model import GPT from mingpt.bpe import BPETokenizer + # ----------------------------------------------------------------------------- -class TestHuggingFaceImport(unittest.TestCase): +class TestHuggingFaceImport(unittest.TestCase): def test_gpt2(self): - model_type = 'gpt2' - device = 'cuda' if torch.cuda.is_available() else 'cpu' + model_type = "gpt2" + device = "cuda" if torch.cuda.is_available() else "cpu" prompt = "Hello!!!!!!!!!? 🤗, my dog is a little" # create a minGPT and a huggingface/transformers model model = GPT.from_pretrained(model_type) - model_hf = GPT2LMHeadModel.from_pretrained(model_type) # init a HF model too + model_hf = GPT2LMHeadModel.from_pretrained(model_type) # init a HF model too # ship both to device model.to(device) @@ -34,9 +35,11 @@ def test_gpt2(self): x1 = tokenizer(prompt).to(device) # ... with huggingface/transformers tokenizer_hf = GPT2Tokenizer.from_pretrained(model_type) - model_hf.config.pad_token_id = model_hf.config.eos_token_id # suppress a warning - encoded_input = tokenizer_hf(prompt, return_tensors='pt').to(device) - x2 = encoded_input['input_ids'] + model_hf.config.pad_token_id = ( + model_hf.config.eos_token_id + ) # suppress a warning + encoded_input = tokenizer_hf(prompt, return_tensors="pt").to(device) + x2 = encoded_input["input_ids"] # ensure the logits match exactly logits1, loss = model(x1) @@ -46,12 +49,13 @@ def test_gpt2(self): # now draw the argmax samples from each y1 = model.generate(x1, max_new_tokens=20, do_sample=False)[0] y2 = model_hf.generate(x2, max_new_tokens=20, do_sample=False)[0] - self.assertTrue(torch.equal(y1, y2)) # compare the raw sampled indices + self.assertTrue(torch.equal(y1, y2)) # compare the raw sampled indices # convert indices to strings out1 = tokenizer.decode(y1.cpu().squeeze()) out2 = tokenizer_hf.decode(y2.cpu().squeeze()) - self.assertTrue(out1 == out2) # compare the exact output strings too + self.assertTrue(out1 == out2) # compare the exact output strings too + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() From 85a9ee932639b466d8350d1041210f6d9c511ddd Mon Sep 17 00:00:00 2001 From: Matt Stancliff Date: Sun, 24 Jul 2022 12:14:21 -0700 Subject: [PATCH 02/12] Add default result directory to .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index f92cfd22..8ec38c0f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ __pycache__/ *.swp .env .pylintrc +out/ From c67fb29c084ee6470524ca56d8b661fc30818af7 Mon Sep 17 00:00:00 2001 From: Matt Stancliff Date: Sun, 24 Jul 2022 12:16:07 -0700 Subject: [PATCH 03/12] Use python built-in iterator cycling I've seen this pattern copied out to a dozen different ML repos, but it's just completely wrong? Instead of using 10 lines of "hack python to be smarter than python," it can all be replaced with a one-line python-provided API for repeating a single iterator forever: for iter_num, batch in enumerate(itertools.cycle(train_loader)): --- mingpt/trainer.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/mingpt/trainer.py b/mingpt/trainer.py index 1d8f493c..da0ca9a1 100644 --- a/mingpt/trainer.py +++ b/mingpt/trainer.py @@ -4,6 +4,7 @@ """ import time +import itertools from collections import defaultdict import torch @@ -43,7 +44,7 @@ def __init__(self, config, model, train_dataset): print("running on device", self.device) # variables that will be assigned to trainer class later for logging and etc - self.iter_num = 0 + self.iter_num: int = 0 self.iter_time = 0.0 self.iter_dt = 0.0 @@ -76,17 +77,8 @@ def run(self): ) model.train() - self.iter_num = 0 self.iter_time = time.time() - data_iter = iter(train_loader) - while True: - - # fetch the next batch (x, y) and re-init iterator if needed - try: - batch = next(data_iter) - except StopIteration: - data_iter = iter(train_loader) - batch = next(data_iter) + for iter_num, batch in enumerate(itertools.cycle(train_loader)): batch = [t.to(self.device) for t in batch] x, y = batch @@ -100,11 +92,12 @@ def run(self): optimizer.step() self.trigger_callbacks("on_batch_end") - self.iter_num += 1 tnow = time.time() + self.iter_num = iter_num self.iter_dt = tnow - self.iter_time self.iter_time = tnow # termination conditions - if config.max_iters is not None and self.iter_num >= config.max_iters: - break + if config.max_iters is not None: + if iter_num >= config.max_iters: + break From e6e12ec6286116f94db7bde69039236f3c00fb4d Mon Sep 17 00:00:00 2001 From: Matt Stancliff Date: Sun, 24 Jul 2022 12:22:18 -0700 Subject: [PATCH 04/12] Introduce pathlib.Path instead of os manipulation This is the "proper" python way of doing path operations. It makes file operations one-liners with half the typing vs using 10 year outdated legacy python patterns. Also reducing some unnecessary line padding and unnecessary text decoration in places too. --- mingpt/utils.py | 18 ++++++++---------- projects/adder/adder.py | 6 ++---- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/mingpt/utils.py b/mingpt/utils.py index f24436aa..7d6eb77b 100644 --- a/mingpt/utils.py +++ b/mingpt/utils.py @@ -1,14 +1,12 @@ -import os import sys import json import random +from pathlib import Path from ast import literal_eval import numpy as np import torch -# ----------------------------------------------------------------------------- - def set_seed(seed): random.seed(seed) @@ -19,15 +17,16 @@ def set_seed(seed): def setup_logging(config): """monotonous bookkeeping""" - work_dir = config.system.work_dir + work_dir = Path(config.system.work_dir) + # create the work directory if it doesn't already exist - os.makedirs(work_dir, exist_ok=True) + work_dir.mkdir(parents=True, exist_ok=True) + # log the args (if any) - with open(os.path.join(work_dir, "args.txt"), "w") as f: - f.write(" ".join(sys.argv)) + (work_dir / "args.txt").write_text(" ".join(sys.argv)) + # log the config itself - with open(os.path.join(work_dir, "config.json"), "w") as f: - f.write(json.dumps(config.to_dict(), indent=4)) + (work_dir / "config.json").write_text(json.dumps(config.to_dict(), indent=4)) class CfgNode: @@ -76,7 +75,6 @@ def merge_from_args(self, args): --model.n_layer=10 --trainer.batch_size=32 """ for arg in args: - keyval = arg.split("=") assert len(keyval) == 2, ( "expecting each override arg to be of form --arg=value, got %s" % arg diff --git a/projects/adder/adder.py b/projects/adder/adder.py index 90d9dd1e..816d0063 100644 --- a/projects/adder/adder.py +++ b/projects/adder/adder.py @@ -5,6 +5,7 @@ import os import sys import json +from pathlib import Path import torch from torch.utils.data import Dataset @@ -14,11 +15,8 @@ from mingpt.trainer import Trainer from mingpt.utils import set_seed, setup_logging, CfgNode as CN -# ----------------------------------------------------------------------------- - def get_config(): - C = CN() # system @@ -208,6 +206,7 @@ def eval_split(trainer, split, max_batches=None): # iteration callback top_score = 0 + ckpt_path = Path(config.system.work_dir) / "model.pt" def batch_end_callback(trainer): global top_score @@ -233,7 +232,6 @@ def batch_end_callback(trainer): if score > top_score: top_score = score print(f"saving model with new top score of {score}") - ckpt_path = os.path.join(config.system.work_dir, "model.pt") torch.save(model.state_dict(), ckpt_path) # revert model to training mode model.train() From 8e111e74da0dc3efc41cc24394b48534c609099c Mon Sep 17 00:00:00 2001 From: Matt Stancliff Date: Sun, 24 Jul 2022 12:27:13 -0700 Subject: [PATCH 05/12] Allow callbacks to store state in trainer Using globals to store data between calls? What?! The proper way is for callbacks to have access to an internal aribtrary state object for their own accounting purposes. this does thusly. --- mingpt/trainer.py | 5 ++++- projects/adder/adder.py | 7 +++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mingpt/trainer.py b/mingpt/trainer.py index da0ca9a1..de8c5471 100644 --- a/mingpt/trainer.py +++ b/mingpt/trainer.py @@ -29,12 +29,15 @@ def get_default_config(): C.grad_norm_clip = 1.0 return C - def __init__(self, config, model, train_dataset): + def __init__(self, config, model, train_dataset, state=None) -> None: self.config = config self.model = model self.train_dataset = train_dataset self.callbacks = defaultdict(list) + # any state used for callback reporting across batches + self.state = state or {} + # determine the device we'll train on if config.device == "auto": self.device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/projects/adder/adder.py b/projects/adder/adder.py index 816d0063..aa4f74d8 100644 --- a/projects/adder/adder.py +++ b/projects/adder/adder.py @@ -204,12 +204,11 @@ def eval_split(trainer, split, max_batches=None): ) return rt.sum() - # iteration callback - top_score = 0 ckpt_path = Path(config.system.work_dir) / "model.pt" def batch_end_callback(trainer): - global top_score + history = trainer.state + top_score = history.get("top_score", 0) if trainer.iter_num % 10 == 0: print( @@ -230,8 +229,8 @@ def batch_end_callback(trainer): score = train_score + test_score # save the model if this is the best score we've seen so far if score > top_score: - top_score = score print(f"saving model with new top score of {score}") + history["top_score"] = score torch.save(model.state_dict(), ckpt_path) # revert model to training mode model.train() From 8e9d737cc8b18e39fa2e35d089dfbc0c8bd1a0b9 Mon Sep 17 00:00:00 2001 From: Matt Stancliff Date: Sun, 24 Jul 2022 12:29:29 -0700 Subject: [PATCH 06/12] Use explicit digit counts in adder Small minor non-important fix, but it makes the intent more clear and shows people there's more python built-ins to explore. --- projects/adder/adder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/projects/adder/adder.py b/projects/adder/adder.py index aa4f74d8..d4552894 100644 --- a/projects/adder/adder.py +++ b/projects/adder/adder.py @@ -5,6 +5,7 @@ import os import sys import json +import string from pathlib import Path import torch @@ -95,7 +96,7 @@ def __init__(self, config, split): self.ixes = perm[:num_test] if split == "test" else perm[num_test:] def get_vocab_size(self): - return 10 # digits 0..9 + return len(string.digits) def get_block_size(self): # a,b,a+b, and +1 due to potential carry overflow, From 8f05460427fb6eb5eacc40f910ecf5e1c7a187d8 Mon Sep 17 00:00:00 2001 From: Matt Stancliff Date: Sun, 24 Jul 2022 12:30:48 -0700 Subject: [PATCH 07/12] Use modern python string formatting in adder The new syntax makes formatting intent much clearer and easier to understand than legacy python string formatting syntax. The new syntax only requires knowing like 2 concepts while the old formatting syntax requires knowing 8 different things about how strings and formatting and replacements and substitutions happen in legacy python strings. --- projects/adder/adder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/projects/adder/adder.py b/projects/adder/adder.py index d4552894..1af96a26 100644 --- a/projects/adder/adder.py +++ b/projects/adder/adder.py @@ -117,9 +117,9 @@ def __getitem__(self, idx): # calculate the "label" of the addition problem a + b c = a + b # encode the digits of a, b, c into strings - astr = f"%0{ndigit}d" % a - bstr = f"%0{ndigit}d" % b - cstr = (f"%0{ndigit+1}d" % c)[::-1] # reverse c to make addition easier + astr = f"{a:0{ndigit}d}" + bstr = f"{b:0{ndigit}d}" + cstr = (f"{c:0{ndigit + 1}d}")[::-1] # reverse c to make addition easier render = astr + bstr + cstr dix = [int(s) for s in render] # convert each character to its token index # x will be input to GPT and y will be the associated expected outputs From 4d4ad74956bebb61c5cc130ba25fc363291ad5cd Mon Sep 17 00:00:00 2001 From: Matt Stancliff Date: Sun, 24 Jul 2022 12:31:49 -0700 Subject: [PATCH 08/12] Improve readability in more places Hints for improving readability: - don't smash 10+ lines together without line breaks - if you have comments for a block, make it a visual block with newlines instead of: #comment\ncode\n#comment\ncode\ncode\n#comment\ncode\n.... - don't create any globals or code under __main__, just call a helper function. this also helps when wanting to turn python entrypoints into single callable scripts with paraemters since a function entry point can be called easier (a la Fire, etc) - avoid end-of-line comments where possible because it makes lines too long then the code formatter uses worse formatting. prefer "# big comment\ncode" instead of "code # big comment" --- mingpt/trainer.py | 6 ++- projects/adder/adder.py | 86 +++++++++++++++++++++++++---------------- 2 files changed, 57 insertions(+), 35 deletions(-) diff --git a/mingpt/trainer.py b/mingpt/trainer.py index de8c5471..7c047bd5 100644 --- a/mingpt/trainer.py +++ b/mingpt/trainer.py @@ -16,10 +16,13 @@ class Trainer: @staticmethod def get_default_config(): C = CN() + # device to train on C.device = "auto" + # dataloder parameters C.num_workers = 4 + # optimizer parameters C.max_iters = None C.batch_size = 64 @@ -43,8 +46,9 @@ def __init__(self, config, model, train_dataset, state=None) -> None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = config.device + self.model = self.model.to(self.device) - print("running on device", self.device) + print("running on device:", self.device) # variables that will be assigned to trainer class later for logging and etc self.iter_num: int = 0 diff --git a/projects/adder/adder.py b/projects/adder/adder.py index 1af96a26..54918951 100644 --- a/projects/adder/adder.py +++ b/projects/adder/adder.py @@ -34,9 +34,9 @@ def get_config(): # trainer C.trainer = Trainer.get_default_config() - C.trainer.learning_rate = ( - 5e-4 # the model we're using is so small that we can go a bit faster - ) + + # the model we're using is so small that we can go a bit faster + C.trainer.learning_rate = 5e-4 return C @@ -75,7 +75,7 @@ def get_default_config(): C.ndigit = 2 return C - def __init__(self, config, split): + def __init__(self, config, split) -> None: self.config = config self.split = split # train/test @@ -84,15 +84,16 @@ def __init__(self, config, split): assert ( ndigit <= 3 ), "the lines below would be very memory inefficient, in future maybe refactor to support" - num = ( - 10**ndigit - ) ** 2 # total number of possible addition problems with ndigit numbers + + # total number of possible addition problems with ndigit numbers + num = (10**ndigit) ** 2 rng = torch.Generator() rng.manual_seed(1337) perm = torch.randperm(num, generator=rng) - num_test = min( - int(num * 0.2), 500 - ) # 20% of the whole dataset, or only up to 500 + + # 20% of the whole dataset, or only up to 500 + num_test = min(int(num * 0.2), 500) + self.ixes = perm[:num_test] if split == "test" else perm[num_test:] def get_vocab_size(self): @@ -114,33 +115,34 @@ def __getitem__(self, idx): nd = 10**ndigit a = idx // nd b = idx % nd + # calculate the "label" of the addition problem a + b c = a + b + # encode the digits of a, b, c into strings astr = f"{a:0{ndigit}d}" bstr = f"{b:0{ndigit}d}" cstr = (f"{c:0{ndigit + 1}d}")[::-1] # reverse c to make addition easier render = astr + bstr + cstr dix = [int(s) for s in render] # convert each character to its token index + # x will be input to GPT and y will be the associated expected outputs x = torch.tensor(dix[:-1], dtype=torch.long) - y = torch.tensor( - dix[1:], dtype=torch.long - ) # predict the next token in the sequence - y[ - : ndigit * 2 - 1 - ] = -1 # we will only train in the output locations. -1 will mask loss to zero - return x, y + # predict the next token in the sequence + y = torch.tensor(dix[1:], dtype=torch.long) -# ----------------------------------------------------------------------------- + # we will only train in the output locations. -1 will mask loss to zero + y[: ndigit * 2 - 1] = -1 + return x, y -if __name__ == "__main__": +def cmd() -> None: # get default config and overrides from the command line, if any config = get_config() config.merge_from_args(sys.argv[1:]) print(config) + setup_logging(config) set_seed(config.system.seed) @@ -165,44 +167,53 @@ def eval_split(trainer, split, max_batches=None): factors = torch.tensor([[10**i for i in range(ndigit + 1)][::-1]]).to( trainer.device ) + loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False) for b, (x, y) in enumerate(loader): x = x.to(trainer.device) + # isolate the first two digits of the input sequence alone d1d2 = x[:, : ndigit * 2] + # let the model sample the rest of the sequence - d1d2d3 = model.generate( - d1d2, ndigit + 1, do_sample=False - ) # using greedy argmax, not sampling + # using greedy argmax, not sampling + d1d2d3 = model.generate(d1d2, ndigit + 1, do_sample=False) + # isolate the last digit of the sampled sequence d3 = d1d2d3[:, -(ndigit + 1) :] d3 = d3.flip(1) # reverse the digits to their "normal" order + # decode the integers from individual digits d1i = (d1d2[:, :ndigit] * factors[:, 1:]).sum(1) d2i = (d1d2[:, ndigit : ndigit * 2] * factors[:, 1:]).sum(1) d3i_pred = (d3 * factors).sum(1) d3i_gt = d1i + d2i # manually calculate the ground truth + # evaluate the correctness of the results in this batch - correct = ( - d3i_pred == d3i_gt - ).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha + # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha + correct = (d3i_pred == d3i_gt).cpu() + for i in range(x.size(0)): results.append(int(correct[i])) - if ( - not correct[i] and mistakes_printed_already < 5 - ): # only print up to 5 mistakes to get a sense + + # only print up to 5 mistakes to get a sense + if not correct[i] and mistakes_printed_already < 5: mistakes_printed_already += 1 print( - "GPT claims that %d + %d = %d but gt is %d" + "GPT claims %d + %d = %d but gt is %d" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i]) ) + if max_batches is not None and b + 1 >= max_batches: break + rt = torch.tensor(results, dtype=torch.float) + print( "%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100 * rt.mean()) ) + return rt.sum() ckpt_path = Path(config.system.work_dir) / "model.pt" @@ -213,26 +224,29 @@ def batch_end_callback(trainer): if trainer.iter_num % 10 == 0: print( - f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}" + f"iter_dt {trainer.iter_dt * 1000:.2f} ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}" ) if trainer.iter_num % 500 == 0: # evaluate both the train and test score - train_max_batches = {1: None, 2: None, 3: 5}[ - config.data.ndigit - ] # if ndigit=2 we can afford the whole train set, ow no + # if ndigit <= 2 we can afford the whole train set, else limit + train_max_batches = 5 if config.data.ndigit > 2 else None model.eval() + with torch.no_grad(): train_score = eval_split( trainer, "train", max_batches=train_max_batches ) test_score = eval_split(trainer, "test", max_batches=None) + score = train_score + test_score + # save the model if this is the best score we've seen so far if score > top_score: - print(f"saving model with new top score of {score}") history["top_score"] = score + print(f"saving model with new top score: {score}") torch.save(model.state_dict(), ckpt_path) + # revert model to training mode model.train() @@ -240,3 +254,7 @@ def batch_end_callback(trainer): # run the optimization trainer.run() + + +if __name__ == "__main__": + cmd() From 2f89bbb8406c6e5b94cbcc003436c669dd636cb4 Mon Sep 17 00:00:00 2001 From: Matt Stancliff Date: Sun, 24 Jul 2022 12:44:23 -0700 Subject: [PATCH 09/12] Use proper README.md file name conventions --- projects/{readme.md => README.md} | 0 projects/adder/{readme.md => README.md} | 0 projects/chargpt/{readme.md => README.md} | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename projects/{readme.md => README.md} (100%) rename projects/adder/{readme.md => README.md} (100%) rename projects/chargpt/{readme.md => README.md} (100%) diff --git a/projects/readme.md b/projects/README.md similarity index 100% rename from projects/readme.md rename to projects/README.md diff --git a/projects/adder/readme.md b/projects/adder/README.md similarity index 100% rename from projects/adder/readme.md rename to projects/adder/README.md diff --git a/projects/chargpt/readme.md b/projects/chargpt/README.md similarity index 100% rename from projects/chargpt/readme.md rename to projects/chargpt/README.md From eed8054132d0dd3ed0f21c8f3a655862cef1f869 Mon Sep 17 00:00:00 2001 From: Matt Stancliff Date: Sun, 24 Jul 2022 12:51:32 -0700 Subject: [PATCH 10/12] Add poetry environment for package Also added 'adder' helper runner (see below) Enables: - automated package buliding: poetry build - automated command running - example: poetry run adder - all the other modern package management things like dependency verification, isolated install environments, etc --- .gitignore | 2 ++ pyproject.toml | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 pyproject.toml diff --git a/.gitignore b/.gitignore index 8ec38c0f..ee36fc20 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ __pycache__/ .env .pylintrc out/ +poetry.lock +dist/ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..060afec1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,20 @@ +[tool.poetry] +name = "mingpt" +version = "3.3.3" +description = "A PyTorch re-implementation of GPT, both training and inference. minGPT tries to be small, clean, interpretable and educational." +authors = [] +license = "Apache-2.0" + +[tool.poetry.dependencies] +python = "^3.8" +torch = "^1.12.0" + +[tool.poetry.dev-dependencies] +transformers = "^4.20.1" + +[tool.poetry.scripts] +adder = "projects.adder.adder:cmd" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" From bf44c172b03bfea1a1d98ded6a60eca1f90828dd Mon Sep 17 00:00:00 2001 From: Matt Stancliff Date: Sun, 24 Jul 2022 13:02:11 -0700 Subject: [PATCH 11/12] Cleanup chargpt and add it to poetry launcher poetry run chargpt (if you have 'input.txt' in your local directory) Also includes the basic Path/style/formatting/cmd refactors made to adder. --- .gitignore | 1 + projects/chargpt/chargpt.py | 35 +++++++++++++++++++++-------------- pyproject.toml | 1 + 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index ee36fc20..bcfeedc1 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ __pycache__/ out/ poetry.lock dist/ +input.txt diff --git a/projects/chargpt/chargpt.py b/projects/chargpt/chargpt.py index 29dbf01d..eb6de687 100644 --- a/projects/chargpt/chargpt.py +++ b/projects/chargpt/chargpt.py @@ -4,6 +4,7 @@ import os import sys +from pathlib import Path import torch from torch.utils.data import Dataset @@ -13,11 +14,8 @@ from mingpt.trainer import Trainer from mingpt.utils import set_seed, setup_logging, CfgNode as CN -# ----------------------------------------------------------------------------- - def get_config(): - C = CN() # system @@ -34,9 +32,9 @@ def get_config(): # trainer C.trainer = Trainer.get_default_config() - C.trainer.learning_rate = ( - 5e-4 # the model we're using is so small that we can go a bit faster - ) + + # the model we're using is so small that we can go a bit faster + C.trainer.learning_rate = 5e-4 return C @@ -87,10 +85,7 @@ def __getitem__(self, idx): return x, y -# ----------------------------------------------------------------------------- - -if __name__ == "__main__": - +def cmd(): # get default config and overrides from the command line, if any config = get_config() config.merge_from_args(sys.argv[1:]) @@ -99,7 +94,7 @@ def __getitem__(self, idx): set_seed(config.system.seed) # construct the training dataset - text = open("input.txt", "r").read() # don't worry we won't run out of file handles + text = Path("input.txt").read_text() train_dataset = CharDataset(config.data, text) # construct the model @@ -109,13 +104,13 @@ def __getitem__(self, idx): # construct the trainer object trainer = Trainer(config.trainer, model, train_dataset) + ckpt_path = Path(config.system.work_dir) / "model.pt" # iteration callback def batch_end_callback(trainer): - if trainer.iter_num % 10 == 0: print( - f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}" + f"iter_dt {trainer.iter_dt * 1000:.2f} ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}" ) if trainer.iter_num % 500 == 0: @@ -124,16 +119,24 @@ def batch_end_callback(trainer): with torch.no_grad(): # sample from the model... context = "O God, O God!" + + # Note: your input.txt must have all characters in 'context' for this to work + # because this finds the index of each character in 'context' inside + # input.txt, so if input.txt doesn't have, say, capital O or ! then + # you'll get a KeyError here. x = torch.tensor( [train_dataset.stoi[s] for s in context], dtype=torch.long )[None, ...].to(trainer.device) + y = model.generate(x, 500, temperature=1.0, do_sample=True, top_k=10)[0] + completion = "".join([train_dataset.itos[int(i)] for i in y]) print(completion) + # save the latest model print("saving model") - ckpt_path = os.path.join(config.system.work_dir, "model.pt") torch.save(model.state_dict(), ckpt_path) + # revert model to training mode model.train() @@ -141,3 +144,7 @@ def batch_end_callback(trainer): # run the optimization trainer.run() + + +if __name__ == "__main__": + cmd() diff --git a/pyproject.toml b/pyproject.toml index 060afec1..dbe26c3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ transformers = "^4.20.1" [tool.poetry.scripts] adder = "projects.adder.adder:cmd" +chargpt = "projects.chargpt.chargpt:cmd" [build-system] requires = ["poetry-core>=1.0.0"] From e87c30e538c3dddd2f7d592766519428d4e667c0 Mon Sep 17 00:00:00 2001 From: Matt Stancliff Date: Sun, 24 Jul 2022 13:19:34 -0700 Subject: [PATCH 12/12] Refactor BPE and add to poetry runner Usable via: poetry run bpe Fixed same issues as cleaning up everything else: - pathlib.Path everywhere instead of direct os module file manipulation - improves code readability (the entire purpose of code, after all) - improves output readability too with more logical breaks for reading/understanding what's actually happening --- mingpt/bpe.py | 94 +++++++++++++++++++++++++++++++++++++------------- pyproject.toml | 1 + 2 files changed, 71 insertions(+), 24 deletions(-) diff --git a/mingpt/bpe.py b/mingpt/bpe.py index a0c92a66..09b625a0 100644 --- a/mingpt/bpe.py +++ b/mingpt/bpe.py @@ -8,15 +8,13 @@ going on. """ -import os import json import regex as re import requests +from pathlib import Path import torch -# ----------------------------------------------------------------------------- - def bytes_to_unicode(): """ @@ -40,6 +38,7 @@ def bytes_to_unicode(): + list(range(ord("®"), ord("ÿ") + 1)) ) cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict + # now get the representations of the other 68 integers that do need shifting # each will get mapped chr(256 + n), where n will grow from 0...67 in the loop n = 0 @@ -51,6 +50,7 @@ def bytes_to_unicode(): n += 1 cs = [chr(n) for n in cs] d = dict(zip(bs, cs)) + return d @@ -63,6 +63,7 @@ def get_pairs(word): for char in word[1:]: pairs.add((prev_char, char)) prev_char = char + return pairs @@ -71,11 +72,14 @@ def __init__(self, encoder, bpe_merges): # byte encoder/decoder self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + # bpe token encoder/decoder self.encoder = encoder self.decoder = {v: k for k, v in self.encoder.items()} + # bpe merge list that defines the bpe "tree", of tuples (a,b) that are to merge to token ab self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + # the splitting pattern used for pre-tokenization # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions <-- original openai comment """ @@ -119,11 +123,12 @@ def bpe(self, token): return token while True: - # find the next lowest rank bigram that can be merged bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) if bigram not in self.bpe_ranks: - break # no more bigrams are eligible to be merged + # no more bigrams are eligible to be merged + break + first, second = bigram # we will now replace all occurences of (first, second) in the list of current @@ -154,8 +159,8 @@ def bpe(self, token): word = new_word if len(word) == 1: break - else: - pairs = get_pairs(word) + + pairs = get_pairs(word) # concat all words into a string, and use ' ' as the separator. Note that # by now all characters have been byte encoded, guaranteeing that ' ' is @@ -171,18 +176,24 @@ def encode(self, text): bpe_idx = [] # pre-tokenize the input text into string tokens (words, roughly speaking) tokens = re.findall(self.pat, text) + # process each token into BPE integers for token in tokens: # encode the token as a bytes (b'') object token_bytes = token.encode("utf-8") + # translate all bytes to their unicode string representation and flatten token_translated = "".join(self.byte_encoder[b] for b in token_bytes) + # perform all the applicable bpe merges according to self.bpe_ranks token_merged = self.bpe(token_translated).split(" ") + # translate all bpe tokens to integers token_ix = [self.encoder[bpe_token] for bpe_token in token_merged] + # extend our running list of all output integers bpe_idx.extend(token_ix) + return bpe_idx def encode_and_show_work(self, text): @@ -216,20 +227,23 @@ def decode(self, bpe_idx): """list of integers comes in, string comes out""" # inverse map the integers to get the tokens tokens_merged = [self.decoder[token] for token in bpe_idx] + # inverse the byte encoder, e.g. recovering 'Ġ' -> ' ', and get the bytes tokens_flat = "".join(tokens_merged) tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat]) + # recover the full utf-8 string text = tokens_bytes.decode("utf-8", errors="replace") + return text def get_file(local_file, remote_file): """downloads remote_file to local_file if necessary""" - if not os.path.isfile(local_file): + if not Path(local_file).is_file(): print(f"downloading {remote_file} to {local_file}") response = requests.get(remote_file) - open(local_file, "wb").write(response.content) + Path(local_file).write_bytes(response.content) def get_encoder(): @@ -237,31 +251,30 @@ def get_encoder(): Returns an instance of the GPT BPE Encoder/Decoder and handles caching of "database" files. """ - home_dir = os.path.expanduser("~") - cache_dir = os.path.join(home_dir, ".cache", "mingpt") - os.makedirs(cache_dir, exist_ok=True) + cache_dir = Path.home() / ".cache" / "mingpt" + cache_dir.mkdir(parents=True, exist_ok=True) # load encoder.json that has the raw mappings from token -> bpe index - encoder_local_file = os.path.join(cache_dir, "encoder.json") + encoder_local_file = cache_dir / "encoder.json" encoder_remote_file = ( "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json" ) get_file(encoder_local_file, encoder_remote_file) - with open(encoder_local_file, "r") as f: - encoder = json.load(f) - assert ( - len(encoder) == 50257 - ) # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token + encoder = json.loads(Path(encoder_local_file).read_text()) + + # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token + assert len(encoder) == 50257 # load vocab.bpe that contains the bpe merges, i.e. the bpe tree structure # in the form tuples (a, b), that indicate that (a, b) is to be merged to one token ab - vocab_local_file = os.path.join(cache_dir, "vocab.bpe") + vocab_local_file = cache_dir / "vocab.bpe" vocab_remote_file = ( "https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe" ) + get_file(vocab_local_file, vocab_remote_file) - with open(vocab_local_file, "r", encoding="utf-8") as f: - bpe_data = f.read() + bpe_data = vocab_local_file.read_text(encoding="utf-8") + # light postprocessing: strip the version on first line and the last line is a blank bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] assert len(bpe_merges) == 50000 # 50,000 merged tokens @@ -283,37 +296,60 @@ def __init__(self): def __call__(self, text, return_tensors="pt"): # PyTorch only; here because we want to match huggingface/transformers interface assert return_tensors == "pt" + # single string input for now, in the future potentially a list of strings assert isinstance(text, str) + # encode and create a "batch dimension" of 1 idx = [self.encoder.encode(text)] + # wrap into PyTorch tensor out = torch.tensor(idx, dtype=torch.long) + return out def decode(self, idx): # ensure a simple 1D tensor for now assert idx.ndim == 1 + # decode indices to text text = self.encoder.decode(idx.tolist()) + return text -if __name__ == "__main__": +def cmd(): + def brk(): + """Visual breaking of different content sections for easier reading.""" + print() + + def section(): + """Visual separation of a large output block for easier reading.""" + print("============================================================") # here is an encoding example text = "Hello!! I'm Andrej Karpathy. It's 2022. w00t :D 🤗" e = get_encoder() r = e.encode_and_show_work(text) - print("Original text is:") + print("Original text:") print(text) - print("First the text gets pre-tokenized, broken up into chunks, the outcome is:") + + brk() + + print("First the text gets pre-tokenized then broken up into chunks:") print(r["tokens"]) + + brk() + # ['Hello', '!!', ' I', "'m", ' Andrej', ' Karpathy', '.', ' It', "'s", ' 2022', '.', ' w', '00', 't', ' :', 'D', ' 🤗'] print("Then we iterate over each chunk and process them in turn...") + + section() for part in r["parts"]: print(part) + section() + # {'token': 'Hello', 'token_bytes': b'Hello', 'token_translated': 'Hello', 'token_merged': ['Hello'], 'token_ix': [15496]} # {'token': '!!', 'token_bytes': b'!!', 'token_translated': '!!', 'token_merged': ['!!'], 'token_ix': [3228]} # {'token': ' I', 'token_bytes': b' I', 'token_translated': 'ĠI', 'token_merged': ['ĠI'], 'token_ix': [314]} @@ -332,8 +368,18 @@ def decode(self, idx): # {'token': 'D', 'token_bytes': b'D', 'token_translated': 'D', 'token_merged': ['D'], 'token_ix': [35]} # {'token': ' 🤗', 'token_bytes': b' \xf0\x9f\xa4\x97', 'token_translated': 'Ġð٤Ĺ', 'token_merged': ['ĠðŁ', '¤', 'Ĺ'], 'token_ix': [12520, 97, 245]} # (refer to the code inside Encoder.encode for what these intermediates are) + + brk() + print("and the final outcome is concatenating and flattening all the token_ix:") print(r["bpe_idx"]) + + brk() + # [15496, 3228, 314, 1101, 10948, 73, 509, 5117, 10036, 13, 632, 338, 33160, 13, 266, 405, 83, 1058, 35, 12520, 97, 245] # this would then become the integer input sequence to the transformer print("ready to feed into a Transformer!") + + +if __name__ == "__main__": + cmd() diff --git a/pyproject.toml b/pyproject.toml index dbe26c3f..bd43485e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ transformers = "^4.20.1" [tool.poetry.scripts] adder = "projects.adder.adder:cmd" chargpt = "projects.chargpt.chargpt:cmd" +bpe = "mingpt.bpe:cmd" [build-system] requires = ["poetry-core>=1.0.0"]