1515
1616from utils import vocab_size
1717from data .midiencoder import QuantizedMidiEncoder
18- from data .multitokencoder import MultiVelocityEncoder
18+ from data .multitokencoder import MultiMidiEncoder
1919from data .quantizer import MidiQuantizer , MidiATQuantizer
2020from data .dataset import MaskedMidiDataset , load_cache_dataset
2121from data .maskedmidiencoder import MaskedMidiEncoder , MaskedNoteEncoder
@@ -105,7 +105,7 @@ def model_predictions_review(
105105 n_dstart_bins = dataset_cfg .quantization .dstart ,
106106 )
107107 if train_cfg .tokens_per_note == "multiple" :
108- base_tokenizer = MultiVelocityEncoder (
108+ base_tokenizer = MultiMidiEncoder (
109109 quantization_cfg = train_cfg .dataset .quantization ,
110110 time_quantization_method = train_cfg .time_quantization_method ,
111111 )
@@ -164,6 +164,9 @@ def model_predictions_review(
164164
165165 # predict velocities and get src, tgt and model output
166166 print ("Making predictions ..." )
167+
168+ # widget id for streamlit_pianoroll widget
169+ key = 0
167170 for record_id in idxs :
168171 # Numpy to int :(
169172 record : dict = dataset .get_complete_record (int (record_id ))
@@ -186,21 +189,24 @@ def model_predictions_review(
186189 pred_piece = MidiPiece (df )
187190
188191 except ValueError :
189- generated_df = pd .DataFrame ([[23 , 1 , 1 , 1 , 1 ]], columns = midi_columns )
192+ generated_df = pd .DataFrame ([[23.0 , 1.0 , 1.0 , 1.0 , 1.0 ]], columns = midi_columns )
190193 generated_df ["mask" ] = [False ]
191194 pred_piece = MidiPiece (generated_df )
192195
193196 pred_piece .source = true_piece .source .copy ()
194197
195198 # create a dashboard
196- st .json (record_source )
199+ st .json (record_source , expanded = False )
197200 cols = st .columns (2 )
198201
199202 source_tokens : list [str ] = [dataset .encoder .vocab [idx ] for idx in src_token_ids ]
200203 tgt_tokens : list [str ] = [dataset .encoder .vocab [idx ] for idx in record ["target_token_ids" ]]
201204 generated_tokens : list [str ] = [dataset .encoder .vocab [idx ] for idx in generated_token_ids ]
205+
202206 with cols [0 ]:
203- from_fortepyan (true_piece )
207+ fig = ff .view .draw_pianoroll_with_velocities (true_piece )
208+ st .pyplot (fig )
209+ from_fortepyan (true_piece , key = key )
204210 # Unchanged
205211 st .markdown ("**Source tokens:**" )
206212 st .markdown (source_tokens )
@@ -211,9 +217,10 @@ def model_predictions_review(
211217 # Predicted
212218 fig = ff .view .draw_dual_pianoroll (pred_piece )
213219 st .pyplot (fig )
214- from_fortepyan (pred_piece )
220+ from_fortepyan (pred_piece , key = key + 1 )
215221 st .markdown ("**Predicted tokens:**" )
216222 st .markdown (generated_tokens )
223+ key += 2
217224
218225
219226if __name__ == "__main__" :
0 commit comments