Skip to content

Commit 6abb305

Browse files
authored
Merge pull request #5 from Nospoko/MIDI-126/finetuning
Midi 126/finetuning
2 parents 129fd82 + fee4c32 commit 6abb305

File tree

17 files changed

+314
-36
lines changed

17 files changed

+314
-36
lines changed

configs/T5denoise-dstart.yaml

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
train:
2+
num_epochs: 5
3+
accum_iter: 5
4+
batch_size: 2
5+
base_lr: 3e-5
6+
warmup: 4000
7+
finetune: False
8+
9+
model_name: T5
10+
dataset_name: 'roszcz/maestro-v1-sustain'
11+
target: denoise
12+
seed: 26
13+
14+
overfit: False
15+
16+
tokens_per_note: single
17+
time_quantization_method: dstart
18+
masking_probability: 0.2
19+
mask: tokens
20+
21+
encoder: velocity
22+
time_bins: 100
23+
24+
dataset:
25+
sequence_len: 128
26+
sequence_step: 42
27+
28+
quantization:
29+
dstart: 5
30+
duration: 5
31+
velocity: 3
32+
33+
device: "cuda:0"
34+
35+
log: True
36+
log_frequency: 10
37+
run_name: midi-T5-${now:%Y-%m-%d-%H-%M}
38+
project: "midi-hf-transformer"
39+
40+
pre_defined_model: null
41+
model:
42+
d_model: 512
43+
d_kv: 64
44+
d_ff: 2048
45+
num_layers: 6
46+
num_heads: 8

configs/T5denoise.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ train:
44
batch_size: 8
55
base_lr: 3e-5
66
warmup: 4000
7+
finetune: False
78

89
model_name: T5
910
dataset_name: 'roszcz/maestro-v1-sustain'
@@ -17,14 +18,17 @@ time_quantization_method: start
1718
masking_probability: 0.15
1819
mask: notes
1920

21+
encoder: velocity
22+
time_bins: 100
23+
2024
dataset:
2125
sequence_duration: 5
2226
sequence_step: 2
2327

2428
quantization:
25-
start: 20
29+
start: 50
2630
duration: 5
27-
velocity: 5
31+
velocity: 3
2832

2933
device: "cuda:0"
3034

@@ -33,6 +37,8 @@ log_frequency: 10
3337
run_name: midi-T5-${now:%Y-%m-%d-%H-%M}
3438
project: "midi-hf-transformer"
3539

40+
pre_defined_model: null
41+
3642
model:
3743
d_model: 512
3844
d_kv: 64

configs/T5start.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ log_frequency: 10
3131
run_name: midi-T5-${now:%Y-%m-%d-%H-%M}
3232
project: "midi-hf-transformer"
3333

34+
pre_defined_model: null
35+
3436
model:
3537
d_model: 512
3638
d_kv: 64

configs/T5velocity-dstart.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@ train:
22
num_epochs: 5
33
accum_iter: 10
44
batch_size: 8
5-
base_lr: 1e-5
5+
base_lr: 3e-5
6+
finetune: True
67
warmup: 4000
78

9+
10+
pretrained_checkpoint: midi-T5-2023-11-15-17-18.pt
811
model_name: T5
912
dataset_name: 'roszcz/maestro-v1-sustain'
1013
target: velocity
1114
seed: 26
15+
time_bins: 100
1216

1317
overfit: False
1418

@@ -30,6 +34,8 @@ log_frequency: 10
3034
run_name: midi-T5-${now:%Y-%m-%d-%H-%M}
3135
project: "midi-hf-transformer"
3236

37+
pre_defined_model: null
38+
3339
model:
3440
d_model: 512
3541
d_kv: 64

configs/T5velocity.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ train:
44
batch_size: 8
55
base_lr: 3e-4
66
warmup: 4000
7+
finetune: True
78

9+
pretrained_checkpoint: midi-T5-2023-11-11-10-29.pt
810
model_name: T5
911
dataset_name: 'roszcz/maestro-v1-sustain'
1012
target: velocity
@@ -30,6 +32,8 @@ log_frequency: 10
3032
run_name: midi-T5-${now:%Y-%m-%d-%H-%M}
3133
project: "midi-hf-transformer"
3234

35+
pre_defined_model: null
36+
3337
model:
3438
d_model: 512
3539
d_kv: 64

configs/architectures/large.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
d_model: 512
2+
d_kv: 64
3+
d_ff: 2048
4+
num_layers: 6
5+
num_heads: 8

configs/architectures/mid.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
d_model: 256
2+
d_kv: 32
3+
d_ff: 1024
4+
num_layers: 6
5+
num_heads: 8

configs/architectures/small.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
d_model: 256
2+
d_kv: 32
3+
d_ff: 512
4+
num_layers: 4
5+
num_heads: 4

dashboard/denoise/main.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from utils import vocab_size
1717
from data.midiencoder import QuantizedMidiEncoder
18-
from data.multitokencoder import MultiVelocityEncoder
18+
from data.multitokencoder import MultiMidiEncoder
1919
from data.quantizer import MidiQuantizer, MidiATQuantizer
2020
from data.dataset import MaskedMidiDataset, load_cache_dataset
2121
from data.maskedmidiencoder import MaskedMidiEncoder, MaskedNoteEncoder
@@ -105,7 +105,7 @@ def model_predictions_review(
105105
n_dstart_bins=dataset_cfg.quantization.dstart,
106106
)
107107
if train_cfg.tokens_per_note == "multiple":
108-
base_tokenizer = MultiVelocityEncoder(
108+
base_tokenizer = MultiMidiEncoder(
109109
quantization_cfg=train_cfg.dataset.quantization,
110110
time_quantization_method=train_cfg.time_quantization_method,
111111
)
@@ -164,6 +164,9 @@ def model_predictions_review(
164164

165165
# predict velocities and get src, tgt and model output
166166
print("Making predictions ...")
167+
168+
# widget id for streamlit_pianoroll widget
169+
key = 0
167170
for record_id in idxs:
168171
# Numpy to int :(
169172
record: dict = dataset.get_complete_record(int(record_id))
@@ -186,21 +189,24 @@ def model_predictions_review(
186189
pred_piece = MidiPiece(df)
187190

188191
except ValueError:
189-
generated_df = pd.DataFrame([[23, 1, 1, 1, 1]], columns=midi_columns)
192+
generated_df = pd.DataFrame([[23.0, 1.0, 1.0, 1.0, 1.0]], columns=midi_columns)
190193
generated_df["mask"] = [False]
191194
pred_piece = MidiPiece(generated_df)
192195

193196
pred_piece.source = true_piece.source.copy()
194197

195198
# create a dashboard
196-
st.json(record_source)
199+
st.json(record_source, expanded=False)
197200
cols = st.columns(2)
198201

199202
source_tokens: list[str] = [dataset.encoder.vocab[idx] for idx in src_token_ids]
200203
tgt_tokens: list[str] = [dataset.encoder.vocab[idx] for idx in record["target_token_ids"]]
201204
generated_tokens: list[str] = [dataset.encoder.vocab[idx] for idx in generated_token_ids]
205+
202206
with cols[0]:
203-
from_fortepyan(true_piece)
207+
fig = ff.view.draw_pianoroll_with_velocities(true_piece)
208+
st.pyplot(fig)
209+
from_fortepyan(true_piece, key=key)
204210
# Unchanged
205211
st.markdown("**Source tokens:**")
206212
st.markdown(source_tokens)
@@ -211,9 +217,10 @@ def model_predictions_review(
211217
# Predicted
212218
fig = ff.view.draw_dual_pianoroll(pred_piece)
213219
st.pyplot(fig)
214-
from_fortepyan(pred_piece)
220+
from_fortepyan(pred_piece, key=key + 1)
215221
st.markdown("**Predicted tokens:**")
216222
st.markdown(generated_tokens)
223+
key += 2
217224

218225

219226
if __name__ == "__main__":

dashboard/download_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from huggingface_hub import hf_hub_download
22

3-
FILENAME_VELOCITY = "midi-T5-2023-10-20-16-03.pt"
4-
FILENAME_DENOISE = "midi-T5-2023-11-07-12-53.pt"
3+
FILENAME_VELOCITY = "velocity-T5-2023-11-11-10-29.pt"
4+
FILENAME_DENOISE = "midi-T5-2023-11-11-10-29.pt"
5+
56

67
hf_hub_download(
78
repo_id="wmatejuk/midi-T5-velocity",
89
filename=FILENAME_VELOCITY,
910
local_dir="checkpoints/velocity",
1011
local_dir_use_symlinks=False,
1112
)
12-
1313
hf_hub_download(
1414
repo_id="wmatejuk/midi-T5-denoise",
1515
filename=FILENAME_DENOISE,

0 commit comments

Comments
 (0)