diff --git a/diffusion/aniportrait/ani_portrait.py b/diffusion/aniportrait/ani_portrait.py new file mode 100644 index 000000000..388fcd61e --- /dev/null +++ b/diffusion/aniportrait/ani_portrait.py @@ -0,0 +1,274 @@ +from math import ceil +import random +import numpy as np +import cv2 +import sys +from typing import Any +from tqdm import tqdm + +import onnxruntime +from ani_portrait_utils import get_model_file_names +from lmk_extractor import LMKExtractor +from facemesh_v2_utils import matrix_to_euler_and_translation, smooth_pose_seq, crop_face, euler_and_translation_to_matrix +from scipy.interpolate import interp1d + +sys.path.append("../../util") +from detector_utils import load_image +from arg_utils import get_base_parser, update_parser # noqa: E402 +from model_utils import check_and_download_models # noqa: E402 +from scheduling_ddim import DDIMScheduler +import ailia +from audio_processor import prepare_audio_feature + +REMOTE_PATH = "https://storage.googleapis.com/ailia-models/AniPortrait/" +FACEMESH_REMOTE_PATH = "https://storage.googleapis.com/ailia-models/facemesh_v2" +MODES = ["Audio2Video", "Video2Video"] +INPUT_IMAGE = "lyl.png" +OUTPUT_IMAGE = "" +REF_IMAGE_SAMPLE = "lyl.png" +AUDIO_SAMPLE = "lyl.wav" +HEAD_POSE_SAMPLE = "pose_ref_video.mp4" + +parser = get_base_parser("gpt2 text generation", INPUT_IMAGE, OUTPUT_IMAGE) +parser.add_argument( + "--onnx", + action="store_true", + help="By default, the ailia SDK is used, but with this option, you can switch to using ONNX Runtime", +) +parser.add_argument("-r", "--reference_image", type=str, default=REF_IMAGE_SAMPLE) +parser.add_argument("-hp", "--head_pose_reference_video", type=str, default=None) +parser.add_argument("-a", "--audio", default=AUDIO_SAMPLE) +parser.add_argument("-v", "--source_video") +parser.add_argument("-s", "--steps", type=int, default=25) +parser.add_argument("-S", "--seed", type=int, default=42) +parser.add_argument("-vs", "--video_size", type=int, default=512) +parser.add_argument("-l", "--length", type=int, default=0) +parser.add_argument("-m", "--mode", choices=MODES) +# parser.add_argument("-p", "--prompt", help="prompt text", required=True, type=str) +args = update_parser(parser, check_input_type=False) + + +def get_head_pose(lmk_extractor: LMKExtractor, video_path): + trans_mat = [] + cap = cv2.VideoCapture(video_path) + total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) + fps = cap.get(cv2.CAP_PROP_FPS) + while True: + ret, img = cap.read() + if not ret: + break + # img = load_image(args.input[0]) + # img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) + _trans_mat, _ = lmk_extractor(img) + trans_mat.append(_trans_mat) + + cap.release() + trans_mat = np.array(trans_mat) + + # Compute delta pose + trans_mat_inv_frame_0 = np.linalg.inv(trans_mat[0]) + pose_arr = np.zeros([trans_mat.shape[0], 6]) + + for i in range(pose_arr.shape[0]): + pose_mat = trans_mat_inv_frame_0 @ trans_mat[i] + euler_angles, translation_vector = matrix_to_euler_and_translation(pose_mat) + pose_arr[i, :3] = euler_angles + pose_arr[i, 3:6] = translation_vector + + new_fps = 30 + old_time = np.linspace(0, total_frames / fps, int(total_frames)) + new_time = np.linspace(0, total_frames / fps, int(total_frames * new_fps / fps)) + + pose_arr_interp = np.zeros((len(new_time), 6)) + for i in range(6): + interp_func = interp1d(old_time, pose_arr[:, i]) + pose_arr_interp[:, i] = interp_func(new_time) + + pose_arr_smooth = smooth_pose_seq(pose_arr_interp, window_size=5) + return pose_arr_smooth + + +def draw_landmarks(image_size, keypoints, normed=False): + ini_size = [512, 512] + image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8) + for i in range(keypoints.shape[0]): + x = int(keypoints[i, 0]) + y = int(keypoints[i, 1]) + cv2.circle(image, (x, y), 1, (0, 255, 0), -1) + return image + + +def smooth_pose_seq(pose_seq, window_size): + smoothed_pose_seq = np.zeros_like(pose_seq) + + for i in range(len(pose_seq)): + start = max(0, i - window_size // 2) + end = min(len(pose_seq), i + window_size // 2 + 1) + smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0) + + return smoothed_pose_seq + + +def create_perspective_matrix(aspect_ratio): + k_degrees_to_radians = np.pi / 180.0 + near = 1 + far = 10_000 + perspective_matrix = np.zeros(16, dtype=np.float32) + + f = 1.0 / np.tan(k_degrees_to_radians * 63 / 2.0) + + denom = 1.0 / (near - far) + perspective_matrix[0] = f / aspect_ratio + perspective_matrix[5] = f + perspective_matrix[10] = (near + far) * denom + perspective_matrix[11] = -1.0 + perspective_matrix[14] = 1.0 * far * near * denom + + perspective_matrix[5] *= -1.0 + return perspective_matrix + + +def project_points(points_3d, trans_mat, pose_vectors, image_shape): + P = create_perspective_matrix(image_shape[1] / image_shape[0]).reshape(4, 4).T + L, N, _ = points_3d.shape + projected_points = np.zeros((L, N, 2)) + + for i in range(L): + points_3d_frame = points_3d[i] + ones = np.ones((points_3d_frame.shape[0], 1)) + points_3d_homogeneous = np.hstack((points_3d_frame, ones)) + transformed_points = points_3d_homogeneous @ (trans_mat @ euler_and_traslation_to_matrix(pose_vectors[i][:3])).T @ P + projected_points_frame = transformed_points[:, :2] / transformed_points[:, 3, np.newaxis] + projected_points_frame[:, 0] = (projected_points_frame[:, 0] + 1) * 0.5 * image_shape[1] + projected_points_frame[:, 1] = (projected_points_frame[:, 1] + 1) * 0.5 * image_shape[0] + projected_points[i] = projected_points_frame + + return projected_points + + + +def generate_from_image(nets: dict[str, Any]): + if args.mode == "Audio2Video": + lmk_extractor = LMKExtractor( + nets["face_landmarks_detector"], + nets["face_detector"], + args.onnx, + ) + ref_image = load_image(args.reference_image) + ref_image = crop_face(ref_image, lmk_extractor) + fps = 30 + cfg = 3.5 + length = 60 + fi_step = 3 + width = 512 + height = 512 + + lmks3d, lmks = lmk_extractor(ref_image) + # ref_pose = draw_landmarks((ref_image.shape[1], ref_image.shape[0]), lmks, normed=True) + + sample = prepare_audio_feature(args.audio, wav2vec_model_path="./pretrained_model/wav2vec2-base-960h") + sample["audio_feature"] = sample["audio_feature"].astype(np.float32) + sample["audio_feature"] = np.expand_dims(sample["audio_feature"], axis=0) + + # inference + if args.onnx: + pred = nets["a2m_model"].run( + ["output"], + {"input_value": sample["audio_feature"], "seq_len": [sample["seq_len"]]} + ) + print(f"{pred=}") + else: + pred = nets["a2m_model"].predict(sample["audio_feature"]) + + pred = pred.squeeze() + pred = pred.reshape(pred.shape[0], -1, 3) + pred = pred + lmks3d + + if args.head_pose_reference_video is not None: + pose_seq = get_head_pose(lmk_extractor, args.head_pose_reference_video) + mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0) + pose_seq = np.tile(mirrored_pose_seq, (sample["seq_len"] // len(mirrored_pose_seq) + 1, 1))[:sample["seq_len"]] + else: + chunk_duration = 5 + sr = 16_000 + fps = 30 + chunk_size = sr * chunk_duration + + audio_chunks = [] + + for i in range(ceil(sample["audio_feature"].shape[1] / chunk_size)): + audio_chunks.append( + sample["audio_feature"][0, i * chunk_size:(i + 1) * chunk_size].reshape(1, -1) + ) + + seq_len_list = [chunk_duration * fps] * (len(audio_chunks) - 1) + [sample["seq_len"] % (chunk_duration * fps)] + + audio_chunks[-2] = np.concatenate((audio_chunks[-2], audio_chunks[-1]), axis=1) + seq_len_list[-2] = seq_len_list[-2] + seq_len_list[-1] + del audio_chunks[-1] + del seq_len_list[-1] + + pose_seq = [] + + for audio, seq_len in zip(audio_chunks, seq_len_list): + print(f"{audio.shape=}") + input(">>>") + if args.onnx: + pose_seq_chunk = nets["a2p_model"].run( + ["output"], + {"input_value": audio, "seq_len": [seq_len], "id_seed": [random.randint(0, 99)]} + ) + print(f"{pose_seq_chunk=}") + else: + pose_seq_chunk = nets["a2p_model"].predict(audio) + + pose_seq_chunk = pose_seq_chunk.squeeze() + pose_seq_chunk[:, :3] *= 0.5 + pose_seq.append(pose_seq_chunk) + + pose_seq = np.concatenate(pose_seq, 0) + pose_seq = smooth_pose_seq(pose_seq, 7) + + projected_vertices = project_points(pred, trans_mat, pose_seq, (height, width)) + + pose_images = [] + + for i, verts in enumerate(projected_vertices): + lmk_img = draw_landmarks(verts) + pose_images.append(lmk_img) + + pose_list = [] + args_L = len(pose_images) if length == 0 or length > len(pose_images) else length + for pose_image_np in pose_images[: args_L : fi_step]: + pose_image_np = cv2.resize(pose_image_np, (width, height)) + pose_list.append(pose_image_np) + + for i, img in enumerate(pose_list, 1): + cv2.imwrite(f"pose_{i}.png", img) + + else: + pass + + +if __name__ == "__main__": + model_file_names = get_model_file_names() + + nets = {} + + for root_model, _model_file_names in tqdm(model_file_names.items()): + for model_name, model_files in tqdm(_model_file_names.items(), leave=False): + # check_and_download_models( + # model_files["weight"], + # model_files["model"], + # REMOTE_PATH if root_model == "aniportrait" else FACEMESH_REMOTE_PATH, + # ) + + if args.onnx: + net = onnxruntime.InferenceSession(model_files["weight"]) + else: + net = ailia.Net(model_files["model"], model_files["weight"], env_id=args.env_id) + + nets[model_name] = net + + + generate_from_image(nets) \ No newline at end of file diff --git a/diffusion/aniportrait/ani_portrait_utils.py b/diffusion/aniportrait/ani_portrait_utils.py new file mode 100644 index 000000000..6620fff0d --- /dev/null +++ b/diffusion/aniportrait/ani_portrait_utils.py @@ -0,0 +1,63 @@ +import numpy as np + + +def get_model_file_names() -> dict[str, dict[str, str]]: + return { + "aniportrait": { + "a2m_model": { + "weight": "a2m_model.onnx", + "model": "a2m_model.onnx.prototxt", + }, + "a2p_model": { + "weight": "a2p_model.onnx", + "model": "a2p_model.onnx.prototxt", + }, + "denoising_unet": { + "weight": "denoising_unet.onnx", + "model": "denoising_unet.onnx.prototxt", + }, + "encoder": { + "weight": "encoder.onnx", + "model": "encoder.onnx.prototxt", + }, + "image_encoder": { + "weight": "image_encoder.onnx", + "model": "image_encoder.onnx.prototxt", + }, + "pose_guider": { + "weight": "pose_guider.onnx", + "model": "pose_guider.onnx.prototxt", + }, + "reference_unet": { + "weight": "reference_unet.onnx", + "model": "reference_unet.onnx.prototxt", + }, + }, + "facemesh_v2": { + "face_landmarks_detector": { + "weight": "face_landmarks_detector.onnx", + "model": "face_landmarks_detector.onnx.prototxt", + }, + "face_detector": { + "weight": "face_detector.onnx", + "model": "face_detector.onnx.prototxt", + }, + }, + "audio": { + "wav2vec2feature_extractor": { + "weight": "wav2vec2feature_extractor.onnx", + "model": "wav2vec2feature_extractor.onnx.prototxt" + } + }, + } + + +def load_canonical_model() -> np.ndarray: + vertices = [] + with open("canonical_model.obj", 'r') as file: + for line in file: + if line.startswith('v '): + _, x, y, z = line.split() + vertices.append([float(x), float(y), float(z)]) + + return np.array(vertices) \ No newline at end of file diff --git a/diffusion/aniportrait/audio_processor.py b/diffusion/aniportrait/audio_processor.py new file mode 100644 index 000000000..ecf333bba --- /dev/null +++ b/diffusion/aniportrait/audio_processor.py @@ -0,0 +1,31 @@ +import os +import math + +import librosa +import numpy as np +from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model + + +class DataProcessor: + def __init__(self, sampling_rate, wav2vec_model_path): + self.model = Wav2Vec2Model.from_pretrained(wav2vec_model_path) + self._processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) + self._sampling_rate = sampling_rate + + def extract_feature(self, audio_path): + speech_array, sampling_rate = librosa.load(audio_path, sr=self._sampling_rate) + input_value = np.squeeze(self._processor(speech_array, sampling_rate=sampling_rate).input_values) + return input_value + + +def prepare_audio_feature(wav_file, fps=30, sampling_rate=16000, wav2vec_model_path=None): + data_preprocessor = DataProcessor(sampling_rate, wav2vec_model_path) + + input_value = data_preprocessor.extract_feature(wav_file) + seq_len = math.ceil(len(input_value)/sampling_rate*fps) + return { + "audio_feature": input_value, + "seq_len": seq_len + } + + diff --git a/diffusion/aniportrait/configuration_utils.py b/diffusion/aniportrait/configuration_utils.py new file mode 100644 index 000000000..51462b9ee --- /dev/null +++ b/diffusion/aniportrait/configuration_utils.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" ConfigMixin base class and utilities.""" + +import functools +import inspect +from collections import OrderedDict +from typing import Any, Dict + + +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + +class ConfigMixin: + def register_to_config(self, **kwargs): + if not hasattr(self, "_internal_dict"): + internal_dict = kwargs + else: + internal_dict = {**self._internal_dict, **kwargs} + + self._internal_dict = FrozenDict(internal_dict) + + @classmethod + def from_config(cls, config_dict, **kwargs): + init_dict, _, hidden_dict = cls.extract_init_dict(config_dict, **kwargs) + + # Return model and optionally state and/or unused_kwargs + model = cls(**init_dict) + + # make sure to also save config parameters that might be used for compatible classes + model.register_to_config(**hidden_dict) + + return model + + @staticmethod + def _get_init_keys(cls): + return set(dict(inspect.signature(cls.__init__).parameters).keys()) + + @classmethod + def extract_init_dict(cls, config_dict, **kwargs): + # 0. Copy origin config dict + original_dict = dict(config_dict.items()) + + # 1. Retrieve expected config attributes from __init__ signature + expected_keys = cls._get_init_keys(cls) + expected_keys.remove("self") + + # remove private attributes + config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + + # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments + init_dict = {} + for key in expected_keys: + # if config param is passed to kwarg and is present in config dict + # it should overwrite existing config dict key + if key in kwargs and key in config_dict: + config_dict[key] = kwargs.pop(key) + + if key in kwargs: + # overwrite key + init_dict[key] = kwargs.pop(key) + elif key in config_dict: + # use value from config dict + init_dict[key] = config_dict.pop(key) + + # 6. Define unused keyword arguments + unused_kwargs = {**config_dict, **kwargs} + + # 7. Define "hidden" config parameters that were saved for compatible classes + hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} + + return init_dict, unused_kwargs, hidden_config_dict + + @property + def config(self) -> Dict[str, Any]: + """ + Returns the config of the class as a frozen dictionary + + Returns: + `Dict[str, Any]`: Config of the class. + """ + return self._internal_dict + + +def register_to_config(init): + r""" + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable + + Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! + """ + + @functools.wraps(init) + def inner_init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} + + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in new_kwargs + } + ) + + new_kwargs = {**config_init_kwargs, **new_kwargs} + getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) + + return inner_init diff --git a/diffusion/aniportrait/detection_utils.py b/diffusion/aniportrait/detection_utils.py new file mode 100644 index 000000000..6b579b920 --- /dev/null +++ b/diffusion/aniportrait/detection_utils.py @@ -0,0 +1,189 @@ +from collections import namedtuple + +import numpy as np + +from math_utils import sigmoid +from nms_utils import packed_nms + +IMAGE_SIZE = 128 + + +def calc_scale(min_scale, max_scale, stride_index, num_strides): + if num_strides == 1: + return (min_scale + max_scale) * 0.5 + else: + return min_scale + (max_scale - min_scale) * 1.0 * stride_index / (num_strides - 1.0) + + +def get_anchor(): + num_layers = 4 + strides = [8, 16, 16, 16] + opt_aspect_ratios = [1.0] + min_scale = 0.1484375 + max_scale = 0.75 + input_size_height = IMAGE_SIZE + input_size_width = IMAGE_SIZE + anchor_offset_x = 0.5 + anchor_offset_y = 0.5 + interpolated_scale_aspect_ratio = 1.0 + + Anchor = namedtuple('Anchor', ['x_center', 'y_center', 'w', 'h']) + anchors = [] + + layer_id = 0 + while layer_id < num_layers: + anchor_height = [] + anchor_width = [] + aspect_ratios = [] + scales = [] + + last_same_stride_layer = layer_id + while last_same_stride_layer < len(strides) \ + and strides[last_same_stride_layer] == strides[layer_id]: + + scale = calc_scale(min_scale, max_scale, last_same_stride_layer, len(strides)) + for aspect_ratio in opt_aspect_ratios: + aspect_ratios.append(aspect_ratio) + scales.append(scale) + + if last_same_stride_layer == len(strides) - 1: + scale_next = 1.0 + else: + scale_next = calc_scale( + min_scale, max_scale, last_same_stride_layer + 1, len(strides)) + scales.append((scale * scale_next) ** 0.5) + aspect_ratios.append(interpolated_scale_aspect_ratio) + + last_same_stride_layer += 1 + + for i, aspect_ratio in enumerate(aspect_ratios): + ratio_sqrts = aspect_ratio ** 0.5 + anchor_height.append(scales[i] / ratio_sqrts) + anchor_width.append(scales[i] * ratio_sqrts) + + stride = strides[layer_id] + feature_map_height = int(np.ceil(1.0 * input_size_height / stride)) + feature_map_width = int(np.ceil(1.0 * input_size_width / stride)) + + for y in np.arange(feature_map_height): + for x in np.arange(feature_map_width): + for anchor_id in range(len(anchor_height)): + x_center = (x + anchor_offset_x) * 1.0 / feature_map_width + y_center = (y + anchor_offset_y) * 1.0 / feature_map_height + w = h = 1.0 + anchors.append(Anchor(x_center, y_center, w, h)) + + layer_id = last_same_stride_layer + + return anchors + + +def decode_boxes(raw_boxes, anchors): + num_boxes = 896 + num_coords = 16 + x_scale = y_scale = IMAGE_SIZE + w_scale = h_scale = IMAGE_SIZE + num_keypoints = 6 + + # (num_boxes, (xmin, ymin, xmax, ymax, key1_x, key1_y, key2_x, key2_y, ..., key6_x, key6_y)) + boxes = np.zeros((num_boxes, num_coords)) + for i in range(num_boxes): + x_center = raw_boxes[i][0] + y_center = raw_boxes[i][1] + w = raw_boxes[i][2] + h = raw_boxes[i][3] + + x_center = x_center / x_scale * anchors[i].w + anchors[i].x_center + y_center = y_center / y_scale * anchors[i].h + anchors[i].y_center + h = h / h_scale * anchors[i].h + w = w / w_scale * anchors[i].w + + ymin = y_center - h / 2. + xmin = x_center - w / 2. + ymax = y_center + h / 2. + xmax = x_center + w / 2. + + boxes[i][0] = xmin + boxes[i][1] = ymin + boxes[i][2] = xmax + boxes[i][3] = ymax + for k in range(num_keypoints): + offset = 4 + k * 2 + keypoint_x = raw_boxes[i][offset] + keypoint_y = raw_boxes[i][offset + 1] + boxes[i][offset] = keypoint_x / x_scale * anchors[i].w + anchors[i].x_center + boxes[i][offset + 1] = keypoint_y / y_scale * anchors[i].h + anchors[i].y_center + + return boxes + + +def weighted_nms(boxes, scores): + scale = IMAGE_SIZE + min_suppression_threshold = 0.5 + + px_boxes = np.zeros((len(boxes), 4)) + px_boxes[:, :4] = boxes[:, :4] * scale + + packed_idx = packed_nms(px_boxes, scores, min_suppression_threshold) + + out_boxes = [] + out_scores = [] + for idx in packed_idx: + total_score = np.sum(scores[idx]) + + candidates = boxes[idx] + candidates = candidates * scores[idx].reshape(-1, 1) + weighted_detection = np.sum(candidates, axis=0) / total_score + + out_boxes.append(weighted_detection) + out_scores.append(np.max(scores[idx])) + + if len(out_boxes) == 0: + return [], [] + + out_boxes = np.vstack(out_boxes) + out_scores = np.array(out_scores) + + return out_boxes, out_scores + + +anchors = get_anchor() + + +def face_detection(detections, scores, project_mat): + boxes = decode_boxes(detections[0], anchors) + scores = np.clip(scores[0, :, 0], -100, 100) + scores = sigmoid(scores) + + min_score_thresh = 0.5 + idx = scores >= min_score_thresh + boxes = boxes[idx] + scores = scores[idx] + + # Performs non-max suppression to remove excessive detections. + boxes, scores = weighted_nms(boxes, scores) + + def project_fn(x, y): + return ( + x * project_mat[0][0] + y * project_mat[0][1] + project_mat[0][3], + x * project_mat[1][0] + y * project_mat[1][1] + project_mat[1][3] + ) + + # DetectionProjectionCalculator + for box in boxes: + xmin, ymin, xmax, ymax = box[:4] + ps = [ + project_fn(*p) for p in [ + [xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax] + ] + ] + + left_top = min(p[0] for p in ps), min(p[1] for p in ps) + right_bottom = max(p[0] for p in ps), max(p[1] for p in ps) + box[[0, 1]] = left_top + box[[2, 3]] = right_bottom + for i in range(6): + kx, ky = 4 + i * 2, 4 + i * 2 + 1 + box[[kx, ky]] = project_fn(box[kx], box[ky]) + + return boxes, scores diff --git a/diffusion/aniportrait/facemesh_v2_utils.py b/diffusion/aniportrait/facemesh_v2_utils.py new file mode 100644 index 000000000..1ec2202f7 --- /dev/null +++ b/diffusion/aniportrait/facemesh_v2_utils.py @@ -0,0 +1,239 @@ +from scipy.spatial.transform import Rotation as R +import math +import sys +from collections import namedtuple + +import cv2 +import numpy as np + +sys.path.append("../../util") +from image_utils import normalize_image + +from detection_utils import IMAGE_SIZE as IMAGE_DET_SIZE + +ROI = namedtuple("ROI", ["x_center", "y_center", "width", "height", "rotation"]) +NUM_LANDMARKS = 478 +IMAGE_SIZE = 256 + + +def warp_perspective( + img, roi: ROI, + dst_width, dst_height, + keep_aspect_ratio=True): + im_h, im_w, _ = img.shape + + v_pad = h_pad = 0 + if keep_aspect_ratio: + dst_aspect_ratio = dst_height / dst_width + roi_aspect_ratio = roi.height / roi.width + + if dst_aspect_ratio > roi_aspect_ratio: + new_height = roi.width * dst_aspect_ratio + new_width = roi.width + v_pad = (1 - roi_aspect_ratio / dst_aspect_ratio) / 2 + else: + new_width = roi.height / dst_aspect_ratio + new_height = roi.height + h_pad = (1 - dst_aspect_ratio / roi_aspect_ratio) / 2 + + roi = ROI(roi.x_center, roi.y_center, new_width, new_height, roi.rotation) + + a = roi.width + b = roi.height + c = math.cos(roi.rotation) + d = math.sin(roi.rotation) + e = roi.x_center + f = roi.y_center + g = 1 / im_w + h = 1 / im_h + + project_mat = [ + [a * c * g, -b * d * g, 0.0, (-0.5 * a * c + 0.5 * b * d + e) * g], + [a * d * h, b * c * h, 0.0, (-0.5 * b * c - 0.5 * a * d + f) * h], + [0.0, 0.0, a * g, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + + rotated_rect = ( + (roi.x_center, roi.y_center), + (roi.width, roi.height), + roi.rotation * 180. / math.pi + ) + pts1 = cv2.boxPoints(rotated_rect) + + pts2 = np.float32([[0, dst_height], [0, 0], [dst_width, 0], [dst_width, dst_height]]) + M = cv2.getPerspectiveTransform(pts1, pts2) + img = cv2.warpPerspective( + img, M, (dst_width, dst_height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT) + + return img, project_mat, roi, (h_pad, v_pad) + + + +def preprocess_det(img): + im_h, im_w, _ = img.shape + + """ + resize & padding + """ + roi = ROI(0.5 * im_w, 0.5 * im_h, im_w, im_h, 0) + dst_width = dst_height = IMAGE_DET_SIZE + img, matrix, *_ = warp_perspective( + img, roi, + dst_width, dst_height) + + """ + normalize & reshape + """ + img = normalize_image(img, normalize_type='127.5') + img = np.expand_dims(img, axis=0) + img = img.astype(np.float32) + + return img, matrix + + +def post_processing(input_tensors, roi, pad): + num_landmarks = NUM_LANDMARKS + num_dimensions = 3 + + # TensorsToFaceLandmarksGraph + input_tensors = input_tensors.reshape(-1) + output_landmarks = np.zeros((num_landmarks, num_dimensions)) + for i in range(num_landmarks): + offset = i * num_dimensions + output_landmarks[i] = input_tensors[offset:offset + 3] + + norm_landmarks = output_landmarks / 256 + + # LandmarkLetterboxRemovalCalculator + h_pad, v_pad = pad + left = h_pad + top = v_pad + left_and_right = h_pad * 2 + top_and_bottom = v_pad * 2 + for landmark in norm_landmarks: + new_x = (landmark[0] - left) / (1 - left_and_right) + new_y = (landmark[1] - top) / (1 - top_and_bottom) + new_z = landmark[2] / (1 - left_and_right) # Scale Z coordinate as X. + landmark[:3] = (new_x, new_y, new_z) + + # LandmarkProjectionCalculator + width = roi.width + height = roi.height + x_center = roi.x_center + y_center = roi.y_center + angle = roi.rotation + for landmark in norm_landmarks: + x = landmark[0] - 0.5 + y = landmark[1] - 0.5 + z = landmark[2] + new_x = math.cos(angle) * x - math.sin(angle) * y + new_y = math.sin(angle) * x + math.cos(angle) * y + + new_x = new_x * width + x_center + new_y = new_y * height + y_center + new_z = z * width + + landmark[...] = new_x, new_y, new_z + + return norm_landmarks + + +def preprocess(img, roi): + im_h, im_w, _ = img.shape + + """ + resize & padding + """ + dst_width = dst_height = IMAGE_SIZE + img, _, roi, pad = warp_perspective( + img, roi, + dst_width, dst_height, + keep_aspect_ratio=False) + + img = normalize_image(img, normalize_type='255') + img = np.expand_dims(img, axis=0) + img = img.astype(np.float32) + + return img, roi, pad + + +def matrix_to_euler_and_translation(matrix): + rotation_matrix = matrix[:3, :3] + translation_vector = matrix[:3, 3] + rotation = R.from_matrix(rotation_matrix) + euler_angles = rotation.as_euler('xyz', degrees=True) + return euler_angles, translation_vector + + +def smooth_pose_seq(pose_seq, window_size=5): + smoothed_pose_seq = np.zeros_like(pose_seq) + + for i in range(len(pose_seq)): + start = max(0, i - window_size // 2) + end = min(len(pose_seq), i + window_size // 2 + 1) + smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0) + + return smoothed_pose_seq + + +def crop_face(img, lmk_extractor, expand=1.5): + result = lmk_extractor(img) # cv2 BGR + + if result is None: + return None + + H, W, _ = img.shape + lmks = result[1] + lmks[:, 0] *= W + lmks[:, 1] *= H + + x_min = np.min(lmks[:, 0]) + x_max = np.max(lmks[:, 0]) + y_min = np.min(lmks[:, 1]) + y_max = np.max(lmks[:, 1]) + + width = x_max - x_min + height = y_max - y_min + + if width*height >= W*H*0.15: + if W == H: + return img + size = min(H, W) + offset = int((max(H, W) - size)/2) + if size == H: + return img[:, offset:-offset] + else: + return img[offset:-offset, :] + else: + center_x = x_min + width / 2 + center_y = y_min + height / 2 + + width *= expand + height *= expand + + size = max(width, height) + + x_min = int(center_x - size / 2) + x_max = int(center_x + size / 2) + y_min = int(center_y - size / 2) + y_max = int(center_y + size / 2) + + top = max(0, -y_min) + bottom = max(0, y_max - img.shape[0]) + left = max(0, -x_min) + right = max(0, x_max - img.shape[1]) + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) + + cropped_img = img[y_min + top:y_max + top, x_min + left:x_max + left] + + return cropped_img + +def euler_and_translation_to_matrix(euler_angles, translation_vector): + rotation = R.from_euler("xyz", euler_angles, degrees=True) + rotation_matrix = rotation.as_matrix() + + matrix = np.eye(4) + matrix[:3, :3] = rotation_matrix + matrix[:3, 3] = translation_vector + return matrix \ No newline at end of file diff --git a/diffusion/aniportrait/feature_extraction_wav2vec2.py b/diffusion/aniportrait/feature_extraction_wav2vec2.py new file mode 100644 index 000000000..3b743c2f1 --- /dev/null +++ b/diffusion/aniportrait/feature_extraction_wav2vec2.py @@ -0,0 +1,238 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extractor class for Wav2Vec2 +""" + +from typing import List, Optional, Union + +import numpy as np + +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a Wav2Vec2 feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + feature_size (`int`, defaults to 1): + The feature dimension of the extracted features. + sampling_rate (`int`, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + padding_value (`float`, defaults to 0.0): + The value that is used to fill the padding values. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly + improve the performance for some models, *e.g.*, + [wav2vec2-lv60](https://huggingface.co/models?search=lv60). + return_attention_mask (`bool`, *optional*, defaults to `False`): + Whether or not [`~Wav2Vec2FeatureExtractor.__call__`] should return `attention_mask`. + + + + Wav2Vec2 models that have set `config.feat_extract_norm == "group"`, such as + [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), have **not** been trained using + `attention_mask`. For such models, `input_values` should simply be padded with 0 and no `attention_mask` + should be passed. + + For Wav2Vec2 models that have set `config.feat_extract_norm == "layer"`, such as + [wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self), `attention_mask` should be + passed for batched inference. + + """ + + model_input_names = ["input_values", "attention_mask"] + + def __init__( + self, + feature_size=1, + sampling_rate=16000, + padding_value=0.0, + return_attention_mask=False, + do_normalize=True, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.return_attention_mask = return_attention_mask + self.do_normalize = do_normalize + + @staticmethod + def zero_mean_unit_var_norm( + input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0 + ) -> List[np.ndarray]: + """ + Every array in the list is normalized to have zero mean and unit variance + """ + if attention_mask is not None: + attention_mask = np.array(attention_mask, np.int32) + normed_input_values = [] + + for vector, length in zip(input_values, attention_mask.sum(-1)): + normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) + if length < normed_slice.shape[0]: + normed_slice[length:] = padding_value + + normed_input_values.append(normed_slice) + else: + normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] + + return normed_input_values + + def __call__( + self, + raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + padding: Union[bool, str, PaddingStrategy] = False, + max_length: Optional[int] = None, + truncation: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + [What are attention masks?](../glossary#attention-mask) + + + + Wav2Vec2 models that have set `config.feat_extract_norm == "group"`, such as + [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), have **not** been trained using + `attention_mask`. For such models, `input_values` should simply be padded with 0 and no + `attention_mask` should be passed. + + For Wav2Vec2 models that have set `config.feat_extract_norm == "layer"`, such as + [wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self), `attention_mask` should + be passed for batched inference. + + + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + padding_value (`float`, defaults to 0.0): + """ + # if sampling_rate is not None: + # if sampling_rate != self.sampling_rate: + # raise ValueError( + # f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + # f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" + # f" {self.sampling_rate} and not {sampling_rate}." + # ) + # else: + # logger.warning( + # "It is strongly recommended to pass the ``sampling_rate`` argument to this function. " + # "Failing to do so can result in silent errors that might be hard to debug." + # ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + # if is_batched_numpy and len(raw_speech.shape) > 2: + # raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) # False + # always return batch + if not is_batched: + raw_speech = [raw_speech] + + # convert into correct format for padding + encoded_inputs = BatchFeature({"input_values": raw_speech}) + + padded_inputs = self.pad( + encoded_inputs, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + # convert input values to correct format + input_values = padded_inputs["input_values"] + if not isinstance(input_values[0], np.ndarray): + padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values] + elif ( + not isinstance(input_values, np.ndarray) + and isinstance(input_values[0], np.ndarray) + and input_values[0].dtype is np.dtype(np.float64) + ): + padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values] + elif isinstance(input_values, np.ndarray) and input_values.dtype is np.dtype(np.float64): + padded_inputs["input_values"] = input_values.astype(np.float32) + + # convert attention_mask to correct format + attention_mask = padded_inputs.get("attention_mask") + if attention_mask is not None: + padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask] + + # zero-mean and unit-variance normalization + if self.do_normalize: + attention_mask = ( + attention_mask + if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD + else None + ) + padded_inputs["input_values"] = self.zero_mean_unit_var_norm( + padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value + ) + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs diff --git a/diffusion/aniportrait/lmk_extractor.py b/diffusion/aniportrait/lmk_extractor.py new file mode 100644 index 000000000..b135f6a01 --- /dev/null +++ b/diffusion/aniportrait/lmk_extractor.py @@ -0,0 +1,136 @@ +import math +import cv2 +import numpy as np + +import onnxruntime +import ailia +from ani_portrait_utils import load_canonical_model +from facemesh_v2_utils import preprocess_det, ROI, post_processing, NUM_LANDMARKS, preprocess +from detection_utils import face_detection + + +def normalize_points(points): + # 重心を計算 + centroid = np.mean(points, axis=0) + # 原点周りに中心化 + centered_points = points - centroid + # スケールを計算 (原点からの二乗平均平方根距離) + scale = np.sqrt(np.sum(centered_points ** 2) / centered_points.shape[0]) + # 正規化 + normalized_points = centered_points / scale + return normalized_points, centroid, scale + + +def estimate_pose(src_points, dst_points) -> np.ndarray: + src_points = np.array(src_points) + dst_points = np.array(dst_points) + + # 両方のポイントを正規化 + src_points, src_centroid, src_scale = normalize_points(src_points) + dst_points, dst_centroid, dst_scale = normalize_points(dst_points) + + # 共分散行列を計算 + covariance_matrix = np.dot(dst_points.T, src_points) + + # 特異値分解 (SVD) を実行 + U, S, Vt = np.linalg.svd(covariance_matrix) + + # 回転行列を計算 + rotation_matrix = np.dot(U, Vt) + + # 適切な回転行列にする (det(R) = 1) + if np.linalg.det(rotation_matrix) < 0: + Vt[-1, :] *= -1 + rotation_matrix = np.dot(U, Vt) + + # スケールファクターを計算 + scale = dst_scale / src_scale + + # 平行移動を計算 + translation = dst_centroid - scale * np.dot(src_centroid, rotation_matrix) + + # 変換行列を構築 + transformation_matrix = np.eye(4) + transformation_matrix[:3, :3] = scale * rotation_matrix + transformation_matrix[:3, 3] = translation + return transformation_matrix + + + +class LMKExtractor: + def __init__( + self, + face_landmarks_detector: onnxruntime.InferenceSession | ailia.Net, + face_detector: onnxruntime.InferenceSession | ailia.Net, + is_onnx: bool, + ): + self.face_landmarks_detector = face_landmarks_detector + self.face_detector = face_detector + self.CANONICAL_MODEL = load_canonical_model() + self.IS_ONNX = is_onnx + + def _predict(self, img): + im_h, im_w, _ = img.shape + img = img[:, :, ::-1] # BGR -> RGB + + input, matrix = preprocess_det(img) + + # feedforward + if not self.IS_ONNX: + output = self.face_detector.predict([input]) + else: + output = self.face_detector.run(None, {'input': input}) + detections, scores = output + + boxes, scores = face_detection(detections, scores, matrix) + if len(boxes) == 0: + return np.zeros((0, NUM_LANDMARKS, 3)) + + landmarks_list = [] + for box in boxes: + # DetectionsToRectsCalculator + rect_width = box[2] - box[0] + rect_height = box[3] - box[1] + center_x = (box[0] + box[2]) / 2 + center_y = (box[1] + box[3]) / 2 + + x0, y0 = box[4] * im_w, box[5] * im_h + x1, y1 = box[6] * im_w, box[7] * im_h + angle = 0 - math.atan2(-(y1 - y0), x1 - x0) + angle = angle - 2 * math.pi * math.floor((angle - (-math.pi)) / (2 * math.pi)); + + # RectTransformationCalculator + scale_x = scale_y = 1.5 + rect_width = rect_width * scale_x + rect_height = rect_height * scale_y + + roi = ROI( + center_x * im_w, center_y * im_h, + rect_width * im_w, rect_height * im_h, + angle) + img, roi, pad = preprocess(img, roi) + + # feedforward + if not self.IS_ONNX: + output = self.face_landmarks_detector.predict([img]) + else: + output = self.face_landmarks_detector.run(None, {'input_12': img}) + landmark_tensors, presence_flag_tensors, _ = output + + norm_rect = ROI( + roi.x_center / im_w, roi.y_center / im_h, + roi.width / im_w, roi.height / im_h, + angle) + landmarks = post_processing(landmark_tensors, norm_rect, pad) + landmarks_list.append(landmarks) + + landmarks = np.stack(landmarks_list, axis=0) + + return landmarks + + def __call__(self, img: np.ndarray): + frame = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) + detection_result = self._predict(frame) + detection_result_crop = detection_result[0, :468, :].reshape(468, 3) + trans_mat = estimate_pose(self.CANONICAL_MODEL, detection_result_crop) + return trans_mat, detection_result \ No newline at end of file diff --git a/diffusion/aniportrait/lyl.wav b/diffusion/aniportrait/lyl.wav new file mode 100644 index 000000000..bbf4b7052 Binary files /dev/null and b/diffusion/aniportrait/lyl.wav differ diff --git a/diffusion/aniportrait/pose_ref_video.mp4 b/diffusion/aniportrait/pose_ref_video.mp4 new file mode 100644 index 000000000..c4ab581d9 Binary files /dev/null and b/diffusion/aniportrait/pose_ref_video.mp4 differ diff --git a/diffusion/aniportrait/prepare_data.py b/diffusion/aniportrait/prepare_data.py new file mode 100644 index 000000000..d2b1278d9 --- /dev/null +++ b/diffusion/aniportrait/prepare_data.py @@ -0,0 +1,30 @@ +import os +import math + +import librosa +import numpy as np +from transformers import Wav2Vec2FeatureExtractor + + +class DataProcessor: + def __init__(self, sampling_rate, wav2vec_model_path): + self._processor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) + self._sampling_rate = sampling_rate + + def extract_feature(self, audio_path): + speech_array, sampling_rate = librosa.load(audio_path, sr=self._sampling_rate) + input_value = np.squeeze(self._processor(speech_array, sampling_rate=sampling_rate).input_values) + return input_value + + +def prepare_audio_feature(wav_file, fps=30, sampling_rate=16000, wav2vec_model_path=None): + data_preprocessor = DataProcessor(sampling_rate, wav2vec_model_path) + + input_value = data_preprocessor.extract_feature(wav_file) + seq_len = math.ceil(len(input_value)/sampling_rate*fps) + return { + "audio_feature": input_value, + "seq_len": seq_len + } + + diff --git a/diffusion/aniportrait/scheduling_ddim.py b/diffusion/aniportrait/scheduling_ddim.py new file mode 100644 index 000000000..16f85a348 --- /dev/null +++ b/diffusion/aniportrait/scheduling_ddim.py @@ -0,0 +1,157 @@ +from typing import List, Optional, Union + +import numpy as np + +from configuration_utils import ConfigMixin, register_to_config + + +class DDIMScheduler(ConfigMixin): + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + ): + if trained_betas is not None: + self.betas = np.array(trained_betas) + elif beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps) + if beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace( + beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps) ** 2 + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.final_alpha_cumprod = self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1].copy().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = timesteps + + def step( + self, + model_output: np.ndarray, + timestep: int, + sample: np.ndarray, + eta: float = 0.0, + ): + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output + pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + pass + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + return prev_sample diff --git a/diffusion/aniportrait/wav2vec2feature_extractor.py b/diffusion/aniportrait/wav2vec2feature_extractor.py new file mode 100644 index 000000000..e69de29bb