-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
87 lines (70 loc) · 2.96 KB
/
utils.py
File metadata and controls
87 lines (70 loc) · 2.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import numpy as np
from sklearn.cluster import KMeans
from config import PARAS
def loss_function_simple(embedding, target):
"""
Implement loss function
:param embedding: N * TF * Embedding Dim
:param target: N * TF * 1 (vocal)
:return: Loss value for one batch N * scalar
"""
def create_diag(target_m):
"""
create dialog
:param target_m: N * TF * 1 (vocal)
:return: N * TF * TF
"""
d_m = torch.bmm(target_m, torch.transpose(target_m, 1, 2))
d_m = torch.sum(d_m, dim=2) # notice there is batch
d_m = torch.diag_embed(d_m)
return torch.sqrt(d_m)
def f2_norm(x):
norm = torch.norm(x, 2)
return norm ** 2
diags = create_diag(target)
n, tf, _ = embedding.shape
part1 = f2_norm(torch.bmm(torch.bmm(torch.transpose(embedding, 1, 2), diags), embedding))
part2 = f2_norm(torch.bmm(torch.bmm(torch.transpose(embedding, 1, 2), diags), target))
part3 = f2_norm(torch.bmm(torch.bmm(torch.transpose(target, 1, 2), diags), target))
return abs(part1 - 2 * part2 + part3) / (n*tf)
def loss_function(embedding, target):
"""
This is the original function, which may need large GPU memory
:param embedding: N * TF * Embedding Dim
:param target: N * TF * 1 (vocal)
:return: Loss value for one batch N * scalar
"""
n, tf, _ = embedding.shape
# ans = torch.bmm(embedding, torch.transpose(embedding, 1, 2)).sub(torch.bmm(target, torch.transpose(target, 1, 2)))
loss1 = torch.bmm(torch.transpose(embedding, 1, 2), embedding)
loss2 = torch.bmm(torch.transpose(target, 1, 2), target)
loss3 = torch.bmm(torch.transpose(embedding, 1, 2), target)
result = torch.sum(torch.norm(loss1, 2)**2 + torch.norm(loss2, 2)**2 - 2 * torch.norm(loss3, 2)**2)
return result / (n*tf*2019)
criterion = torch.nn.BCELoss()
def mask_scale_loss(mask: torch.Tensor, target: torch.Tensor):
"""
This function returns the loss defined by intersection over union
:param mask: the binary mask generated by model N * T * F * 2 (voice prob, background prob)
:param target: the target should be N * T * F * 1
:return: we should generate a IoU mask over the voice source
"""
music_mask, _ = torch.unbind(mask, dim=3) # split the output mask
return criterion(music_mask, target)
def mask_scale_loss_unet(mask: torch.Tensor, target: torch.Tensor):
return criterion(mask, target)
def embedding_to_mask(embedding_out):
"""
Convert embedding out as a binary mask
:param embedding_out: tensor, TF * Embedding Dim
:return: mask, which T * F
"""
tmp = embedding_out.view((PARAS.N_MEL, PARAS.N_MEL, PARAS.E_DIM))
tmp = tmp.numpy()
tmp.resize((PARAS.N_MEL ** 2, PARAS.E_DIM))
k_means_client = KMeans(n_clusters=2, random_state=0).fit(tmp)
mask = k_means_client.labels_.copy()
mask = np.resize(mask, (PARAS.N_MEL, PARAS.N_MEL))
r_mask = np.ones(mask.shape) - mask
return mask, r_mask