Skip to content

Conversation

@SamuelJanas
Copy link

@SamuelJanas SamuelJanas commented Sep 17, 2023

MIDI-86
This is a prototype version of the blogpost. I'll be adding results and images in the TODO sections. Let me know If we would like to include anything else here. I was trying to make it a quick read that would leave you with some intuition on the topic. I also tried to share some overview on the results of the project to make it seem more appealing(?).

@SamuelJanas SamuelJanas self-assigned this Sep 17, 2023
@SamuelJanas SamuelJanas marked this pull request as draft September 17, 2023 11:43
@SamuelJanas
Copy link
Author

SamuelJanas commented Sep 17, 2023

I'm also getting this error trying to run hexo server inside container:

 Nunjucks Error: _posts/Hello-MIDI.md [Line 8, Column 4] unknown block tag: algrtmImgBanner

Tried removing lines with this block tag, but it seems to be an issue with all of them. Let me know if it's something that I'm doing incorrectly or if it's reproducible.

@SamuelJanas SamuelJanas marked this pull request as ready for review September 17, 2023 12:15
@SamuelJanas SamuelJanas requested a review from roszcz September 17, 2023 12:15
@roszcz
Copy link
Member

roszcz commented Sep 17, 2023

I'm also getting this error trying to run hexo server inside container:

 Nunjucks Error: _posts/Hello-MIDI.md [Line 8, Column 4] unknown block tag: algrtmImgBanner

Tried removing lines with this block tag, but it seems to be an issue with all of them. Let me know if it's something that I'm doing uncorecctly or if it's reproducible.

Sorry, I failed to include this information in the README - you have to pull the "theme" code, which is a submodule in this repository:

git submodule init
git submodule update

Copy link
Member

@roszcz roszcz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

I think this post would work better if it was only about ECG data - I feel like the appearance of MIDI will make it confusing to most readers. If you remove it, you got provide a bit more detail about ECG signals - I think it's worth explaining what is a multichannel 1D signal in this context :)

However, an intriguing issue arises from this architecture: sometimes the encoder doesn't capture as much information in the latent space as one might expect. This is because decoders, especially those with high capacity, can become exceedingly proficient at "filling in the gaps," or reconstructing missing or ambiguous information. As a result, the encoder might not learn a rich or informative latent space. Instead, the decoder compensates for the encoder's shortcomings, essentially becoming too good at its job for the encoder to improve.

#### How VQ-VAEs Address This Issue
VQ-VAEs introduce an additional layer of complexity with vector quantization, which addresses this issue. In VQ-VAEs, the encoder's output is not used directly for decoding. Instead, it is mapped to the nearest vector in a predefined codebook. This discrete representation, sourced from the codebook, is then used by the decoder for reconstruction.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a code snippet demonstrating this? As minimalistic as possible :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.layer = nn.Linear(10, 5)
        
    def forward(self, x):
        return self.layer(x)

class Codebook(nn.Module):
    def __init__(self, num_codes, code_dim):
        super(Codebook, self).__init__()
        self.codebook = nn.Parameter(torch.randn(num_codes, code_dim))
        
    def forward(self, z):
        distances = ((z.unsqueeze(1) - self.codebook.detach().unsqueeze(0))**2).sum(-1)
        indices = torch.argmin(distances, dim=1)
        return self.codebook[indices]

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.layer = nn.Linear(5, 10)
        
    def forward(self, z):
        return self.layer(z)

And training showcase as follows:

for epoch in range(1000):
    # Simulated input data
    x = torch.randn(32, 10)
    
    # Forward pass
    z_e = encoder(x)
    z_q = codebook(z_e)
    x_recon = decoder(z_q)
    
    # Reconstruction loss
    loss = ((x - x_recon)**2).mean()
    
    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 100 == 0:
        print(f"Epoch [{epoch+1}/1000], Loss: {loss.item():.4f}")

I feel like it's an elegant way to show how the codebook works on a toy example. If you had something else in mind let me know

@SamuelJanas SamuelJanas requested a review from roszcz September 19, 2023 08:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants