Skip to content

Add GAN structure and training #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion configs/default.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,45 @@
data:
input_dir: "./data/gta"
output_dir: "./data/mels"
valid_input_dir: "./data/valid_gta"

train:
rep_discriminator: 1
discriminator_train_start_steps: 10000
num_workers: 8
batch_size: 16
optimizer: 'adam'
adam:
lr: 0.0001
beta1: 0.5
beta2: 0.9
---
audio:
n_mel_channels: 80
segment_length: 16000
pad_short: 2000
filter_length: 1024
hop_length: 256 # WARNING: this can't be changed.
win_length: 1024
sampling_rate: 22050
mel_fmin: 0.0
mel_fmax: 8000.0

model:
feat_match: 10.0
lambda_adv: 2.5
use_subband_stft_loss: False
feat_loss: False
out_channels: 1
generator_ratio: [8, 8, 4] # for 256 hop size and 22050 sample rate
mult: 256
n_residual_layers: 4
num_D : 3
ndf : 16
n_layers: 3
downsampling_factor: 4
disc_out: 512

train: "/mnt/Karan/ResUnet/data/training"
valid: "/mnt/Karan/ResUnet/data/testing"
log: "logs"
Expand All @@ -7,6 +49,6 @@ checkpoints: "checkpoints"

batch_size: 16
lr: 0.001
RESNET_PLUS_PLUS: True
RESNET_PLUS_PLUS: False
IMAGE_SIZE: 1500
CROP_SIZE: 224
24 changes: 24 additions & 0 deletions core/discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
import torch.nn as nn


class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.discriminator = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU()
)
self.out = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)

def forward(self, x):
'''
returns: (list of 6 features, discriminator score)
we directly predict score without last sigmoid function
since we're using Least Squares GAN (https://arxiv.org/abs/1611.04076)
'''
x = self.discriminator(x)
return self.out(x)
22 changes: 22 additions & 0 deletions core/multiscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
import torch.nn as nn
from utils.utils import weights_init
from .discriminator import Discriminator



class MultiScaleDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.disc1 = Discriminator()
self.disc2 = Discriminator()
self.disc3 = Discriminator()

self.apply(weights_init)

def forward(self, x, start):
results = []
results.append(self.disc1(x[:, : , 0:20, start: start + 40]))
results.append(self.disc2(x[:, :, 20:40, start: start + 40]))
results.append(self.disc3(x[:, :, 40:80, start: start + 40]))
return results
2 changes: 1 addition & 1 deletion core/res_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, channel, filters=[64, 128, 256, 512]):

self.output_layer = nn.Sequential(
nn.Conv2d(filters[0], 1, 1, 1),
nn.Sigmoid(),
# nn.Sigmoid(),
)

def forward(self, x):
Expand Down
2 changes: 1 addition & 1 deletion core/res_unet_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, channel, filters=[32, 64, 128, 256, 512]):

self.aspp_out = ASPP(filters[1], filters[0])

self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1), nn.Sigmoid())
self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1)) # , nn.Sigmoid())

def forward(self, x):
x1 = self.input_layer(x) + self.input_skip(x)
Expand Down
53 changes: 53 additions & 0 deletions dataset/mel_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
import glob
import torch
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader


def create_dataloader(hp, train):
dataset = MelFromDisk(hp, train)

if train:
return DataLoader(dataset=dataset, batch_size=hp.train.batch_size, shuffle=True,
num_workers=0, pin_memory=True, drop_last=True)
else:
return DataLoader(dataset=dataset, batch_size=1, shuffle=False,
num_workers=0, pin_memory=False, drop_last=False)


class MelFromDisk(Dataset):
def __init__(self, hp, train):
self.hp = hp
self.train = train
self.path = hp.data.input_dir if train else hp.data.valid_input_dir
self.wav_list = glob.glob(os.path.join(self.path, '**', '*.npy'), recursive=True)
self.mel_segment_length = hp.model.idim
self.mapping = [i for i in range(len(self.wav_list))]

def __len__(self):
return len(self.wav_list)

def __getitem__(self, idx):
input_mel = self.wav_list[idx]
id = os.path.basename(input_mel).split(".")[0]

input_mel_path = "{}/{}.npy".format(self.hp.data.input_dir, id)
output_mel_path = "{}/{}.npy".format(self.hp.data.output_dir, id)

mel_gt = torch.from_numpy(np.load(output_mel_path))
# mel = torch.load(melpath).squeeze(0) # # [num_mel, T]

mel_gta = torch.from_numpy(np.load(input_mel_path))

max_mel_start = mel_gta.size(1) - self.mel_segment_length
mel_start = random.randint(0, max_mel_start)
mel_end = mel_start + self.mel_segment_length
mel_gta = mel_gta[:, mel_start:mel_end]
mel_gt = mel_gt[:, mel_start:mel_end]

return mel_gta, mel_gt

def shuffle_mapping(self):
random.shuffle(self.mapping)
Loading