Skip to content

Refactor for modern 2022 python style and usage #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@ __pycache__/
*.swp
.env
.pylintrc
out/
poetry.lock
dist/
input.txt
198 changes: 132 additions & 66 deletions mingpt/bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
going on.
"""

import os
import json
import regex as re
import requests
from pathlib import Path

import torch

# -----------------------------------------------------------------------------

def bytes_to_unicode():
"""
Expand All @@ -33,21 +32,28 @@ def bytes_to_unicode():
like 'Ā', or 'Ġ', etc.
"""
# the 188 integers that render fine in their original form and need no shifting
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict

# now get the representations of the other 68 integers that do need shifting
# each will get mapped chr(256 + n), where n will grow from 0...67 in the loop
n = 0
for b in range(2**8):
if b not in bs:
# if this byte is "ugly" then map it to the next available "nice" character
bs.append(b)
cs.append(2**8+n)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
d = dict(zip(bs, cs))

return d


def get_pairs(word):
"""
Return all bigrams as a set of tuples, of consecutive elements in the iterable word.
Expand All @@ -57,19 +63,23 @@ def get_pairs(word):
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char

return pairs

class Encoder:

class Encoder:
def __init__(self, encoder, bpe_merges):
# byte encoder/decoder
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}

# bpe token encoder/decoder
self.encoder = encoder
self.decoder = {v:k for k,v in self.encoder.items()}
self.decoder = {v: k for k, v in self.encoder.items()}

# bpe merge list that defines the bpe "tree", of tuples (a,b) that are to merge to token ab
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))

# the splitting pattern used for pre-tokenization
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions <-- original openai comment
"""
Expand All @@ -89,7 +99,9 @@ def __init__(self, encoder, bpe_merges):
- we are special casing a few common apostrophe constructs ('s, 't, 're, ...) and making those into separate tokens
- we then separate out strings into consecutive chunks of 1) letters, 2) numbers, 3) non-letter-numbers, 4) whitespaces
"""
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self.cache = {}

def bpe(self, token):
Expand All @@ -104,18 +116,19 @@ def bpe(self, token):
if token in self.cache:
return self.cache[token]

word = tuple(token) # individual characters that make up the token, in a tuple
pairs = get_pairs(word) # get all bigrams
word = tuple(token) # individual characters that make up the token, in a tuple
pairs = get_pairs(word) # get all bigrams

if not pairs:
return token

while True:

# find the next lowest rank bigram that can be merged
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break # no more bigrams are eligible to be merged
# no more bigrams are eligible to be merged
break

first, second = bigram

# we will now replace all occurences of (first, second) in the list of current
Expand All @@ -134,8 +147,8 @@ def bpe(self, token):
break

# if this occurence is also followed by second, then merge them into one
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
Expand All @@ -146,154 +159,197 @@ def bpe(self, token):
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)

pairs = get_pairs(word)

# concat all words into a string, and use ' ' as the separator. Note that
# by now all characters have been byte encoded, guaranteeing that ' ' is
# not used in the actual data and is a 'special' delimiter character
word = ' '.join(word)
word = " ".join(word)

# cache the result and return
self.cache[token] = word
return word

def encode(self, text):
""" string goes in, list of integers comes out """
"""string goes in, list of integers comes out"""
bpe_idx = []
# pre-tokenize the input text into string tokens (words, roughly speaking)
tokens = re.findall(self.pat, text)

# process each token into BPE integers
for token in tokens:
# encode the token as a bytes (b'') object
token_bytes = token.encode('utf-8')
token_bytes = token.encode("utf-8")

# translate all bytes to their unicode string representation and flatten
token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
token_translated = "".join(self.byte_encoder[b] for b in token_bytes)

# perform all the applicable bpe merges according to self.bpe_ranks
token_merged = self.bpe(token_translated).split(' ')
token_merged = self.bpe(token_translated).split(" ")

# translate all bpe tokens to integers
token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]

# extend our running list of all output integers
bpe_idx.extend(token_ix)

return bpe_idx

def encode_and_show_work(self, text):
""" debugging function, same as encode but returns all intermediate work """
"""debugging function, same as encode but returns all intermediate work"""
bpe_idx = []
parts = []
tokens = re.findall(self.pat, text)
for token in tokens:
token_bytes = token.encode('utf-8')
token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
token_merged = self.bpe(token_translated).split(' ')
token_bytes = token.encode("utf-8")
token_translated = "".join(self.byte_encoder[b] for b in token_bytes)
token_merged = self.bpe(token_translated).split(" ")
token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]
bpe_idx.extend(token_ix)
parts.append({
'token': token,
'token_bytes': token_bytes,
'token_translated': token_translated,
'token_merged': token_merged,
'token_ix': token_ix,
})
parts.append(
{
"token": token,
"token_bytes": token_bytes,
"token_translated": token_translated,
"token_merged": token_merged,
"token_ix": token_ix,
}
)
out = {
'bpe_idx': bpe_idx, # the actual output sequence
'tokens': tokens, # result of pre-tokenization
'parts': parts, # intermediates for each token part
"bpe_idx": bpe_idx, # the actual output sequence
"tokens": tokens, # result of pre-tokenization
"parts": parts, # intermediates for each token part
}
return out

def decode(self, bpe_idx):
""" list of integers comes in, string comes out """
"""list of integers comes in, string comes out"""
# inverse map the integers to get the tokens
tokens_merged = [self.decoder[token] for token in bpe_idx]

# inverse the byte encoder, e.g. recovering 'Ġ' -> ' ', and get the bytes
tokens_flat = ''.join(tokens_merged)
tokens_flat = "".join(tokens_merged)
tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat])

# recover the full utf-8 string
text = tokens_bytes.decode('utf-8', errors='replace')
text = tokens_bytes.decode("utf-8", errors="replace")

return text


def get_file(local_file, remote_file):
""" downloads remote_file to local_file if necessary """
if not os.path.isfile(local_file):
"""downloads remote_file to local_file if necessary"""
if not Path(local_file).is_file():
print(f"downloading {remote_file} to {local_file}")
response = requests.get(remote_file)
open(local_file, "wb").write(response.content)
Path(local_file).write_bytes(response.content)


def get_encoder():
"""
Returns an instance of the GPT BPE Encoder/Decoder
and handles caching of "database" files.
"""
home_dir = os.path.expanduser('~')
cache_dir = os.path.join(home_dir, '.cache', 'mingpt')
os.makedirs(cache_dir, exist_ok=True)
cache_dir = Path.home() / ".cache" / "mingpt"
cache_dir.mkdir(parents=True, exist_ok=True)

# load encoder.json that has the raw mappings from token -> bpe index
encoder_local_file = os.path.join(cache_dir, 'encoder.json')
encoder_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json'
encoder_local_file = cache_dir / "encoder.json"
encoder_remote_file = (
"https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json"
)
get_file(encoder_local_file, encoder_remote_file)
with open(encoder_local_file, 'r') as f:
encoder = json.load(f)
assert len(encoder) == 50257 # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token
encoder = json.loads(Path(encoder_local_file).read_text())

# 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token
assert len(encoder) == 50257

# load vocab.bpe that contains the bpe merges, i.e. the bpe tree structure
# in the form tuples (a, b), that indicate that (a, b) is to be merged to one token ab
vocab_local_file = os.path.join(cache_dir, 'vocab.bpe')
vocab_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe'
vocab_local_file = cache_dir / "vocab.bpe"
vocab_remote_file = (
"https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe"
)

get_file(vocab_local_file, vocab_remote_file)
with open(vocab_local_file, 'r', encoding="utf-8") as f:
bpe_data = f.read()
bpe_data = vocab_local_file.read_text(encoding="utf-8")

# light postprocessing: strip the version on first line and the last line is a blank
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
assert len(bpe_merges) == 50000 # 50,000 merged tokens
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
assert len(bpe_merges) == 50000 # 50,000 merged tokens

# construct the Encoder object and return
enc = Encoder(encoder, bpe_merges)
return enc


# -----------------------------------------------------------------------------


class BPETokenizer:
""" PyTorch-aware class that wraps the Encoder above """
"""PyTorch-aware class that wraps the Encoder above"""

def __init__(self):
self.encoder = get_encoder()

def __call__(self, text, return_tensors='pt'):
def __call__(self, text, return_tensors="pt"):
# PyTorch only; here because we want to match huggingface/transformers interface
assert return_tensors == 'pt'
assert return_tensors == "pt"

# single string input for now, in the future potentially a list of strings
assert isinstance(text, str)

# encode and create a "batch dimension" of 1
idx = [self.encoder.encode(text)]

# wrap into PyTorch tensor
out = torch.tensor(idx, dtype=torch.long)

return out

def decode(self, idx):
# ensure a simple 1D tensor for now
assert idx.ndim == 1

# decode indices to text
text = self.encoder.decode(idx.tolist())

return text


if __name__ == '__main__':
def cmd():
def brk():
"""Visual breaking of different content sections for easier reading."""
print()

def section():
"""Visual separation of a large output block for easier reading."""
print("============================================================")

# here is an encoding example
text = "Hello!! I'm Andrej Karpathy. It's 2022. w00t :D 🤗"
e = get_encoder()
r = e.encode_and_show_work(text)

print("Original text is:")
print("Original text:")
print(text)
print("First the text gets pre-tokenized, broken up into chunks, the outcome is:")
print(r['tokens'])

brk()

print("First the text gets pre-tokenized then broken up into chunks:")
print(r["tokens"])

brk()

# ['Hello', '!!', ' I', "'m", ' Andrej', ' Karpathy', '.', ' It', "'s", ' 2022', '.', ' w', '00', 't', ' :', 'D', ' 🤗']
print("Then we iterate over each chunk and process them in turn...")
for part in r['parts']:

section()
for part in r["parts"]:
print(part)
section()

# {'token': 'Hello', 'token_bytes': b'Hello', 'token_translated': 'Hello', 'token_merged': ['Hello'], 'token_ix': [15496]}
# {'token': '!!', 'token_bytes': b'!!', 'token_translated': '!!', 'token_merged': ['!!'], 'token_ix': [3228]}
# {'token': ' I', 'token_bytes': b' I', 'token_translated': 'ĠI', 'token_merged': ['ĠI'], 'token_ix': [314]}
Expand All @@ -312,8 +368,18 @@ def decode(self, idx):
# {'token': 'D', 'token_bytes': b'D', 'token_translated': 'D', 'token_merged': ['D'], 'token_ix': [35]}
# {'token': ' 🤗', 'token_bytes': b' \xf0\x9f\xa4\x97', 'token_translated': 'Ġð٤Ĺ', 'token_merged': ['ĠðŁ', '¤', 'Ĺ'], 'token_ix': [12520, 97, 245]}
# (refer to the code inside Encoder.encode for what these intermediates are)

brk()

print("and the final outcome is concatenating and flattening all the token_ix:")
print(r['bpe_idx'])
print(r["bpe_idx"])

brk()

# [15496, 3228, 314, 1101, 10948, 73, 509, 5117, 10036, 13, 632, 338, 33160, 13, 266, 405, 83, 1058, 35, 12520, 97, 245]
# this would then become the integer input sequence to the transformer
print("ready to feed into a Transformer!")


if __name__ == "__main__":
cmd()
Loading