diff --git a/lipsync/__init__.py b/lipsync/__init__.py index 0b25cb1..a3c3fec 100644 --- a/lipsync/__init__.py +++ b/lipsync/__init__.py @@ -1,5 +1,8 @@ -from lipsync.lipsync import LipSync +""" +The lipsync package initialization. +""" +from lipsync.lipsync import LipSync __all__ = [ 'LipSync', diff --git a/lipsync/audio.py b/lipsync/audio.py index ec49ce2..eb72e1e 100644 --- a/lipsync/audio.py +++ b/lipsync/audio.py @@ -1,30 +1,40 @@ +""" +Audio processing utilities for lipsync. + +Includes loading WAV files, applying preemphasis, +and computing mel-spectrograms with optional normalization. +""" + import librosa import librosa.filters import numpy as np from scipy import signal + from lipsync.hparams import HParams hp = HParams() def load_wav(path: str, sr: int) -> np.ndarray: - """Load a WAV file. + """ + Load a WAV file using librosa. Args: path (str): Path to the WAV file. sr (int): Sampling rate to load the audio at. Returns: - np.ndarray: Audio time series as a 1D numpy array. + np.ndarray: Audio time series as a 1D NumPy array. """ return librosa.core.load(path, sr=sr)[0] def preemphasis_func(wav: np.ndarray, k: float, preemphasize: bool = True) -> np.ndarray: - """Apply a preemphasis filter to the waveform. + """ + Apply a preemphasis filter to the waveform. Args: - wav (np.ndarray): Input waveform as a 1D numpy array. + wav (np.ndarray): Input waveform as a 1D NumPy array. k (float): Preemphasis coefficient. preemphasize (bool): Whether to apply preemphasis or not. @@ -35,17 +45,19 @@ def preemphasis_func(wav: np.ndarray, k: float, preemphasize: bool = True) -> np # This increases the magnitude of high-frequency components. if preemphasize: return signal.lfilter([1, -k], [1], wav) + return wav def melspectrogram(wav: np.ndarray) -> np.ndarray: - """Compute the mel-spectrogram of a waveform. + """ + Compute the mel-spectrogram of a waveform. Args: wav (np.ndarray): Input waveform array. Returns: - np.ndarray: Mel-spectrogram as a 2D numpy array (num_mels x time). + np.ndarray: Mel-spectrogram as a 2D NumPy array (num_mels x time). """ D = _stft(preemphasis_func(wav, hp.preemphasis, hp.preemphasize)) S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db @@ -56,19 +68,26 @@ def melspectrogram(wav: np.ndarray) -> np.ndarray: def _stft(y: np.ndarray) -> np.ndarray: - """Compute the STFT of the given waveform. + """ + Compute the STFT of the given waveform. Args: y (np.ndarray): Input waveform. Returns: - np.ndarray: Complex STFT of y. Shape is (1 + n_fft/2, time). + np.ndarray: Complex STFT of shape (1 + n_fft//2, time). """ - return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=hp.hop_size, win_length=hp.win_size) + return librosa.stft( + y=y, + n_fft=hp.n_fft, + hop_length=hp.hop_size, + win_length=hp.win_size + ) def _linear_to_mel(spectrogram: np.ndarray) -> np.ndarray: - """Convert a linear-scale spectrogram to mel-scale. + """ + Convert a linear-scale spectrogram to mel-scale. Args: spectrogram (np.ndarray): Linear frequency spectrogram. @@ -81,7 +100,8 @@ def _linear_to_mel(spectrogram: np.ndarray) -> np.ndarray: def _build_mel_basis() -> np.ndarray: - """Construct a mel-filter bank. + """ + Construct a mel filter bank. Returns: np.ndarray: Mel filter bank matrix. @@ -97,7 +117,8 @@ def _build_mel_basis() -> np.ndarray: def _amp_to_db(x: np.ndarray) -> np.ndarray: - """Convert amplitude to decibels. + """ + Convert amplitude to decibels. Args: x (np.ndarray): Amplitude values. @@ -110,7 +131,8 @@ def _amp_to_db(x: np.ndarray) -> np.ndarray: def _normalize(spec: np.ndarray) -> np.ndarray: - """Normalize the mel-spectrogram. + """ + Normalize the mel-spectrogram. Args: spec (np.ndarray): Decibel-scaled mel-spectrogram. @@ -139,7 +161,9 @@ def _normalize(spec: np.ndarray) -> np.ndarray: if hp.symmetric_mels: # Symmetric range normalization - return (2 * hp.max_abs_value) * ((spec - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value + return (2 * hp.max_abs_value) * ( + (spec - hp.min_level_db) / (-hp.min_level_db) + ) - hp.max_abs_value else: # Asymmetric range normalization return hp.max_abs_value * ((spec - hp.min_level_db) / (-hp.min_level_db)) diff --git a/lipsync/helpers.py b/lipsync/helpers.py index 3a3375b..ba71198 100644 --- a/lipsync/helpers.py +++ b/lipsync/helpers.py @@ -1,3 +1,7 @@ +""" +Helper utilities for lipsync, including video frame reading and face bounding box extraction. +""" + import av import numpy as np from typing import Tuple, List @@ -45,28 +49,30 @@ def read_frames(face: str) -> Tuple[List[np.ndarray], int]: raise ValueError(f"An error occurred while reading the video file: {e}") from e -def get_face_box(landmarks: list) -> Tuple[int, int, int, int]: +def get_face_box(landmarks: list, face_index: int = 0) -> Tuple[int, int, int, int]: """ Extracts and returns the bounding box coordinates of a detected face. Args: landmarks (list): A list containing facial landmarks where the third element - (index 2) represents the bounding box coordinates. + (index 2) represents the bounding box coordinates. + face_index (int, optional): The index of the face to extract bounding box coordinates. Returns: Tuple[int, int, int, int]: The bounding box coordinates (x1, y1, x2, y2) of the face. Raises: ValueError: If the landmarks list is improperly structured or does not contain - the expected bounding box. + the expected bounding box. """ try: # Extract the face box from the landmarks - face_box = landmarks[2][0] # Access the bounding box coordinates + face_box = landmarks[2][face_index] # Access the bounding box coordinates face_box = np.clip(face_box, 0, None) # Ensure no negative values # Convert bounding box values to integers x1, y1, x2, y2 = map(int, face_box[:-1]) # Exclude the confidence score (last value) return x1, y1, x2, y2 + except (IndexError, TypeError, ValueError) as e: raise ValueError("Invalid landmarks structure. Could not extract face box.") from e diff --git a/lipsync/lipsync.py b/lipsync/lipsync.py index f14260f..a8ffee5 100644 --- a/lipsync/lipsync.py +++ b/lipsync/lipsync.py @@ -1,3 +1,8 @@ +""" +Main class to handle the lip-syncing pipeline, combining video (face) frames +and audio to produce a synchronized output using the Wav2Lip model. +""" + import numpy as np import cv2 import os @@ -6,41 +11,46 @@ import torch import tempfile import pickle +from typing import List, Tuple, Union + +import face_alignment + from lipsync import audio from lipsync.helpers import read_frames, get_face_box from lipsync.models import load_model -from typing import List, Tuple, Union -import face_alignment class LipSync: """ Class for lip-syncing videos using the Wav2Lip model. + + This class encapsulates: + - Face detection + - Audio preprocessing and mel spectrogram creation + - Model inference + - Final video rendering with synced audio """ # Default parameters checkpoint_path: str = '' - static: bool = False fps: float = 25.0 - pads: List[int] = [0, 10, 0, 0] + pads: List[int] = [0, 10, 0, 0] # Face bounding box padding in [x1_pad, y1_pad, x2_pad, y2_pad] wav2lip_batch_size: int = 128 - resize_factor: int = 1 - crop: List[int] = [0, -1, 0, -1] - box: List[int] = [-1, -1, -1, -1] - rotate: bool = False nosmooth: bool = False save_cache: bool = True cache_dir: str = tempfile.gettempdir() - _filepath: str = '' img_size: int = 96 mel_step_size: int = 16 device: str = 'cpu' ffmpeg_loglevel: str = 'verbose' model: str = 'wav2lip' + _filepath: str = '' + def __init__(self, **kwargs): """ - Initializes LipSync with custom parameters. + Initializes LipSync with user-defined or default parameters. + Automatically checks if CUDA is available when device='cuda'. """ device = kwargs.get('device', self.device) self.device = 'cuda' if (device == 'cuda' and torch.cuda.is_available()) else 'cpu' @@ -52,17 +62,28 @@ def __init__(self, **kwargs): @staticmethod def get_smoothened_boxes(boxes: np.ndarray, t: int) -> np.ndarray: """ - Smoothens bounding boxes over a temporal window. + Smoothens bounding boxes over a temporal window by averaging. + + Args: + boxes (np.ndarray): Array of bounding boxes of shape (N, 4). + t (int): Window size for smoothing. + + Returns: + np.ndarray: Smoothed bounding boxes. """ for i in range(len(boxes)): window_end = min(i + t, len(boxes)) window = boxes[i:window_end] boxes[i] = np.mean(window, axis=0) + return boxes def get_cache_filename(self) -> str: """ - Generates a filename for caching face detection results. + Generates a filename for caching face detection results based on input file name. + + Returns: + str: Cache file path. """ filename = os.path.basename(self._filepath) return os.path.join(self.cache_dir, f'{filename}.pk') @@ -70,6 +91,9 @@ def get_cache_filename(self) -> str: def get_from_cache(self) -> Union[List, bool]: """ Retrieves face detection results from cache if available. + + Returns: + Union[List, bool]: Cached face detection data or False if cache is unavailable. """ if not self.save_cache: return False @@ -78,12 +102,17 @@ def get_from_cache(self) -> Union[List, bool]: if os.path.isfile(cache_filename): with open(cache_filename, 'rb') as cached_file: return pickle.load(cached_file) - return False def detect_faces_in_frames(self, images: List[np.ndarray]) -> List[Tuple[int, int, int, int]]: """ Detect faces in the given frames using face_alignment. + + Args: + images (List[np.ndarray]): A list of frames to detect faces in. + + Returns: + List[Tuple[int, int, int, int]]: A list of bounding box coordinates per frame. """ detector = face_alignment.FaceAlignment( landmarks_type=face_alignment.LandmarksType.TWO_D, @@ -93,6 +122,8 @@ def detect_faces_in_frames(self, images: List[np.ndarray]) -> List[Tuple[int, in predictions = [] for i in tqdm(range(0, len(images)), desc="Face Detection"): + # get_landmarks_from_image returns + # a list of [landmarks, landmark score, bounding-box-info] landmarks = detector.get_landmarks_from_image(images[i], return_bboxes=True) predictions.append(get_face_box(landmarks)) @@ -101,27 +132,37 @@ def detect_faces_in_frames(self, images: List[np.ndarray]) -> List[Tuple[int, in def process_face_boxes(self, predictions: List, images: List[np.ndarray]) -> List[List]: """ - Process face bounding boxes, apply smoothing, and crop faces. + Process face bounding boxes, apply smoothing, and crop faces from frames. + + Args: + predictions (List): Bounding box predictions for each frame. + images (List[np.ndarray]): Original frames to crop from. + + Returns: + List[List]: A list containing [cropped_face, (y1, y2, x1, x2)] for each frame. """ pady1, pady2, padx1, padx2 = self.pads img_h, img_w = images[0].shape[:2] - # Convert predictions to bounding boxes + # Convert predictions to bounding boxes with padding results = [] for rect, image in zip(predictions, images): if rect is None: raise ValueError('Face not detected! Ensure all frames contain a face.') - y1 = max(0, rect[1] - pady1) - y2 = min(img_h, rect[3] + pady2) - x1 = max(0, rect[0] - padx1) - x2 = min(img_w, rect[2] + padx2) + + x1, y1, x2, y2 = rect + y1 = max(0, y1 - pady1) + y2 = min(img_h, y2 + pady2) + x1 = max(0, x1 - padx1) + x2 = min(img_w, x2 + padx2) results.append([x1, y1, x2, y2]) - # Smooth bounding boxes if needed + # Smooth bounding boxes if smoothing is enabled boxes = np.array(results) if not self.nosmooth: boxes = self.get_smoothened_boxes(boxes, t=5) + # Crop faces from images cropped_results = [] for (x1, y1, x2, y2), image in zip(boxes, images): face_img = image[int(y1): int(y2), int(x1): int(x2)] @@ -131,7 +172,15 @@ def process_face_boxes(self, predictions: List, images: List[np.ndarray]) -> Lis def face_detect(self, images: List[np.ndarray]) -> List[Tuple[np.ndarray, Tuple[int, int, int, int]]]: """ - Performs face detection on a list of images. + Performs face detection on a list of images and returns cropped face regions + along with the bounding box coordinates. + + Args: + images (List[np.ndarray]): A list of frames to detect faces in. + + Returns: + List[Tuple[np.ndarray, Tuple[int, int, int, int]]]: + Each element is [face_img, (y1, y2, x1, x2)]. """ cache = self.get_from_cache() if cache: @@ -153,16 +202,33 @@ def datagen( mels: List[np.ndarray] ) -> Tuple[np.ndarray, np.ndarray, List[np.ndarray], List[Tuple[int, int, int, int]]]: """ - Generator that yields batches of images and mel spectrogram chunks. + Generator that yields batches of face images and mel spectrograms. + + Args: + frames (List[np.ndarray]): Video frames from the source face video. + mels (List[np.ndarray]): List of mel spectrogram chunks. + + Yields: + Tuple[np.ndarray, np.ndarray, List[np.ndarray], List[Tuple[int, int, int, int]]]: + - img_batch_np: Array of masked/unmasked face images ready for the model. + - mel_batch_np: Array of mel spectrogram data. + - frame_batch: Original frames corresponding to each face. + - coords_batch: Bounding box coordinates for each face. """ - face_det_results = self._get_face_detections(frames) + # Detect faces/crop them once + face_det_results = self.face_detect(frames) + batch_size = self.wav2lip_batch_size img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + total_frames = len(frames) for i, m in enumerate(mels): - idx = 0 if self.static else (i % len(frames)) - frame_to_save = frames[idx] - face, coords = face_det_results[idx] + # Loop over mels in a cyclical manner w.r.t frames + index = i % total_frames + frame_to_save = frames[index] + face, coords = face_det_results[index] + + # Resize the cropped face to the expected input size face_resized = cv2.resize(face, (self.img_size, self.img_size)) img_batch.append(face_resized) @@ -177,22 +243,10 @@ def datagen( frame_batch.clear() coords_batch.clear() - # Yield remaining batch if any + # Yield any leftover samples if len(img_batch) > 0: yield self._prepare_batch(img_batch, mel_batch, frame_batch, coords_batch) - def _get_face_detections(self, frames: List[np.ndarray]) -> List[List]: - """ - Retrieve or compute face detections. - """ - if self.box[0] == -1: - # No manual bounding box provided, detect faces - return self.face_detect(frames if not self.static else [frames[0]]) - else: - # Use provided bounding box - y1, y2, x1, x2 = self.box - return [[f[y1:y2, x1:x2], (y1, y2, x1, x2)] for f in frames] - def _prepare_batch( self, img_batch: List[np.ndarray], @@ -201,16 +255,29 @@ def _prepare_batch( coords_batch: List[Tuple[int, int, int, int]] ) -> Tuple[np.ndarray, np.ndarray, List[np.ndarray], List[Tuple[int, int, int, int]]]: """ - Prepares a batch of images and mel spectrograms for inference. + Prepares a single batch of images and mel spectrograms for inference. + + Args: + img_batch (List[np.ndarray]): List of face images. + mel_batch (List[np.ndarray]): List of mel spectrograms. + frame_batch (List[np.ndarray]): Original frames corresponding to the faces. + coords_batch (List[Tuple[int, int, int, int]]): Coordinates for face placement. + + Returns: + Tuple of (img_batch_np, mel_batch_np, frame_batch, coords_batch). """ img_batch_np = np.asarray(img_batch, dtype=np.uint8) mel_batch_np = np.asarray(mel_batch, dtype=np.float32) - # Mask the lower half of the image + # Mask the lower half of the face in the input so the model learns to fill the mouth region half = self.img_size // 2 img_masked = img_batch_np.copy() img_masked[:, half:] = 0 + + # Concatenate original and masked images across the channel dimension img_batch_np = np.concatenate((img_masked, img_batch_np), axis=3) / 255.0 + + # Add channel dimension for mel spectrogram mel_batch_np = mel_batch_np[..., np.newaxis] return img_batch_np, mel_batch_np, frame_batch, coords_batch @@ -219,6 +286,12 @@ def _prepare_batch( def create_temp_file(ext: str) -> str: """ Creates a temporary file with a specific extension. + + Args: + ext (str): Desired file extension (e.g. 'avi', 'wav'). + + Returns: + str: Full path to the created temporary file (empty). """ temp_fd, filename = tempfile.mkstemp() os.close(temp_fd) @@ -226,35 +299,70 @@ def create_temp_file(ext: str) -> str: def sync(self, face: str, audio_file: str, outfile: str) -> str: """ - Performs lip-syncing on the input video/image using the provided audio. + Performs lip-syncing on the input video/image with the provided audio. + + Args: + face (str): Path to an image or video file containing the face to be lip-synced. + audio_file (str): Path to an audio file. + outfile (str): Path to save the output video file. + + Returns: + str: The path to the final output video file with audio and lip-synced faces. + + Raises: + ValueError: If the input face file is invalid or cannot be processed. """ self._filepath = face + + # 1. Load the video/image frames and get FPS full_frames, fps = self._load_input_face(face) + + # 2. Prepare/extract raw audio from the input audio file if needed audio_file = self._prepare_audio(audio_file) + + # 3. Generate the mel spectrogram mel = self._generate_mel_spectrogram(audio_file) + + # 4. Split mel spectrogram into chunks for each frame mel_chunks = self._split_mel_chunks(mel, fps) + + # 5. Load the Wav2Lip model model = self._load_model_for_inference() + # 6. Setup VideoWriter temp_result_avi = self.create_temp_file('avi') out = self._prepare_video_writer(temp_result_avi, full_frames[0].shape[:2], fps) + # 7. Perform inference self._perform_inference(model, full_frames, mel_chunks, out) out.release() + # 8. Merge the new video with the original audio self._merge_audio_video(audio_file, temp_result_avi, outfile) + return outfile def _load_input_face(self, face: str) -> Tuple[List[np.ndarray], float]: """ Loads the input face (video or image) and returns frames and fps. + + Args: + face (str): Path to a video or image file. + + Returns: + (List[np.ndarray], float): A list of frames and the fps. + + Raises: + ValueError: If the face path is invalid or the file is not found. """ if not os.path.isfile(face): raise ValueError('face argument must be a valid file path.') - if face.split('.')[-1].lower() in ['jpg', 'png', 'jpeg']: - self.static = True + ext = face.split('.')[-1].lower() + if ext in ['jpg', 'png', 'jpeg']: + # Single image; create one-frame "video" full_frames = [cv2.imread(face)] - fps = self.fps + fps = self.fps # Use default or user-specified fps for single image else: full_frames, fps = read_frames(face) @@ -262,7 +370,13 @@ def _load_input_face(self, face: str) -> Tuple[List[np.ndarray], float]: def _prepare_audio(self, audio_file: str) -> str: """ - Prepares (extracts) raw audio if not in .wav format. + Prepares (extracts) a .wav file from any given audio format using ffmpeg if needed. + + Args: + audio_file (str): Path to the audio file. + + Returns: + str: Path to the .wav file (might be newly created). """ if not audio_file.endswith('.wav'): wav_filename = self.create_temp_file('wav') @@ -278,6 +392,15 @@ def _prepare_audio(self, audio_file: str) -> str: def _generate_mel_spectrogram(audio_file: str) -> np.ndarray: """ Generates the mel spectrogram from the given audio file. + + Args: + audio_file (str): Path to a .wav file. + + Returns: + np.ndarray: The mel spectrogram. + + Raises: + ValueError: If the mel spectrogram contains NaN values. """ wav = audio.load_wav(audio_file, 16000) mel = audio.melspectrogram(wav) @@ -288,6 +411,9 @@ def _generate_mel_spectrogram(audio_file: str) -> np.ndarray: def _load_model_for_inference(self) -> torch.nn.Module: """ Loads the lip sync model for inference. + + Returns: + torch.nn.Module: The Wav2Lip model loaded in eval mode. """ model = load_model(self.model, self.device, self.checkpoint_path) return model @@ -295,7 +421,15 @@ def _load_model_for_inference(self) -> torch.nn.Module: @staticmethod def _prepare_video_writer(filename: str, frame_shape: Tuple[int, int], fps: float) -> cv2.VideoWriter: """ - Prepares the VideoWriter for output. + Prepares a cv2.VideoWriter to save frames. + + Args: + filename (str): Name of the output file. + frame_shape (Tuple[int, int]): (height, width) of the frames. + fps (float): Frames per second. + + Returns: + cv2.VideoWriter: A VideoWriter object. """ frame_h, frame_w = frame_shape return cv2.VideoWriter( @@ -313,19 +447,29 @@ def _perform_inference( out: cv2.VideoWriter, ): """ - Runs the inference loop: generates data, passes through model, and writes results. + Runs the inference loop: feeds batches of frames and mel spectrograms + through the model, and writes the predicted frames to the output. + + Args: + model (torch.nn.Module): The Wav2Lip model. + full_frames (List[np.ndarray]): List of frames from the input video. + mel_chunks (List[np.ndarray]): List of mel spectrogram chunks. + out (cv2.VideoWriter): Video writer to save generated frames. """ data_generator = self.datagen(full_frames.copy(), mel_chunks) total_batches = int(np.ceil(len(mel_chunks) / self.wav2lip_batch_size)) steps = tqdm(data_generator, total=total_batches, desc="Lip-sync Inference") for (img_batch_np, mel_batch_np, frames, coords) in steps: + # Prepare input tensors img_batch_t = torch.FloatTensor(np.transpose(img_batch_np, (0, 3, 1, 2))).to(self.device) mel_batch_t = torch.FloatTensor(np.transpose(mel_batch_np, (0, 3, 1, 2))).to(self.device) with torch.no_grad(): + # Model forward pass pred = model(mel_batch_t, img_batch_t) + # Write predicted frames self._write_predicted_frames(pred, frames, coords, out) @staticmethod @@ -337,8 +481,16 @@ def _write_predicted_frames( ): """ Writes the predicted frames (lipsynced faces) into the output video. + + Args: + pred (torch.Tensor): Model output of shape (N, 3, H, W). + frames (List[np.ndarray]): Original frames. + coords (List[Tuple[int, int, int, int]]): Face coordinates. + out (cv2.VideoWriter): Video writer to append the frames. """ + # Convert predicted tensor to NumPy pred_np = (pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.0).astype(np.uint8) + for p, f, c in zip(pred_np, frames, coords): y1, y2, x1, x2 = c p_resized = cv2.resize(p, (x2 - x1, y2 - y1)) @@ -347,7 +499,16 @@ def _write_predicted_frames( def _merge_audio_video(self, audio_file: str, temp_video: str, outfile: str): """ - Merges the generated video with the input audio. + Merges the generated video (temp_video) with the input audio (audio_file) + into the final output (outfile) using ffmpeg. + + Args: + audio_file (str): Path to the .wav file. + temp_video (str): Path to the intermediate .avi video file. + outfile (str): Path for the final output video file. + + Raises: + subprocess.CalledProcessError: If ffmpeg merging fails. """ command = ( f'ffmpeg -y -i "{audio_file}" -i "{temp_video}" -strict -2 ' @@ -357,20 +518,30 @@ def _merge_audio_video(self, audio_file: str, temp_video: str, outfile: str): def _split_mel_chunks(self, mel: np.ndarray, fps: float) -> List[np.ndarray]: """ - Splits the mel spectrogram into fixed-size chunks. - """ + Splits the mel spectrogram into fixed-size chunks according to the fps. + Args: + mel (np.ndarray): Mel spectrogram of shape (n_mels, time). + fps (float): Frames per second from the input video. + + Returns: + List[np.ndarray]: A list of mel spectrogram chunks, each chunk + typically of shape (n_mels, mel_step_size). + """ mel_chunks = [] mel_length = mel.shape[1] - mel_idx_multiplier = 80.0 / fps + mel_idx_multiplier = 80.0 / fps # Hard-coded for 80 mel frames per video second i = 0 while True: start_idx = int(i * mel_idx_multiplier) end_idx = start_idx + self.mel_step_size + if end_idx > mel_length: + # If we exceed the mel length, append the last chunk (padded if needed) mel_chunks.append(mel[:, -self.mel_step_size:]) break + mel_chunks.append(mel[:, start_idx:end_idx]) i += 1 diff --git a/lipsync/models/__init__.py b/lipsync/models/__init__.py index 2433db4..ab5d055 100644 --- a/lipsync/models/__init__.py +++ b/lipsync/models/__init__.py @@ -1,17 +1,21 @@ -from lipsync.models.wav2lip import Wav2Lip +""" +Model registration and loading utilities for lipsync. +""" + from typing import Dict, Any import torch +from lipsync.models.wav2lip import Wav2Lip -# Registry for available models +# A registry of available models MODEL_REGISTRY: Dict[str, Any] = { 'wav2lip': Wav2Lip, } -def _load(checkpoint_path, device): +def _load_checkpoint(checkpoint_path: str, device: str) -> Any: """ - Loads a model checkpoint from the specified path. + Loads a model checkpoint from the specified path onto the given device. Args: checkpoint_path (str): The path to the model checkpoint file. @@ -22,26 +26,29 @@ def _load(checkpoint_path, device): Raises: AssertionError: If the device is not 'cpu' or 'cuda'. + FileNotFoundError: If checkpoint file not found. """ + # Ensure valid device assert device in ['cpu', 'cuda'], "Device must be 'cpu' or 'cuda'" - # Load checkpoint on the specified device - if device == 'cuda': - return torch.load(checkpoint_path, weights_only=True) - - return torch.load( - checkpoint_path, - map_location=lambda storage, _: storage, - weights_only=True - ) + # Try loading checkpoint on the specified device + try: + checkpoint = torch.load( + checkpoint_path, + map_location=torch.device(device), + weights_only=True + ) + return checkpoint + except FileNotFoundError as e: + raise FileNotFoundError(f"Checkpoint file '{checkpoint_path}' not found.") from e -def load_model(model_name: str, device: str, checkpoint: str): +def load_model(model_name: str, device: str, checkpoint: str) -> torch.nn.Module: """ Loads and initializes a model with the given checkpoint and device. Args: - model_name (str): The name of the model to load. + model_name (str): The name of the model to load, e.g. 'wav2lip'. device (str): The device to load the model on, either 'cpu' or 'cuda'. checkpoint (str): The path to the model checkpoint. @@ -51,16 +58,22 @@ def load_model(model_name: str, device: str, checkpoint: str): Raises: KeyError: If the model name is not found in the model registry. """ - # Retrieve the model class from the registry + # Retrieve the model class from the registry, converting name to lower in case of mismatch cls = MODEL_REGISTRY[model_name.lower()] # Initialize the model model = cls() - # Load the checkpoint and set model state - checkpoint = _load(checkpoint, device) - model.load_state_dict(checkpoint) + # Load the checkpoint + checkpoint_dict = _load_checkpoint(checkpoint, device) + + # Load state dict into model + model.load_state_dict(checkpoint_dict) + + # Move model to the specified device model = model.to(device) + + # Put model into evaluation mode return model.eval() diff --git a/lipsync/models/conv.py b/lipsync/models/conv.py index 46c3274..1fd8a8b 100644 --- a/lipsync/models/conv.py +++ b/lipsync/models/conv.py @@ -1,3 +1,7 @@ +""" +Custom convolutional modules with optional residual connections for lipsync models. +""" + import torch from torch import nn @@ -30,9 +34,19 @@ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False): self.residual = residual def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the Conv2d module. + + Args: + x (torch.Tensor): Input tensor of shape (N, cin, H, W). + + Returns: + torch.Tensor: Output tensor of shape (N, cout, H', W'). + """ out = self.conv_block(x) if self.residual: out = out + x + return self.act(out) @@ -62,5 +76,14 @@ def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0): self.act = nn.ReLU(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the Conv2dTranspose module. + + Args: + x (torch.Tensor): Input tensor of shape (N, cin, H, W). + + Returns: + torch.Tensor: Output tensor of shape (N, cout, H', W'). + """ out = self.conv_block(x) return self.act(out) diff --git a/lipsync/models/wav2lip.py b/lipsync/models/wav2lip.py index 99be1b7..65dadac 100644 --- a/lipsync/models/wav2lip.py +++ b/lipsync/models/wav2lip.py @@ -1,17 +1,30 @@ +""" +Definition of the Wav2Lip model, which generates lip-synced video frames from audio +and face frames. It includes both the audio encoder and the face encoder/decoder modules. +""" + import torch from torch import nn + from lipsync.models.conv import Conv2dTranspose, Conv2d class Wav2Lip(nn.Module): - """Wav2Lip model for generating lip-synced videos. + """ + Wav2Lip model for generating lip-synced videos. This model takes as input sequences of audio and corresponding face frames and produces synthesized video frames where the lip movements are synchronized with the given audio. """ def __init__(self): - """Initializes the Wav2Lip model modules.""" + """ + Initializes the Wav2Lip model modules: + - Face encoder blocks + - Audio encoder + - Face decoder blocks + - Output block + """ super(Wav2Lip, self).__init__() # Face encoder blocks @@ -107,6 +120,7 @@ def __init__(self): ), ]) + # Output block self.output_block = nn.Sequential( Conv2d(80, 32, kernel_size=3, stride=1, padding=1), nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), @@ -118,7 +132,8 @@ def forward( audio_sequences: torch.Tensor, face_sequences: torch.Tensor ) -> torch.Tensor: - """Runs the forward pass of the Wav2Lip model. + """ + Runs the forward pass of the Wav2Lip model. Args: audio_sequences (torch.Tensor): The input audio sequences of shape (B, T, 1, 80, 16) @@ -135,8 +150,15 @@ def forward( # Reshape sequences if input is batched over time if input_dim_size > 4: - BT = audio_sequences.size(0) * audio_sequences.size(1) - audio_sequences = audio_sequences.view(-1, 1, audio_sequences.size(3), audio_sequences.size(4)) + BT = B * audio_sequences.size(1) + audio_sequences = audio_sequences.view( + BT, + audio_sequences.size(2), + audio_sequences.size(3), + audio_sequences.size(4) + ) + + # Face shape: (B, C, T, H, W) => (B, T, C, H, W) => (BT, C, H, W) _, C, T, H, W = face_sequences.size() face_sequences = face_sequences.permute(0, 2, 1, 3, 4).contiguous().view(-1, C, H, W) @@ -149,9 +171,7 @@ def forward( for block in self.face_encoder_blocks: x = block(x) feats.append(x) - - # Reverse feats for easier decoder access - feats.reverse() + feats.reverse() # Reverse the order for decoder # Decode with audio embedding x = audio_embedding @@ -164,8 +184,6 @@ def forward( # Reshape back to original format if needed if input_dim_size > 4: - # x: (B*T, C, H, W) - # Determine T from audio_embedding's batch: B*T total_frames = audio_embedding.size(0) T = total_frames // B x = x.view(B, T, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)