forked from thesujitroy/LipSync3D
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaudio2mesh.py
65 lines (56 loc) · 2.18 KB
/
audio2mesh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from torch import nn
from torch.nn import functional as F
import math
class Audio2mesh(nn.Module):
'''
This is a PyTorch model class representing Audio Encoder, and a geometry decoder mentioned in the paper
'''
def __init__(self):
super(Audio2mesh, self).__init__()
# TODO: Complete the following audio encoder layers as mentioned in the paper
self.audio_encoder = nn.Sequential(
nn.Conv2d(2, 72, (3, 1), (2, 1), (1, 0)),
nn.LeakyReLU(),
nn.Conv2d(72, 108, (3, 1), (2, 1), (1, 0)),
nn.LeakyReLU(),
nn.Conv2d(108, 162, (3, 1), (2, 1), (1, 0)),
nn.LeakyReLU(),
nn.Conv2d(162, 243, (3, 1), (2, 1), (1, 0)),
nn.LeakyReLU(),
nn.Conv2d(243, 256, (3, 1), (2, 1), (1, 0)),
nn.LeakyReLU(),
nn.Conv2d(256, 256, (3, 1), (2, 1), (1, 0)),
nn.LeakyReLU(),
nn.Conv2d(256, 128, (1, 3), (1, 2), (0, 2)),
nn.LeakyReLU(),
nn.Conv2d(128, 64, (1, 3), (1, 2), (0, 2)),
nn.LeakyReLU(),
nn.Conv2d(64, 32, (1, 3), (1, 2), (0, 2)),
nn.LeakyReLU(),
nn.Conv2d(32, 16, (1, 3), (1, 2), (0, 2)),
nn.LeakyReLU(),
nn.Conv2d(16, 8, (1, 3), (1, 2), (0, 2)),
nn.LeakyReLU(),
nn.Conv2d(8, 4, (1, 3), (1, 2), (0, 1)),
nn.LeakyReLU(),
View([-1, 32]),
,)
# TODO: Complete the following geometry decoder mentioned in the paper
self.geometry_decoder = nn.Sequential(
nn.Linear(32, 150),
nn.Dropout(0.5),
nn.Linear(150, 1404)
,)
def forward(self, audio_spectogram, verticref):
'''
Forward pass of model
Input: Audio Spectogram of shape (256,24,2)
Returns: Vertices of shape (1404) ----> 468*3
'''
# TODO: Pass the audio spectogram though the audio encoder and geometry decoder
proj = self.audio_encoder(audio_spectogram)
vertices = self.geometry_decoder(proj)
vertices = vertices.reshape(-1, 468, 3)
vertices_mod = verticref + vertices
return vertices_mod