-
Notifications
You must be signed in to change notification settings - Fork 267
Open
Description
Discussed in #55
Originally posted by sauravp June 9, 2023
I am trying to train this on CPU (on a small dataset) to validate some ideas.
import torch
from musiclm_pytorch import MusicLM, MuLaNTrainer
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer, MuLaNEmbedQuantizer
import random
import numpy as np
device = 'cpu'
audio_transformer = AudioSpectrogramTransformer(
dim = 512,
depth = 6,
heads = 8,
dim_head = 64,
spec_n_fft = 128,
spec_win_length = 24,
spec_aug_stretch_factor = 0.8
)
text_transformer = TextTransformer(
dim = 512,
depth = 6,
heads = 8,
dim_head = 64,
max_seq_len = 512
)
mulan = MuLaN(
audio_transformer = audio_transformer,
text_transformer = text_transformer
)
mulan.to(device)
mulan.eval()
wavs = torch.randn(5, 1024).to(device)
texts = torch.randint(0, 20000, (5, 512)).to(device)
#print(wavs.shape, texts.shape)
from torch.utils.data import Dataset
class TextAudioDataset(Dataset):
def __init__(self, wavs, texts):
super().__init__()
self.wavs = wavs
self.texts = texts
def __len__(self):
if len(self.wavs) != len(self.texts):
return -1
else:
return len(self.wavs)
def __getitem__(self, idx):
return self.wavs[idx], self.texts[idx]
trainer = MuLaNTrainer(
mulan = mulan,
dataset = TextAudioDataset(wavs, texts),
batch_size = 2
)
trainer.to(device)
trainer.train()I am getting the following error:
RuntimeError: stft input and window must be on the same device but got self on mps:0 and window on cpu
Is there a way to run the entire thing on CPU?
Metadata
Metadata
Assignees
Labels
No labels