Skip to content

Runtime Error on CPU #56

@sauravp

Description

@sauravp

Discussed in #55

Originally posted by sauravp June 9, 2023
I am trying to train this on CPU (on a small dataset) to validate some ideas.

import torch
from musiclm_pytorch import MusicLM, MuLaNTrainer
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer, MuLaNEmbedQuantizer

import random
import numpy as np

device = 'cpu'


audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    max_seq_len = 512
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

mulan.to(device)
mulan.eval()

wavs = torch.randn(5, 1024).to(device)
texts = torch.randint(0, 20000, (5, 512)).to(device)
#print(wavs.shape, texts.shape)

from torch.utils.data import Dataset

class TextAudioDataset(Dataset):
    def __init__(self, wavs, texts):
        super().__init__()
        self.wavs = wavs
        self.texts = texts

    def __len__(self):
        if len(self.wavs) != len(self.texts):
            return -1
        else:
            return len(self.wavs)

    def __getitem__(self, idx):
        return self.wavs[idx], self.texts[idx]


trainer = MuLaNTrainer(
    mulan = mulan,
    dataset = TextAudioDataset(wavs, texts),
    batch_size = 2
)

trainer.to(device)

trainer.train()

I am getting the following error:

RuntimeError: stft input and window must be on the same device but got self on mps:0 and window on cpu

Is there a way to run the entire thing on CPU?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions