Skip to content

getting state_dict from checkpoint can't do inferencing properly #72

@qichilu

Description

@qichilu

Following are the codes that I load the checkpoint and do inferencing but not working (no error message but no audio). Anything wrong with de codes? Please help:

from future import annotations

from typing import Any, Dict, Tuple, Union, Optional
import os
import torch
import yaml
from torch import nn
from vocos.heads import FourierHead
from vocos.models import Backbone
from vocos.feature_extractors import FeatureExtractor

def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
"""Instantiates a class with the given args and init.

Args:
    args: Positional arguments required for instantiation.
    init: Dict of the form {"class_path":...,"init_args":...}.

Returns:
    The instantiated class object.
"""
kwargs = init.get("init_args", {})
if not isinstance(args, tuple):
    args = (args,)
class_module, class_name = init["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
args_class = getattr(module, class_name)
return args_class(*args, **kwargs)

class Vocos(nn.Module):
"""
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
This class is primarily designed for inference, with support for loading from pretrained
model checkpoints. It consists of three main components: a feature extractor,
a backbone, and a head.
"""

def __init__(
    self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead,
):
    super().__init__()
    self.feature_extractor = feature_extractor
    self.backbone = backbone
    self.head = head

@classmethod
def from_hparams(cls, config_path: str) -> Vocos:
    """
    Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
    """
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    feature_extractor = instantiate_class(args=(), init=config["feature_extractor"])
    backbone = instantiate_class(args=(), init=config["backbone"])
    head = instantiate_class(args=(), init=config["head"])
    model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
    return model

@classmethod
def from_pretrained(cls, config_path: str, ckpt_path: str) -> Vocos:
    """
    Class method to create a new Vocos model instance from a pre-trained model.
    """
    model = cls.from_hparams(config_path)
    model_dict = model.state_dict()
    ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
    state_dict = {k: v for k, v in ckpt["state_dict"].items() if k in model_dict}
    model.load_state_dict(state_dict)
    model.eval()
    return model

@torch.inference_mode()
def forward(self, features_input: torch.Tensor) -> torch.Tensor:
    """
    Method to decode audio waveform from already calculated features. The features input is passed through
    the backbone and the head to reconstruct the audio output.

    Args:
        features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
                                 C denotes the feature dimension, and L is the sequence length.

    Returns:
        Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
    """
    x = self.backbone(features_input)
    audio_output = self.head(x)
    return audio_output

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions