Skip to content

Commit 7860084

Browse files
committed
refactor speechLM
1 parent a6a2df8 commit 7860084

8 files changed

Lines changed: 94 additions & 180 deletions

File tree

README.md

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,8 @@ audio_values = decoder(units)
6262
import torch
6363
import torchaudio
6464
from textless.data.speech_encoder import SpeechEncoder
65-
from tokenizers import Tokenizer
6665
from transformers import LlamaForCausalLM
6766

68-
from src.speechlm.utils import convert_units_to_unicode
69-
7067
wav_path = "/path/to/wav"
7168

7269
encoder = SpeechEncoder.by_name(
@@ -77,21 +74,14 @@ encoder = SpeechEncoder.by_name(
7774
need_f0=False,
7875
).cuda()
7976

80-
# BPE tokenizer
81-
tokenizer = Tokenizer.from_file("/path/to/pretrained/tokenizer.json")
82-
8377
model = LlamaForCausalLM.from_pretrained("/path/to/pretrained/model").cuda()
8478

8579
# load a waveform
8680
waveform, sr = torchaudio.load(wav_path)
8781
waveform = torchaudio.functional.resample(waveform, sr, 16000)
8882

8983
# encode a waveform into pseudo-phonetic units
90-
units = encoder(waveform.cuda())["units"].tolist()
91-
unicodes = convert_units_to_unicode(units)
92-
93-
# BPE
94-
input_ids = tokenizer.encode(unicodes).ids
84+
input_ids = encoder(waveform.cuda())["units"].tolist()
9585
input_ids = torch.tensor([input_ids], device="cuda") + 2 # 0: pad, 1: EOS
9686

9787
# Speech LM

configs/speechlm/hubert.yaml

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@ dataset:
22
wav_dir_train: "data/librilight"
33
ext_audio: ".flac"
44

5-
unicode_train: "data/speechlm/hubert/unicode/train"
6-
train_file: "data/speechlm/hubert/unit/train.txt"
5+
train: "ryota-komatsu/librilight"
76
units_per_sample: 125
87

9-
swuggy_dev_file: "data/speechlm/hubert/unit/lexical/dev.json"
10-
sblimp_dev_file: "data/speechlm/hubert/unit/syntactic/dev.json"
11-
swuggy_test_file: "data/speechlm/hubert/unit/lexical/test.json"
12-
sblimp_test_file: "data/speechlm/hubert/unit/syntactic/test.json"
8+
swuggy: "ryota-komatsu/swuggy" # lexical
9+
sblimp: "ryota-komatsu/sblimp" # syntactic
1310

1411
APP_DIR: "data/zr-data"
1512
result_dir: "results/speechlm/hubert"
@@ -19,7 +16,7 @@ dataloader:
1916

2017
model:
2118
path: "models/speechlm/hubert"
22-
vocab_size: 8192 # BPE vocab size
19+
vocab_size: ${s2u.vocab_size}
2320
hidden_size: 768
2421
intermediate_size: 2048 # 4 * hidden_size * 2 / 3
2522
num_hidden_layers: 12
@@ -45,7 +42,4 @@ s2u:
4542
dense_model_name: "hubert-base-ls960"
4643
quantizer_model_name: "kmeans"
4744
vocab_size: 100
48-
49-
tokenizer_path: "models/speechlm/hubert/tokenizer.json"
50-
5145
num_workers: 16

main_speechlm.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,14 @@
22
from omegaconf import OmegaConf
33

44
from src.speechlm.eval import evaluate
5-
from src.speechlm.tokenize import encode, tokenize, tokenize_slm21
5+
from src.speechlm.tokenize import tokenize_slm21, tokenize_trainset
66
from src.speechlm.train import train
77

88

99
class TaskRunner:
10-
def encode(self, config: str = "configs/speechlm/hubert.yaml", spkids: str = "1-9"):
10+
def tokenize_trainset(self, config: str = "configs/speechlm/hubert.yaml"):
1111
config = OmegaConf.load(config)
12-
encode(config, spkids)
13-
14-
def tokenize(self, config: str = "configs/speechlm/hubert.yaml"):
15-
config = OmegaConf.load(config)
16-
tokenize(config)
12+
tokenize_trainset(config)
1713

1814
def tokenize_slm21(self, config: str = "configs/speechlm/hubert.yaml"):
1915
config = OmegaConf.load(config)
@@ -29,8 +25,7 @@ def eval(self, config: str = "configs/speechlm/hubert.yaml"):
2925

3026
def __call__(self, config: str = "configs/speechlm/hubert.yaml", spkids: str = "1-9"):
3127
config = OmegaConf.load(config)
32-
encode(config, spkids)
33-
tokenize(config)
28+
tokenize_trainset(config, spkids)
3429
tokenize_slm21(config)
3530
train(config)
3631

src/speechlm/data.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import random
2-
from typing import Any, Dict
2+
from typing import Any, Dict, Optional
33

44
import torch
55
import torchaudio
@@ -40,45 +40,37 @@ def collate_fn(batch):
4040
}
4141

4242

43-
class UnitDataset(torch.utils.data.Dataset):
44-
def __init__(
45-
self,
46-
files,
47-
units_per_sample: int,
48-
num_special_tokens: int = 2,
49-
eos_token_id: int = 1,
50-
):
51-
self.input_ids = []
52-
for file in files:
53-
with open(file) as f:
54-
for units in f:
55-
units = units.rstrip().split()
56-
units = torch.tensor([int(u) + num_special_tokens for u in units] + [eos_token_id])
57-
self.input_ids.append(units)
58-
59-
self.units_per_sample = units_per_sample
43+
def get_collate_fn(
44+
num_special_tokens: int = 2,
45+
pad_token_id: int = 0,
46+
units_per_sample: Optional[int] = None,
47+
):
48+
def collate_fn(batch) -> Dict[str, torch.LongTensor]:
49+
input_ids = []
50+
names = []
6051

61-
def __len__(self) -> int:
62-
return len(self.input_ids)
52+
for item in batch:
53+
units = torch.tensor(item["units"]) + num_special_tokens
6354

64-
def __getitem__(self, n: int) -> Dict[str, torch.Tensor]:
65-
input_ids = self.input_ids[n]
66-
attention_mask = torch.ones_like(input_ids)
55+
if units_per_sample:
56+
diff = len(units) - units_per_sample
6757

68-
diff = len(input_ids) - self.units_per_sample
58+
if diff > 0:
59+
start = random.randrange(diff)
60+
units = units[start : start + units_per_sample]
6961

70-
if diff > 0:
71-
start = random.randrange(diff)
72-
input_ids = input_ids[start : start + self.units_per_sample]
73-
attention_mask = attention_mask[start : start + self.units_per_sample]
74-
else:
75-
input_ids = torch.nn.functional.pad(input_ids, (0, -diff))
76-
attention_mask = torch.nn.functional.pad(attention_mask, (0, -diff))
62+
input_ids.append(units)
63+
names.append(item["id"])
7764

78-
labels = input_ids.masked_fill(input_ids.eq(0), -100)
65+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
66+
attention_mask = input_ids.ne(pad_token_id).long()
67+
labels = input_ids.masked_fill(input_ids.eq(pad_token_id), -100)
7968

8069
return {
8170
"input_ids": input_ids,
8271
"attention_mask": attention_mask,
8372
"labels": labels,
73+
"names": names,
8474
}
75+
76+
return collate_fn

src/speechlm/eval.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import pandas as pd
55
import torch
66
import torch.nn.functional as F
7+
from datasets import load_dataset
78
from transformers import LlamaForCausalLM
89

9-
from .utils import load_named_units_from_json
10+
from .data import get_collate_fn
1011

1112

1213
def evaluate(config):
@@ -22,17 +23,21 @@ def evaluate(config):
2223

2324
_eval(
2425
model,
25-
config.dataset.swuggy_test_file,
26+
config.dataset.swuggy,
27+
"test",
2628
Path(config.dataset.result_dir) / "lexical/test.txt",
2729
config.dataloader.batch_size_per_device,
2830
num_special_tokens,
31+
config.model.pad_token_id,
2932
)
3033
_eval(
3134
model,
32-
config.dataset.sblimp_test_file,
35+
config.dataset.sblimp,
36+
"test",
3337
Path(config.dataset.result_dir) / "syntactic/test.txt",
3438
config.dataloader.batch_size_per_device,
3539
num_special_tokens,
40+
config.model.pad_token_id,
3641
)
3742

3843
subprocess.run(
@@ -71,15 +76,24 @@ def evaluate(config):
7176
def _eval(
7277
model: LlamaForCausalLM,
7378
in_file,
79+
split: str,
7480
out_file,
7581
batch_size: int,
7682
num_special_tokens: int = 2,
83+
pad_token_id: int = 0,
7784
):
85+
dataset = load_dataset(in_file, split=split)
86+
loader = torch.utils.data.DataLoader(
87+
dataset,
88+
batch_size,
89+
collate_fn=get_collate_fn(num_special_tokens=num_special_tokens, pad_token_id=pad_token_id),
90+
)
91+
7892
with open(out_file, "w") as f:
79-
for batch in load_named_units_from_json(in_file, batch_size, num_special_tokens):
93+
for batch in loader:
8094
# Speech LM
8195
input_ids = batch["input_ids"].cuda()
82-
labels = input_ids.masked_fill(input_ids.eq(0), -100)
96+
labels = batch["labels"].cuda()
8397
logits = model(input_ids=input_ids, labels=labels).logits.transpose(1, 2)
8498

8599
labels = F.pad(labels, (0, 1), value=-100)

src/speechlm/tokenize.py

Lines changed: 28 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,11 @@
1-
import glob
2-
import json
31
from pathlib import Path
42

53
import torch
4+
from datasets import Dataset, DatasetDict, Features, Sequence, Value
65
from textless.data.speech_encoder import SpeechEncoder
7-
from tokenizers import Tokenizer
8-
from tokenizers.models import BPE
9-
from tokenizers.trainers import BpeTrainer
106
from tqdm import tqdm
117

128
from .data import SpeechDataset
13-
from .utils import convert_units_to_unicode, shift_unit
14-
15-
16-
def tokenize(config):
17-
Path(config.s2u.tokenizer_path).parent.mkdir(parents=True, exist_ok=True)
18-
19-
files = glob.glob(config.dataset.unicode_train + "*")
20-
initial_alphabet = [chr(shift_unit(unit)) for unit in range(config.s2u.vocab_size)]
21-
trainer = BpeTrainer(vocab_size=config.model.vocab_size, initial_alphabet=initial_alphabet)
22-
tokenizer = Tokenizer(BPE())
23-
tokenizer.train(files=files, trainer=trainer)
24-
tokenizer.save(config.s2u.tokenizer_path)
25-
26-
Path(config.dataset.train_file).parent.mkdir(parents=True, exist_ok=True)
27-
with open(config.dataset.train_file, "w") as f:
28-
for file in files:
29-
with open(file) as g:
30-
for unicodes in g:
31-
unicodes = unicodes.rstrip()
32-
units = tokenizer.encode(unicodes).ids
33-
units = " ".join(str(u) for u in units)
34-
35-
f.write(f"{units}\n")
369

3710

3811
def tokenize_slm21(config):
@@ -63,36 +36,41 @@ def tokenize_slm21(config):
6336
deduplicate=True,
6437
need_f0=False,
6538
).cuda()
66-
tokenizer = Tokenizer.from_file(config.s2u.tokenizer_path)
6739

68-
_tokenize_slm21(encoder, tokenizer, config.dataset.swuggy_dev_file, swuggy_dev_loader)
69-
_tokenize_slm21(encoder, tokenizer, config.dataset.sblimp_dev_file, sblimp_dev_loader)
70-
_tokenize_slm21(encoder, tokenizer, config.dataset.swuggy_test_file, swuggy_test_loader)
71-
_tokenize_slm21(encoder, tokenizer, config.dataset.sblimp_test_file, sblimp_test_loader)
40+
swuggy_dev = _tokenize(encoder, swuggy_dev_loader)
41+
sblimp_dev = _tokenize(encoder, sblimp_dev_loader)
42+
swuggy_test = _tokenize(encoder, swuggy_test_loader)
43+
sblimp_test = _tokenize(encoder, sblimp_test_loader)
7244

45+
swuggy = DatasetDict({"dev": swuggy_dev, "test": swuggy_test})
46+
sblimp = DatasetDict({"dev": sblimp_dev, "test": sblimp_test})
7347

74-
def _tokenize_slm21(
48+
swuggy.push_to_hub(config.dataset.swuggy)
49+
sblimp.push_to_hub(config.dataset.sblimp)
50+
51+
52+
def _tokenize(
7553
encoder: SpeechEncoder,
76-
tokenizer: Tokenizer,
77-
file,
7854
data_loader: torch.utils.data.DataLoader,
7955
):
80-
Path(file).parent.mkdir(parents=True, exist_ok=True)
81-
82-
dataset = dict()
83-
84-
for item in tqdm(data_loader):
85-
outputs = encoder(item["input_values"].cuda())
86-
unicodes = convert_units_to_unicode(outputs["units"].tolist())
87-
input_ids = tokenizer.encode(unicodes).ids
56+
features = Features(
57+
{
58+
"id": Value("string"),
59+
"units": Sequence(Value("int32")),
60+
}
61+
)
62+
63+
def generate_dataset():
64+
for item in tqdm(data_loader):
65+
outputs = encoder(item["input_values"].cuda())
66+
units = outputs["units"].tolist()
8867

89-
dataset[item["name"][0]] = input_ids
68+
yield {"id": item["name"][0], "units": units}
9069

91-
with open(file, "w") as f:
92-
json.dump(dataset, f)
70+
return Dataset.from_generator(generate_dataset, features=features)
9371

9472

95-
def encode(config, spk_ids: str = "1-9"):
73+
def tokenize_trainset(config, spk_ids: str = "1-9"):
9674
wav_dir_train = Path(config.dataset.wav_dir_train)
9775
train_paths = wav_dir_train.glob(f"*/[{spk_ids}]*/**/*" + config.dataset.ext_audio)
9876
train_set = SpeechDataset(train_paths)
@@ -106,15 +84,5 @@ def encode(config, spk_ids: str = "1-9"):
10684
need_f0=False,
10785
).cuda()
10886

109-
_encode(encoder, config.dataset.unicode_train + f"{spk_ids}", train_loader)
110-
111-
112-
def _encode(encoder: SpeechEncoder, file, data_loader: torch.utils.data.DataLoader):
113-
Path(file).parent.mkdir(parents=True, exist_ok=True)
114-
with open(file, "w") as f:
115-
for item in tqdm(data_loader):
116-
outputs = encoder(item["input_values"].cuda())
117-
118-
unicodes = convert_units_to_unicode(outputs["units"].tolist())
119-
120-
f.write(f"{unicodes}\n")
87+
trainset = _tokenize(encoder, train_loader)
88+
trainset.push_to_hub(config.dataset.train, split=f"train{spk_ids}")

0 commit comments

Comments
 (0)