diff --git a/README.md b/README.md
index 34634aeaf..8f3e40409 100644
--- a/README.md
+++ b/README.md
@@ -348,6 +348,7 @@ If you would like to try on your computer:
| [
](generative_adversarial_networks/encoder4editing/) | [encoder4editing](/generative_adversarial_networks/encoder4editing/) | [Designing an Encoder for StyleGAN Image Manipulation](https://github.com/omertov/encoder4editing) | Pytorch | 1.2.10 and later | |
| [
](generative_adversarial_networks/lipgan/) | [lipgan](/generative_adversarial_networks/lipgan/) | [LipGAN](https://github.com/Rudrabha/LipGAN) | Keras | 1.2.15 and later | [JP](https://medium.com/axinc/lipgan-%E3%83%AA%E3%83%83%E3%83%97%E3%82%B7%E3%83%B3%E3%82%AF%E5%8B%95%E7%94%BB%E3%82%92%E7%94%9F%E6%88%90%E3%81%99%E3%82%8B%E6%A9%9F%E6%A2%B0%E5%AD%A6%E7%BF%92%E3%83%A2%E3%83%87%E3%83%AB-57511508eaff) |
| [
](generative_adversarial_networks/live_portrait/) | [live_portrait](/generative_adversarial_networks/live_portrait)| [LivePortrait](https://github.com/KwaiVGI/LivePortrait) | Pytorch | 1.5.0 and later | [JP](https://medium.com/axinc/live-portrait-1%E6%9E%9A%E3%81%AE%E7%94%BB%E5%83%8F%E3%82%92%E5%8B%95%E3%81%8B%E3%81%9B%E3%82%8Bai%E3%83%A2%E3%83%87%E3%83%AB-8eaa7d3eb683)|
+| [
](generative_adversarial_networks/sadtalker/) | [SadTalker](generative_adversarial_networks/sadtalker/) | [SadTalker](https://github.com/OpenTalker/SadTalker) | Pytorch | 1.5.0 and later | |
## Hand detection
diff --git a/generative_adversarial_networks/sadtalker/LICENSE b/generative_adversarial_networks/sadtalker/LICENSE
new file mode 100644
index 000000000..153196259
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/LICENSE
@@ -0,0 +1,209 @@
+Tencent is pleased to support the open source community by making SadTalker available.
+
+Copyright (C), a Tencent company. All rights reserved.
+
+SadTalker is licensed under the Apache 2.0 License, except for the third-party components listed below.
+
+Terms of the Apache License Version 2.0:
+---------------------------------------------
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/generative_adversarial_networks/sadtalker/README.md b/generative_adversarial_networks/sadtalker/README.md
new file mode 100644
index 000000000..55da1805c
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/README.md
@@ -0,0 +1,74 @@
+# SadTalker
+
+## Input
+
+[
](input.png)
+
+(Image from https://github.com/OpenTalker/SadTalker/blob/main/examples/source_image/art_1.png)
+
+[input.wav](input.wav)
+
+(Audio from https://github.com/OpenTalker/SadTalker/blob/main/examples/driven_audio/bus_chinese.wav)
+
+## Output
+
+mp4
+
+## Requirements
+
+This model requires `ffmpeg` and additional module.
+
+```bash
+pip3 install -r requirements.txt
+```
+
+## Usage
+Automatically downloads the onnx and prototxt files on the first run.
+It is necessary to be connected to the Internet while downloading.
+
+For the sample image,
+```bash
+$ python3 sadtalker.py
+```
+
+If you want to specify the input image, put the image path after the `--input` option.
+If you want to specify the input audio, put the audio path after the `--audio` option.
+You can use `--savepath` option to change the name of the output file to save.
+```bash
+$ python3 sadtalker.py --input IMAGE_PATH --audio AUDIO_PATH --savepath SAVE_VIDEO_PATH
+```
+
+By adding the `--enhancer` option, you can enhance the generated face via gfpgan.
+```bash
+$ python3 sadtalker.py --enhancer
+```
+
+To run the full image animation, set the `--preprocess` option to `full`. For better results, also use `--still`.
+```bash
+$ python3 sadtalker.py -i input_full_body.png --enhancer --preprocess full --still
+```
+
+## Reference
+
+- [SadTalker](https://github.com/OpenTalker/SadTalker)
+- [retinaface](https://github.com/biubug6/Pytorch_Retinaface)
+- [GFPGAN](https://github.com/TencentARC/GFPGAN)
+
+## Framework
+
+Pytorch
+
+## Model Format
+
+ONNX opset=20
+
+## Netron
+
+[animation_generator.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/sadtalker/animation_generator.onnx.prototxt)
+[audio2exp.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/sadtalker/audio2exp.onnx.prototxt)
+[audio2pose.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/sadtalker/audio2pose.onnx.prototxt)
+[face_align.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/sadtalker/face_align.onnx.prototxt)
+[face3d_recon.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/sadtalker/face3d_recon.onnx.prototxt)
+[kp_detector.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/sadtalker/kp_detector.onnx.prototxt)
+[mappingnet_full.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/sadtalker/mappingnet_full.onnx.prototxt)
+[mappingnet_not_full.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/sadtalker/mappingnet_not_full.onnx.prototxt)
diff --git a/generative_adversarial_networks/sadtalker/animation/__init__.py b/generative_adversarial_networks/sadtalker/animation/__init__.py
new file mode 100644
index 000000000..9d12241e0
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/animation/__init__.py
@@ -0,0 +1 @@
+from .animate import AnimateFromCoeff
diff --git a/generative_adversarial_networks/sadtalker/animation/animate.py b/generative_adversarial_networks/sadtalker/animation/animate.py
new file mode 100644
index 000000000..d4d311c74
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/animation/animate.py
@@ -0,0 +1,113 @@
+import os
+import cv2
+import numpy as np
+import imageio
+from pydub import AudioSegment
+from skimage import img_as_ubyte
+
+from animation.make_animation import make_animation
+from animation.face_enhancer import enhancer_generator_with_len, enhancer_list
+from animation.paste_pic import paste_pic
+from animation.videoio import save_video_with_watermark
+
+class AnimateFromCoeff:
+ def __init__(self, generator_net, kp_detector_net, mapping_net, retinaface_net, gfpgan_net, use_onnx):
+ self.generator_net = generator_net
+ self.kp_detector_net = kp_detector_net
+ self.he_estimator_net = None
+ self.mapping_net = mapping_net
+ self.retinaface_net = retinaface_net
+ self.gfpgan_net = gfpgan_net
+ self.use_onnx = use_onnx
+
+ def generate(self, x, video_save_dir, pic_path, crop_info,
+ enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
+ source_image = x['source_image']
+ source_semantics = x['source_semantics']
+ target_semantics = x['target_semantics_list']
+ yaw_c_seq = x.get('yaw_c_seq', None)
+ pitch_c_seq = x.get('pitch_c_seq', None)
+ roll_c_seq = x.get('roll_c_seq', None)
+ frame_num = x['frame_num']
+
+ predictions_video = make_animation(
+ source_image, source_semantics, target_semantics,
+ self.generator_net, self.kp_detector_net, self.he_estimator_net, self.mapping_net,
+ yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp=True,
+ use_onnx=self.use_onnx
+ )
+
+ predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+ predictions_video = predictions_video[:frame_num]
+
+ video = []
+ for idx in range(predictions_video.shape[0]):
+ image = predictions_video[idx]
+ image = np.transpose(image.data, [1, 2, 0]).astype(np.float32)
+ video.append(image)
+ result = img_as_ubyte(video)
+
+ ### the generated video is 256x256, so we keep the aspect ratio,
+ original_size = crop_info[0]
+ if original_size:
+ result = [
+ cv2.resize(result_i, (img_size, int(img_size * original_size[1] / original_size[0])))
+ for result_i in result
+ ]
+
+ video_name = x['video_name'] + '.mp4'
+ path = os.path.join(video_save_dir, 'temp_'+video_name)
+ imageio.mimsave(path, result, fps=float(25), codec='libx264')
+
+ av_path = os.path.join(video_save_dir, video_name)
+ return_path = av_path
+
+ audio_path = x['audio_path']
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+ new_audio_path = os.path.join(video_save_dir, audio_name + '.wav')
+
+ sound = AudioSegment.from_file(audio_path)
+ end_time = frame_num * 1000 / 25
+ word = sound.set_frame_rate(16000)[0:end_time]
+ word.export(new_audio_path, format="wav")
+
+ save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name}')
+
+ if 'full' in preprocess.lower():
+ # only add watermark to the full image.
+ video_name_full = x['video_name'] + '_full.mp4'
+ full_video_path = os.path.join(video_save_dir, video_name_full)
+ return_path = full_video_path
+ paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path,
+ extended_crop= True if 'ext' in preprocess.lower() else False)
+ print(f'The generated video is named {video_save_dir}/{video_name_full}')
+ else:
+ full_video_path = av_path
+
+ # paste back then enhancers
+ if enhancer:
+ video_name_enhancer = x['video_name'] + '_enhanced.mp4'
+ enhanced_path = os.path.join(video_save_dir, 'temp_' + video_name_enhancer)
+ av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
+ return_path = av_path_enhancer
+
+ try:
+ enhanced_images_gen_with_len = enhancer_generator_with_len(
+ full_video_path, method=enhancer, bg_upsampler=background_enhancer,
+ retinaface_net=self.retinaface_net, gfpgan_net=self.gfpgan_net
+ )
+ except:
+ enhanced_images_gen_with_len = enhancer_list(
+ full_video_path, method=enhancer, bg_upsampler=background_enhancer
+ )
+
+ imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
+ save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
+ print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
+ os.remove(enhanced_path)
+
+ os.remove(path)
+ os.remove(new_audio_path)
+
+ return return_path
diff --git a/generative_adversarial_networks/sadtalker/animation/face_enhancer.py b/generative_adversarial_networks/sadtalker/animation/face_enhancer.py
new file mode 100644
index 000000000..9d93d2568
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/animation/face_enhancer.py
@@ -0,0 +1,95 @@
+import os
+import sys
+import cv2
+import numpy as np
+from tqdm import tqdm
+
+from animation.videoio import load_video_to_cv2
+
+sys.path.append('../../util')
+sys.path.append('../../face_restoration/gfpgan')
+from face_restoration import (
+ get_face_landmarks_5, align_warp_face,
+ get_inverse_affine, paste_faces_to_image
+)
+
+UPSCALE = 2
+
+class GeneratorWithLen(object):
+ """ From https://stackoverflow.com/a/7460929 """
+ def __init__(self, gen, length):
+ self.gen = gen
+ self.length = length
+
+ def __len__(self):
+ return self.length
+
+ def __iter__(self):
+ return self.gen
+
+def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'):
+ gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
+ return list(gen)
+
+def enhancer_generator_with_len(
+ images, method='gfpgan', bg_upsampler='realesrgan',
+ retinaface_net=None, gfpgan_net=None
+ ):
+ """Provide a generator with a __len__ method."""
+
+ if os.path.isfile(images): # handle video to images
+ # TODO: Create a generator version of load_video_to_cv2
+ images = load_video_to_cv2(images)
+
+ gen = enhancer_generator_no_len(
+ images, method=method, bg_upsampler=bg_upsampler,
+ retinaface_net=retinaface_net, gfpgan_net=gfpgan_net
+ )
+ return GeneratorWithLen(gen, len(images))
+
+def enhancer_generator_no_len(
+ images, method='gfpgan', bg_upsampler='realesrgan',
+ retinaface_net=None, gfpgan_net=None
+ ):
+ for idx in tqdm(range(len(images)), desc='Face Enhancer'):
+ img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)
+
+ det_faces, all_landmarks_5 = get_face_landmarks_5(retinaface_net, img, eye_dist_threshold=5)
+ cropped_faces, affine_matrices = align_warp_face(img, all_landmarks_5)
+
+ restored_faces = []
+ for cropped_face in cropped_faces:
+ x = preprocess(cropped_face)
+ output = gfpgan_net.predict([x])[0] # feedforward
+ restored_face = post_processing(output)
+ restored_faces.append(restored_face)
+
+ h, w = img.shape[:2]
+ h_up, w_up = int(h * UPSCALE), int(w * UPSCALE)
+ img = cv2.resize(img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+
+ inverse_affine_matrices = get_inverse_affine(affine_matrices, upscale_factor=UPSCALE)
+ r_img = paste_faces_to_image(img, restored_faces, inverse_affine_matrices, upscale_factor=UPSCALE)
+
+ r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
+ yield r_img
+
+def preprocess(img):
+ img = img[:, :, ::-1] # BGR -> RGB
+ img = img / 127.5 - 1.0
+ img = img.transpose(2, 0, 1) # HWC -> CHW
+ img = np.expand_dims(img, axis=0)
+ img = img.astype(np.float32)
+
+ return img
+
+def post_processing(pred):
+ img = pred[0]
+ img = img.transpose(1, 2, 0) # CHW -> HWC
+ img = img[:, :, ::-1] # RGB -> BGR
+
+ img = np.clip(img, -1, 1)
+ img = (img + 1) * 127.5
+ img = img.astype(np.uint8)
+
+ return img
diff --git a/generative_adversarial_networks/sadtalker/animation/make_animation.py b/generative_adversarial_networks/sadtalker/animation/make_animation.py
new file mode 100644
index 000000000..7db57062d
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/animation/make_animation.py
@@ -0,0 +1,128 @@
+import numpy as np
+from scipy.special import softmax
+from tqdm import tqdm
+
+def headpose_pred_to_degree(pred):
+ pred = softmax(pred, axis=1)
+ idx_tensor = np.arange(66, dtype=np.float32)
+ degree = np.sum(pred * idx_tensor, axis=1) * 3 - 99
+ return degree
+
+def get_rotation_matrix(yaw, pitch, roll):
+ yaw = yaw / 180 * np.pi
+ pitch = pitch / 180 * np.pi
+ roll = roll / 180 * np.pi
+
+ pitch_mat = np.stack([
+ np.ones_like(pitch), np.zeros_like(pitch), np.zeros_like(pitch),
+ np.zeros_like(pitch), np.cos(pitch), -np.sin(pitch),
+ np.zeros_like(pitch), np.sin(pitch), np.cos(pitch)
+ ], axis=1).reshape(-1, 3, 3)
+
+ yaw_mat = np.stack([
+ np.cos(yaw), np.zeros_like(yaw), np.sin(yaw),
+ np.zeros_like(yaw), np.ones_like(yaw), np.zeros_like(yaw),
+ -np.sin(yaw), np.zeros_like(yaw), np.cos(yaw)
+ ], axis=1).reshape(-1, 3, 3)
+
+ roll_mat = np.stack([
+ np.cos(roll), -np.sin(roll), np.zeros_like(roll),
+ np.sin(roll), np.cos(roll), np.zeros_like(roll),
+ np.zeros_like(roll), np.zeros_like(roll), np.ones_like(roll)
+ ], axis=1).reshape(-1, 3, 3)
+
+ rot_mat = np.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)
+ return rot_mat
+
+def keypoint_transformation(kp_canonical, he, wo_exp=False):
+ kp = kp_canonical # (bs, k, 3)
+ yaw, pitch, roll= he['yaw'], he['pitch'], he['roll']
+ yaw = headpose_pred_to_degree(yaw)
+ pitch = headpose_pred_to_degree(pitch)
+ roll = headpose_pred_to_degree(roll)
+
+ if 'yaw_in' in he:
+ yaw = he['yaw_in']
+ if 'pitch_in' in he:
+ pitch = he['pitch_in']
+ if 'roll_in' in he:
+ roll = he['roll_in']
+
+ rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3)
+
+ t, exp = he['t'], he['exp']
+ if wo_exp:
+ exp = exp*0
+
+ # keypoint rotation
+ kp_rotated = np.einsum('bmp,bkp->bkm', rot_mat, kp)
+
+ # keypoint translation
+ t[:, 0] = t[:, 0]*0
+ t[:, 2] = t[:, 2]*0
+ t = np.repeat(t[:, np.newaxis, :], kp.shape[1], axis=1)
+ kp_t = kp_rotated + t
+
+ # add expression deviation
+ exp = exp.reshape(exp.shape[0], -1, 3)
+ kp_transformed = kp_t + exp
+
+ return kp_transformed
+
+def make_animation(
+ source_image, source_semantics, target_semantics,
+ generator_net, kp_detector_net, he_estimator_net, mapping_net,
+ yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None, use_exp=True, use_half=False,
+ use_onnx=False
+):
+ predictions = []
+
+ if use_onnx:
+ kp_canonical = kp_detector_net.run(None, {"input_image": source_image})[0]
+ he_source_tmp = mapping_net.run(None, {"input_3dmm": source_semantics})
+ else:
+ kp_canonical = kp_detector_net.run([source_image])[0]
+ he_source_tmp = mapping_net.run([source_semantics])
+ he_source = {
+ "yaw": he_source_tmp[0],
+ "pitch": he_source_tmp[1],
+ "roll": he_source_tmp[2],
+ "t": he_source_tmp[3],
+ "exp": he_source_tmp[4],
+ }
+
+ kp_source = keypoint_transformation(kp_canonical, he_source)
+
+ for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer'):
+ target_semantics_frame = target_semantics[:, frame_idx]
+ if use_onnx:
+ he_driving_tmp = mapping_net.run(None, {"input_3dmm": target_semantics_frame})
+ else:
+ he_driving_tmp = mapping_net.run([target_semantics_frame])
+ he_driving = {
+ "yaw": he_driving_tmp[0],
+ "pitch": he_driving_tmp[1],
+ "roll": he_driving_tmp[2],
+ "t": he_driving_tmp[3],
+ "exp": he_driving_tmp[4],
+ }
+
+ if yaw_c_seq is not None:
+ he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
+ if pitch_c_seq is not None:
+ he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
+ if roll_c_seq is not None:
+ he_driving['roll_in'] = roll_c_seq[:, frame_idx]
+
+ kp_driving = keypoint_transformation(kp_canonical, he_driving)
+
+ if use_onnx:
+ out = generator_net.run(None, {
+ "source_image": source_image,
+ "kp_driving": kp_driving,
+ "kp_source": kp_source,
+ })[0]
+ else:
+ out = generator_net.run([source_image, kp_driving, kp_source])[0]
+ predictions.append(out)
+ return np.stack(predictions, axis=1)
diff --git a/generative_adversarial_networks/sadtalker/animation/paste_pic.py b/generative_adversarial_networks/sadtalker/animation/paste_pic.py
new file mode 100644
index 000000000..451103ded
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/animation/paste_pic.py
@@ -0,0 +1,72 @@
+import cv2, os, sys
+import numpy as np
+from tqdm import tqdm
+import uuid
+
+from animation.videoio import save_video_with_watermark
+
+sys.path.append('../../util')
+from image_utils import imread # noqa: E402
+
+def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop=False):
+
+ if not os.path.isfile(pic_path):
+ raise ValueError('pic_path must be a valid path to video/image file')
+ elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_img = imread(pic_path)
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(pic_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ break
+ full_img = frame
+ frame_h = full_img.shape[0]
+ frame_w = full_img.shape[1]
+
+ video_stream = cv2.VideoCapture(video_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ crop_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ crop_frames.append(frame)
+
+ if len(crop_info) != 3:
+ print("you didn't crop the image")
+ return
+ else:
+ r_w, r_h = crop_info[0]
+ clx, cly, crx, cry = crop_info[1]
+ lx, ly, rx, ry = crop_info[2]
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ if extended_crop:
+ oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+ else:
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+
+ tmp_path = str(uuid.uuid4())+'.mp4'
+ out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_w, frame_h))
+ for crop_frame in tqdm(crop_frames, 'seamlessClone'):
+ p = cv2.resize(crop_frame.astype(np.uint8), (ox2-ox1, oy2 - oy1))
+
+ mask = 255*np.ones(p.shape, p.dtype)
+ location = ((ox1+ox2) // 2, (oy1+oy2) // 2)
+ gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE)
+ out_tmp.write(gen_img)
+
+ out_tmp.release()
+
+ save_video_with_watermark(tmp_path, new_audio_path, full_video_path, watermark=False)
+ os.remove(tmp_path)
diff --git a/generative_adversarial_networks/sadtalker/animation/videoio.py b/generative_adversarial_networks/sadtalker/animation/videoio.py
new file mode 100644
index 000000000..1d583faa3
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/animation/videoio.py
@@ -0,0 +1,39 @@
+import shutil
+import uuid
+import os
+import cv2
+
+def load_video_to_cv2(input_path):
+ video_stream = cv2.VideoCapture(input_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+ return full_frames
+
+def save_video_with_watermark(video, audio, save_path, watermark=False):
+ temp_file = str(uuid.uuid4())+'.mp4'
+ cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file)
+ os.system(cmd)
+
+ if watermark is False:
+ shutil.move(temp_file, save_path)
+ else:
+ # watermark
+ try:
+ ##### check if stable-diffusion-webui
+ import webui
+ from modules import paths
+ watarmark_path = paths.script_path+"/extensions/SadTalker/docs/sadtalker_logo.png"
+ except:
+ # get the root path of sadtalker.
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ watarmark_path = dir_path+"/../../docs/sadtalker_logo.png"
+
+ cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path)
+ os.system(cmd)
+ os.remove(temp_file)
\ No newline at end of file
diff --git a/generative_adversarial_networks/sadtalker/audio2coeff/__init__.py b/generative_adversarial_networks/sadtalker/audio2coeff/__init__.py
new file mode 100644
index 000000000..6199e8d50
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/audio2coeff/__init__.py
@@ -0,0 +1 @@
+from .audio2coeff import Audio2Coeff
diff --git a/generative_adversarial_networks/sadtalker/audio2coeff/audio2coeff.py b/generative_adversarial_networks/sadtalker/audio2coeff/audio2coeff.py
new file mode 100644
index 000000000..c49f9ebb7
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/audio2coeff/audio2coeff.py
@@ -0,0 +1,58 @@
+import os
+import numpy as np
+from scipy.io import savemat, loadmat
+from scipy.signal import savgol_filter
+
+from audio2coeff.audio2exp import Audio2Exp
+from audio2coeff.audio2pose import Audio2Pose
+
+class Audio2Coeff:
+ def __init__(self, audio2exp_net, audio2pose_net, use_onnx):
+ self.audio2exp_model = Audio2Exp(audio2exp_net, use_onnx)
+ self.audio2pose_model = Audio2Pose(audio2pose_net, use_onnx)
+
+ def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_path=None):
+ results_dict_exp= self.audio2exp_model.test(batch)
+ exp_pred = results_dict_exp['exp_coeff_pred'] # bs T 64
+
+ batch['class'] = np.array([pose_style], dtype=np.int64)
+ results_dict_pose = self.audio2pose_model.test(batch)
+ pose_pred = results_dict_pose['pose_pred'] # bs T 6
+
+ pose_len = pose_pred.shape[1]
+ if pose_len < 13:
+ pose_len = int((pose_len - 1) / 2) * 2 + 1
+ pose_pred = savgol_filter(pose_pred, pose_len, 2, axis=1)
+ else:
+ pose_pred = savgol_filter(pose_pred, 13, 2, axis=1)
+
+ coeffs_pred_numpy = np.concatenate((exp_pred, pose_pred), axis=-1)[0]
+
+ if ref_pose_coeff_path is not None:
+ coeffs_pred_numpy = self.using_refpose(coeffs_pred_numpy, ref_pose_coeff_path)
+
+ savemat(os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])),
+ {'coeff_3dmm': coeffs_pred_numpy})
+
+ return os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name']))
+
+ def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path):
+ num_frames = coeffs_pred_numpy.shape[0]
+ refpose_coeff_dict = loadmat(ref_pose_coeff_path)
+ refpose_coeff = refpose_coeff_dict['coeff_3dmm'][:, 64:70]
+
+ refpose_num_frames = refpose_coeff.shape[0]
+ if refpose_num_frames < num_frames:
+ div = num_frames // refpose_num_frames
+ re = num_frames % refpose_num_frames
+
+ refpose_coeff_list = [refpose_coeff for _ in range(div)]
+ refpose_coeff_list.append(refpose_coeff[:re, :])
+ refpose_coeff = np.concatenate(refpose_coeff_list, axis=0)
+
+ # Adjust relative head pose
+ coeffs_pred_numpy[:, 64:70] += refpose_coeff[:num_frames, :] - refpose_coeff[0:1, :]
+
+ return coeffs_pred_numpy
+
+
diff --git a/generative_adversarial_networks/sadtalker/audio2coeff/audio2exp.py b/generative_adversarial_networks/sadtalker/audio2coeff/audio2exp.py
new file mode 100644
index 000000000..e8a69b618
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/audio2coeff/audio2exp.py
@@ -0,0 +1,40 @@
+from tqdm import tqdm
+import numpy as np
+
+class Audio2Exp:
+ def __init__(self, audio2exp_net, use_onnx):
+ self.audio2exp_net = audio2exp_net
+ self.use_onnx = use_onnx
+
+ def test(self, batch):
+ mel_input = batch['indiv_mels'] # (bs, T, 1, 80, 16)
+ ref = batch['ref'] # (bs, T, 70)
+ ratio = batch['ratio_gt'] # (bs, T, 1)
+
+ bs, T, _, _, _ = mel_input.shape
+ exp_coeff_pred = []
+
+ for i in tqdm(range(0, T, 10), 'audio2exp'): # 10フレームごとに処理
+ current_mel_input = mel_input[:, i:i+10] # (bs, 10, 1, 80, 16)
+ current_ref = ref[:, i:i+10, :64] # (bs, 10, 64)
+ current_ratio = ratio[:, i:i+10] # (bs, 10, 1)
+
+ audiox = current_mel_input.reshape(-1, 1, 80, 16) # (bs*T, 1, 80, 16)
+
+ if self.use_onnx:
+ curr_exp_coeff_pred = self.audio2exp_net.run(None, {
+ "audio": audiox,
+ "ref": current_ref,
+ "ratio": current_ratio,
+ })[0]
+ else:
+ curr_exp_coeff_pred = self.audio2exp_net.run([audiox, current_ref, current_ratio])[0]
+
+ exp_coeff_pred += [curr_exp_coeff_pred]
+
+ # BS x T x 64
+ results_dict = {
+ 'exp_coeff_pred': np.concatenate(exp_coeff_pred, axis=1)
+ }
+
+ return results_dict
diff --git a/generative_adversarial_networks/sadtalker/audio2coeff/audio2pose.py b/generative_adversarial_networks/sadtalker/audio2coeff/audio2pose.py
new file mode 100644
index 000000000..d259a0691
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/audio2coeff/audio2pose.py
@@ -0,0 +1,63 @@
+import numpy as np
+
+class Audio2Pose:
+ def __init__(self, audio2pose_net, use_onnx):
+ self.audio2pose_net = audio2pose_net
+ self.seq_len = 32
+ self.latent_dim = 64
+ self.use_onnx = use_onnx
+
+ def test(self, x):
+ batch = {}
+
+ ref = x['ref'] # [BS, 1, 70]
+ class_id = x['class'] # [BS]
+ bs = ref.shape[0]
+ pose_ref = ref[:, 0, -6:] # [BS, 6]
+
+ indiv_mels = x['indiv_mels'] # [BS, T, 1, 80, 16]
+ T_total = int(x['num_frames']) - 1
+
+ chunk_count = (T_total + self.seq_len - 1) // self.seq_len
+
+ pose_motion_pred_list = [
+ np.zeros((bs, 1, pose_ref.shape[1]), dtype=np.float32)
+ ]
+
+ start_idx = 0
+ for _ in range(chunk_count):
+ end_idx = min(start_idx + self.seq_len, T_total)
+ chunk_len = end_idx - start_idx
+
+ chunk_mels = indiv_mels[:, 1 + start_idx : 1 + end_idx]
+
+ if chunk_len < self.seq_len:
+ pad_len = self.seq_len - chunk_len
+ pad_chunk = np.repeat(chunk_mels[:, :1], pad_len, axis=1)
+ chunk_mels = np.concatenate([pad_chunk, chunk_mels], axis=1)
+
+ z = np.random.randn(bs, self.latent_dim).astype(np.float32)
+
+ # Inference using a single model for AudioEncoder and netG.
+ if self.use_onnx:
+ motion_pred = self.audio2pose_net.run(None, {
+ "chunk_mels": chunk_mels,
+ "z": z,
+ "pose_ref": pose_ref,
+ "class": class_id,
+ })[0]
+ else:
+ motion_pred = self.audio2pose_net.run([chunk_mels, z, pose_ref, class_id])[0]
+
+ if chunk_len < self.seq_len:
+ motion_pred = motion_pred[:, -chunk_len:, :]
+
+ pose_motion_pred_list.append(motion_pred)
+ start_idx += chunk_len
+
+ pose_motion_pred = np.concatenate(pose_motion_pred_list, axis=1) # [BS, T_total, 6]
+ pose_pred = ref[:, :1, -6:] + pose_motion_pred # [BS, T_total+1, 6]
+
+ batch['pose_motion_pred'] = pose_motion_pred
+ batch['pose_pred'] = pose_pred
+ return batch
diff --git a/generative_adversarial_networks/sadtalker/batch_generation/__init__.py b/generative_adversarial_networks/sadtalker/batch_generation/__init__.py
new file mode 100644
index 000000000..16fb97ccf
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/batch_generation/__init__.py
@@ -0,0 +1,2 @@
+from .generate_batch import get_data
+from .generate_facerender_batch import get_facerender_data
diff --git a/generative_adversarial_networks/sadtalker/batch_generation/audio.py b/generative_adversarial_networks/sadtalker/batch_generation/audio.py
new file mode 100644
index 000000000..174857158
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/batch_generation/audio.py
@@ -0,0 +1,136 @@
+import librosa
+import librosa.filters
+import numpy as np
+from scipy import signal
+from scipy.io import wavfile
+
+from batch_generation.hparams import hparams as hp
+
+def load_wav(path, sr):
+ return librosa.core.load(path, sr=sr)[0]
+
+def save_wav(wav, path, sr):
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
+ #proposed by @dsmiller
+ wavfile.write(path, sr, wav.astype(np.int16))
+
+def save_wavenet_wav(wav, path, sr):
+ librosa.output.write_wav(path, wav, sr=sr)
+
+def preemphasis(wav, k, preemphasize=True):
+ if preemphasize:
+ return signal.lfilter([1, -k], [1], wav)
+ return wav
+
+def inv_preemphasis(wav, k, inv_preemphasize=True):
+ if inv_preemphasize:
+ return signal.lfilter([1], [1, -k], wav)
+ return wav
+
+def get_hop_size():
+ hop_size = hp.hop_size
+ if hop_size is None:
+ assert hp.frame_shift_ms is not None
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
+ return hop_size
+
+def linearspectrogram(wav):
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
+
+ if hp.signal_normalization:
+ return _normalize(S)
+ return S
+
+def melspectrogram(wav):
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
+
+ if hp.signal_normalization:
+ return _normalize(S)
+ return S
+
+def _lws_processor():
+ import lws
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
+
+def _stft(y):
+ if hp.use_lws:
+ return _lws_processor(hp).stft(y).T
+ else:
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
+
+##########################################################
+#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
+def num_frames(length, fsize, fshift):
+ """Compute number of time frames of spectrogram
+ """
+ pad = (fsize - fshift)
+ if length % fshift == 0:
+ M = (length + pad * 2 - fsize) // fshift + 1
+ else:
+ M = (length + pad * 2 - fsize) // fshift + 2
+ return M
+
+
+def pad_lr(x, fsize, fshift):
+ """Compute left and right padding
+ """
+ M = num_frames(len(x), fsize, fshift)
+ pad = (fsize - fshift)
+ T = len(x) + 2 * pad
+ r = (M - 1) * fshift + fsize - T
+ return pad, pad + r
+##########################################################
+#Librosa correct padding
+def librosa_pad_lr(x, fsize, fshift):
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
+
+# Conversions
+_mel_basis = None
+
+def _linear_to_mel(spectogram):
+ global _mel_basis
+ if _mel_basis is None:
+ _mel_basis = _build_mel_basis()
+ return np.dot(_mel_basis, spectogram)
+
+def _build_mel_basis():
+ assert hp.fmax <= hp.sample_rate // 2
+ return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
+ fmin=hp.fmin, fmax=hp.fmax)
+
+def _amp_to_db(x):
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
+ return 20 * np.log10(np.maximum(min_level, x))
+
+def _db_to_amp(x):
+ return np.power(10.0, (x) * 0.05)
+
+def _normalize(S):
+ if hp.allow_clipping_in_normalization:
+ if hp.symmetric_mels:
+ return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
+ -hp.max_abs_value, hp.max_abs_value)
+ else:
+ return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
+
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
+ if hp.symmetric_mels:
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
+ else:
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
+
+def _denormalize(D):
+ if hp.allow_clipping_in_normalization:
+ if hp.symmetric_mels:
+ return (((np.clip(D, -hp.max_abs_value,
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
+ + hp.min_level_db)
+ else:
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
+
+ if hp.symmetric_mels:
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
+ else:
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
diff --git a/generative_adversarial_networks/sadtalker/batch_generation/generate_batch.py b/generative_adversarial_networks/sadtalker/batch_generation/generate_batch.py
new file mode 100644
index 000000000..967028ad9
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/batch_generation/generate_batch.py
@@ -0,0 +1,115 @@
+import os
+import numpy as np
+import random
+from tqdm import tqdm
+import scipy.io as scio
+
+import batch_generation.audio as audio
+
+def crop_pad_audio(wav, audio_length):
+ if len(wav) > audio_length:
+ wav = wav[:audio_length]
+ elif len(wav) < audio_length:
+ wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)
+ return wav
+
+def parse_audio_length(audio_length, sr, fps):
+ bit_per_frames = sr / fps
+
+ num_frames = int(audio_length / bit_per_frames)
+ audio_length = int(num_frames * bit_per_frames)
+
+ return audio_length, num_frames
+
+def generate_blink_seq(num_frames):
+ ratio = np.zeros((num_frames,1))
+ frame_id = 0
+ while frame_id in range(num_frames):
+ start = 80
+ if frame_id+start+9<=num_frames - 1:
+ ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5]
+ frame_id = frame_id+start+9
+ else:
+ break
+ return ratio
+
+def generate_blink_seq_randomly(num_frames):
+ ratio = np.zeros((num_frames,1))
+ if num_frames<=20:
+ return ratio
+ frame_id = 0
+ while frame_id in range(num_frames):
+ start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70)))
+ if frame_id+start+5<=num_frames - 1:
+ ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5]
+ frame_id = frame_id+start+5
+ else:
+ break
+ return ratio
+
+def get_data(first_coeff_path, audio_path, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True):
+
+ syncnet_mel_step_size = 16
+ fps = 25
+
+ pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
+ audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+
+
+ if idlemode:
+ num_frames = int(length_of_audio * 25)
+ indiv_mels = np.zeros((num_frames, 80, 16))
+ else:
+ wav = audio.load_wav(audio_path, 16000)
+ wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
+ wav = crop_pad_audio(wav, wav_length)
+ orig_mel = audio.melspectrogram(wav).T
+ spec = orig_mel.copy() # nframes 80
+ indiv_mels = []
+
+ for i in tqdm(range(num_frames), 'mel'):
+ start_frame_num = i-2
+ start_idx = int(80. * (start_frame_num / float(fps)))
+ end_idx = start_idx + syncnet_mel_step_size
+ seq = list(range(start_idx, end_idx))
+ seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
+ m = spec[seq, :]
+ indiv_mels.append(m.T)
+ indiv_mels = np.asarray(indiv_mels) # T 80 16
+
+ ratio = generate_blink_seq_randomly(num_frames) # T
+ source_semantics_path = first_coeff_path
+ source_semantics_dict = scio.loadmat(source_semantics_path)
+ ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
+ ref_coeff = np.repeat(ref_coeff, num_frames, axis=0)
+
+ if ref_eyeblink_coeff_path is not None:
+ ratio[:num_frames] = 0
+ refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path)
+ refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64]
+ refeyeblink_num_frames = refeyeblink_coeff.shape[0]
+ if refeyeblink_num_frames frame_num:
+ new_degree_list = new_degree_list[:frame_num]
+ elif len(new_degree_list) < frame_num:
+ for _ in range(frame_num-len(new_degree_list)):
+ new_degree_list.append(new_degree_list[-1])
+
+ remainder = frame_num%batch_size
+ if remainder!=0:
+ for _ in range(batch_size-remainder):
+ new_degree_list.append(new_degree_list[-1])
+ new_degree_np = np.array(new_degree_list).reshape(batch_size, -1)
+ return new_degree_np
+
diff --git a/generative_adversarial_networks/sadtalker/batch_generation/hparams.py b/generative_adversarial_networks/sadtalker/batch_generation/hparams.py
new file mode 100644
index 000000000..743c5c7d5
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/batch_generation/hparams.py
@@ -0,0 +1,160 @@
+from glob import glob
+import os
+
+class HParams:
+ def __init__(self, **kwargs):
+ self.data = {}
+
+ for key, value in kwargs.items():
+ self.data[key] = value
+
+ def __getattr__(self, key):
+ if key not in self.data:
+ raise AttributeError("'HParams' object has no attribute %s" % key)
+ return self.data[key]
+
+ def set_hparam(self, key, value):
+ self.data[key] = value
+
+
+# Default hyperparameters
+hparams = HParams(
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
+ # network
+ rescale=True, # Whether to rescale audio prior to preprocessing
+ rescaling_max=0.9, # Rescaling value
+
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
+ # Does not work if n_ffit is not multiple of hop_size!!
+ use_lws=False,
+
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i )
+
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
+
+ # Mel and Linear spectrograms normalization/scaling and clipping
+ signal_normalization=True,
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
+ symmetric_mels=True,
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
+ # faster and cleaner convergence)
+ max_abs_value=4.,
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
+ # be too big to avoid gradient explosion,
+ # not too small for fast convergence)
+ # Contribution by @begeekmyfriend
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
+ # levels. Also allows for better G&L phase reconstruction)
+ preemphasize=True, # whether to apply filter
+ preemphasis=0.97, # filter coefficient.
+
+ # Limits
+ min_level_db=-100,
+ ref_level_db=20,
+ fmin=55,
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+ fmax=7600, # To be increased/reduced depending on data.
+
+ ###################### Our training parameters #################################
+ img_size=96,
+ fps=25,
+
+ batch_size=16,
+ initial_learning_rate=1e-4,
+ nepochs=300000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
+ num_workers=20,
+ checkpoint_interval=3000,
+ eval_interval=3000,
+ writer_interval=300,
+ save_optimizer_state=True,
+
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
+ syncnet_batch_size=64,
+ syncnet_lr=1e-4,
+ syncnet_eval_interval=1000,
+ syncnet_checkpoint_interval=10000,
+
+ disc_wt=0.07,
+ disc_initial_learning_rate=1e-4,
+)
+
+
+
+# Default hyperparameters
+hparamsdebug = HParams(
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
+ # network
+ rescale=True, # Whether to rescale audio prior to preprocessing
+ rescaling_max=0.9, # Rescaling value
+
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
+ # Does not work if n_ffit is not multiple of hop_size!!
+ use_lws=False,
+
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i )
+
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
+
+ # Mel and Linear spectrograms normalization/scaling and clipping
+ signal_normalization=True,
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
+ symmetric_mels=True,
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
+ # faster and cleaner convergence)
+ max_abs_value=4.,
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
+ # be too big to avoid gradient explosion,
+ # not too small for fast convergence)
+ # Contribution by @begeekmyfriend
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
+ # levels. Also allows for better G&L phase reconstruction)
+ preemphasize=True, # whether to apply filter
+ preemphasis=0.97, # filter coefficient.
+
+ # Limits
+ min_level_db=-100,
+ ref_level_db=20,
+ fmin=55,
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+ fmax=7600, # To be increased/reduced depending on data.
+
+ ###################### Our training parameters #################################
+ img_size=96,
+ fps=25,
+
+ batch_size=2,
+ initial_learning_rate=1e-3,
+ nepochs=100000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
+ num_workers=0,
+ checkpoint_interval=10000,
+ eval_interval=10,
+ writer_interval=5,
+ save_optimizer_state=True,
+
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
+ syncnet_batch_size=64,
+ syncnet_lr=1e-4,
+ syncnet_eval_interval=10000,
+ syncnet_checkpoint_interval=10000,
+
+ disc_wt=0.07,
+ disc_initial_learning_rate=1e-4,
+)
+
+
+def hparams_debug_string():
+ values = hparams.values()
+ hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
+ return "Hyperparameters:\n" + "\n".join(hp)
diff --git a/generative_adversarial_networks/sadtalker/input.png b/generative_adversarial_networks/sadtalker/input.png
new file mode 100644
index 000000000..4388abe02
Binary files /dev/null and b/generative_adversarial_networks/sadtalker/input.png differ
diff --git a/generative_adversarial_networks/sadtalker/input.wav b/generative_adversarial_networks/sadtalker/input.wav
new file mode 100644
index 000000000..888647738
Binary files /dev/null and b/generative_adversarial_networks/sadtalker/input.wav differ
diff --git a/generative_adversarial_networks/sadtalker/input_full_body.png b/generative_adversarial_networks/sadtalker/input_full_body.png
new file mode 100644
index 000000000..4fca65c94
Binary files /dev/null and b/generative_adversarial_networks/sadtalker/input_full_body.png differ
diff --git a/generative_adversarial_networks/sadtalker/input_japanese.wav b/generative_adversarial_networks/sadtalker/input_japanese.wav
new file mode 100644
index 000000000..0b6146fd3
Binary files /dev/null and b/generative_adversarial_networks/sadtalker/input_japanese.wav differ
diff --git a/generative_adversarial_networks/sadtalker/preprocess/__init__.py b/generative_adversarial_networks/sadtalker/preprocess/__init__.py
new file mode 100644
index 000000000..616a585d3
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/preprocess/__init__.py
@@ -0,0 +1 @@
+from .preprocess import CropAndExtract
\ No newline at end of file
diff --git a/generative_adversarial_networks/sadtalker/preprocess/croper.py b/generative_adversarial_networks/sadtalker/preprocess/croper.py
new file mode 100644
index 000000000..7943b57a7
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/preprocess/croper.py
@@ -0,0 +1,109 @@
+import os
+import cv2
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+from preprocess.extract_kp_videos_safe import KeypointExtractor
+from preprocess.face_detection import face_detect
+from preprocess.facexlib_alignment import landmark_98_to_68
+
+class Preprocesser:
+ def __init__(self, face_align_net, face_det_net, use_onnx):
+ self.predictor = KeypointExtractor(face_align_net, face_det_net, use_onnx)
+
+ def get_landmark(self, img_np):
+ dets = face_detect(img_np, self.predictor.face_det_net)
+ if len(dets) == 0:
+ return None
+ det = dets[0]
+
+ img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :]
+ lm = landmark_98_to_68(self.predictor.detector.get_landmarks(img))
+ lm[:,0] += int(det[0])
+ lm[:,1] += int(det[1])
+ return lm
+
+ def align_face(self, img, lm, output_size=1024):
+ lm_chin = lm[0: 17] # left-right
+ lm_eyebrow_left = lm[17: 22] # left-right
+ lm_eyebrow_right = lm[22: 27] # left-right
+ lm_nose = lm[27: 31] # top-down
+ lm_nostrils = lm[31: 36] # top-down
+ lm_eye_left = lm[36: 42] # left-clockwise
+ lm_eye_right = lm[42: 48] # left-clockwise
+ lm_mouth_outer = lm[48: 60] # left-clockwise
+ lm_mouth_inner = lm[60: 68] # left-clockwise
+
+ # Calculate auxiliary vectors.
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ mouth_left = lm_mouth_outer[0]
+ mouth_right = lm_mouth_outer[6]
+ mouth_avg = (mouth_left + mouth_right) * 0.5
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Choose oriented crop rectangle.
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ x /= np.hypot(*x)
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
+ y = np.flipud(x) * [-1, 1]
+ c = eye_avg + eye_to_mouth * 0.1
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ qsize = np.hypot(*x) * 2
+
+ # Shrink.
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
+ img = img.resize(rsize, Image.ANTIALIAS)
+ quad /= shrink
+ qsize /= shrink
+ else:
+ rsize = (int(np.rint(float(img.size[0]))), int(np.rint(float(img.size[1]))))
+
+ # Crop.
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
+ min(crop[3] + border, img.size[1]))
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+ quad -= crop[0:2]
+
+ # Pad.
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
+ max(pad[3] - img.size[1] + border, 0))
+
+ # Transform.
+ quad = (quad + 0.5).flatten()
+ lx = max(min(quad[0], quad[2]), 0)
+ ly = max(min(quad[1], quad[7]), 0)
+ rx = min(max(quad[4], quad[6]), img.size[0])
+ ry = min(max(quad[3], quad[5]), img.size[0])
+
+ # Save aligned image.
+ return rsize, crop, [lx, ly, rx, ry]
+
+ def crop(self, img_np_list, still=False, xsize=512):
+ img_np = img_np_list[0]
+ lm = self.get_landmark(img_np)
+ if lm is None:
+ raise 'can not detect the landmark from source image'
+
+ rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize)
+
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ for i, img in enumerate(img_np_list):
+ img = cv2.resize(img, (rsize[0], rsize[1]))[cly:cry, clx:crx]
+ if not still:
+ img = img[ly:ry, lx:rx]
+ img_np_list[i] = img
+ return img_np_list, crop, quad
+
diff --git a/generative_adversarial_networks/sadtalker/preprocess/extract_kp_videos_safe.py b/generative_adversarial_networks/sadtalker/preprocess/extract_kp_videos_safe.py
new file mode 100644
index 000000000..54cc4266a
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/preprocess/extract_kp_videos_safe.py
@@ -0,0 +1,59 @@
+import os
+import cv2
+import time
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+
+from preprocess.face_detection import face_detect
+from preprocess.fan import FAN
+from preprocess.facexlib_alignment import landmark_98_to_68
+
+class KeypointExtractor():
+ def __init__(self, face_align_net, face_det_net, use_onnx):
+ self.detector = FAN(face_align_net, use_onnx)
+ self.face_det_net = face_det_net
+
+ def extract_keypoint(self, images, name=None):
+ if isinstance(images, list):
+ keypoints = []
+ for image in tqdm(images, desc='landmark Det'):
+ current_kp = self.extract_keypoint(image)
+ if np.mean(current_kp) == -1 and keypoints:
+ keypoints.append(keypoints[-1])
+ else:
+ keypoints.append(current_kp[None])
+
+ keypoints = np.concatenate(keypoints, 0)
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
+ else:
+ while True:
+ try:
+ # face detection -> face alignment.
+ img = np.array(images)
+ bboxes = face_detect(img, self.face_det_net)[0]
+ img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
+
+ keypoints = landmark_98_to_68(self.detector.get_landmarks(img))
+
+ # keypoints to the original location
+ keypoints[:,0] += int(bboxes[0])
+ keypoints[:,1] += int(bboxes[1])
+
+ break
+ except RuntimeError as e:
+ if str(e).startswith('CUDA'):
+ print("Warning: out of memory, sleep for 1s")
+ time.sleep(1)
+ else:
+ print(e)
+ break
+ except TypeError:
+ print('No face detected in this image')
+ shape = [68, 2]
+ keypoints = -1. * np.ones(shape)
+ break
+ if name is not None:
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+ return keypoints
diff --git a/generative_adversarial_networks/sadtalker/preprocess/face3d_utils.py b/generative_adversarial_networks/sadtalker/preprocess/face3d_utils.py
new file mode 100644
index 000000000..2612d97ef
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/preprocess/face3d_utils.py
@@ -0,0 +1,102 @@
+"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+from scipy.io import loadmat
+from PIL import Image
+import cv2
+import os
+from skimage import transform as trans
+import warnings
+warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
+warnings.filterwarnings("ignore", category=FutureWarning)
+
+
+# calculating least square problem for image alignment
+def POS(xp, x):
+ npts = xp.shape[1]
+
+ A = np.zeros([2*npts, 8])
+
+ A[0:2*npts-1:2, 0:3] = x.transpose()
+ A[0:2*npts-1:2, 3] = 1
+
+ A[1:2*npts:2, 4:7] = x.transpose()
+ A[1:2*npts:2, 7] = 1
+
+ b = np.reshape(xp.transpose(), [2*npts, 1])
+
+ k, _, _, _ = np.linalg.lstsq(A, b)
+
+ R1 = k[0:3]
+ R2 = k[4:7]
+ sTx = k[3]
+ sTy = k[7]
+ s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2
+ t = np.stack([sTx, sTy], axis=0)
+
+ return t, s
+
+# resize and crop images for face reconstruction
+def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):
+ w0, h0 = img.size
+ w = (w0*s).astype(np.int32)
+ h = (h0*s).astype(np.int32)
+ left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32)
+ right = left + target_size
+ up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32)
+ below = up + target_size
+
+ img = img.resize((w, h), resample=Image.BICUBIC)
+ img = img.crop((left, up, right, below))
+
+ if mask is not None:
+ mask = mask.resize((w, h), resample=Image.BICUBIC)
+ mask = mask.crop((left, up, right, below))
+
+ lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] -
+ t[1] + h0/2], axis=1)*s
+ lm = lm - np.reshape(
+ np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2])
+
+ return img, lm, mask
+
+# utils for face reconstruction
+def extract_5p(lm):
+ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+ lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean(
+ lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0)
+ lm5p = lm5p[[1, 2, 0, 3, 4], :]
+ return lm5p
+
+# utils for face reconstruction
+def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.):
+ """
+ Return:
+ transparams --numpy.array (raw_W, raw_H, scale, tx, ty)
+ img_new --PIL.Image (target_size, target_size, 3)
+ lm_new --numpy.array (68, 2), y direction is opposite to v direction
+ mask_new --PIL.Image (target_size, target_size)
+
+ Parameters:
+ img --PIL.Image (raw_H, raw_W, 3)
+ lm --numpy.array (68, 2), y direction is opposite to v direction
+ lm3D --numpy.array (5, 3)
+ mask --PIL.Image (raw_H, raw_W, 3)
+ """
+
+ w0, h0 = img.size
+ if lm.shape[0] != 5:
+ lm5p = extract_5p(lm)
+ else:
+ lm5p = lm
+
+ # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face
+ t, s = POS(lm5p.transpose(), lm3D.transpose())
+ s = rescale_factor/s
+
+ # processing the image
+ img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask)
+ trans_params = np.array([w0, h0, s, float(t[0]), float(t[1])])
+
+ return trans_params, img_new, lm_new, mask_new
diff --git a/generative_adversarial_networks/sadtalker/preprocess/face_detection.py b/generative_adversarial_networks/sadtalker/preprocess/face_detection.py
new file mode 100644
index 000000000..c78129395
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/preprocess/face_detection.py
@@ -0,0 +1,77 @@
+"""
+reference: ailia-models/face_detection/retinaface
+"""
+
+import sys
+import numpy as np
+
+sys.path.append('../../face_detection/retinaface')
+import retinaface_utils as rut
+from retinaface_utils import PriorBox
+
+CONFIDENCE_THRES = 0.02
+TOP_K = 5000
+NMS_THRES = 0.4
+KEEP_TOP_K = 750
+
+def face_detect(image, retinaface_net):
+ """
+ Args:
+ image (numpy.ndarray): Input image (H, W, C) in BGR format.
+ retinaface_net (ailia.Net): Ailia RetinaFace model.
+
+ Returns:
+ numpy.ndarray: Bounding boxes of detected faces (N, 4) in (x1, y1, x2, y2) format.
+ """
+ cfg = rut.cfg_re50
+
+ dim = (image.shape[1], image.shape[0])
+ image = image - (104, 117, 123)
+ image = np.expand_dims(image.transpose(2, 0, 1), axis=0)
+
+ preds = retinaface_net.predict([image])
+
+ detections = postprocessing(preds, image, cfg=cfg, dim=dim)
+ bboxes = detections[:, :4].astype(int)
+ return bboxes
+
+def postprocessing(preds_ailia, input_data, cfg, dim):
+ IMAGE_WIDTH, IMAGE_HEIGHT = dim
+ scale = np.array([IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_HEIGHT])
+ loc, conf, landms = preds_ailia
+ priorbox = PriorBox(cfg, image_size=(IMAGE_HEIGHT, IMAGE_WIDTH))
+ priors = priorbox.forward()
+ boxes = rut.decode(np.squeeze(loc, axis=0), priors, cfg['variance'])
+ boxes = boxes * scale
+ scores = np.squeeze(conf, axis=0)[:, 1]
+ landms = rut.decode_landm(np.squeeze(landms, axis=0), priors, cfg['variance'])
+ scale1 = np.array([input_data.shape[3], input_data.shape[2], input_data.shape[3], input_data.shape[2],
+ input_data.shape[3], input_data.shape[2], input_data.shape[3], input_data.shape[2],
+ input_data.shape[3], input_data.shape[2]])
+ landms = landms * scale1
+
+ # ignore low scores
+ inds = np.where(scores > CONFIDENCE_THRES)[0]
+ boxes = boxes[inds]
+ landms = landms[inds]
+ scores = scores[inds]
+
+ # keep top-K before NMS
+ order = scores.argsort()[::-1][:TOP_K]
+ boxes = boxes[order]
+ landms = landms[order]
+ scores = scores[order]
+
+ # do NMS
+ dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
+ keep = rut.py_cpu_nms(dets, NMS_THRES)
+ dets = dets[keep, :]
+ landms = landms[keep]
+
+ # keep top-K faster NMS
+ dets = dets[:KEEP_TOP_K, :]
+ landms = landms[:KEEP_TOP_K, :]
+
+ detections = np.concatenate((dets, landms), axis=1)
+
+ return detections
diff --git a/generative_adversarial_networks/sadtalker/preprocess/facexlib_alignment.py b/generative_adversarial_networks/sadtalker/preprocess/facexlib_alignment.py
new file mode 100644
index 000000000..75833fd15
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/preprocess/facexlib_alignment.py
@@ -0,0 +1,48 @@
+"""
+reference: facexlib.alignment
+"""
+import numpy as np
+
+def landmark_98_to_68(landmark_98):
+ """Transfer 98 landmark positions to 68 landmark positions.
+ Args:
+ landmark_98(numpy array): Polar coordinates of 98 landmarks, (98, 2)
+ Returns:
+ landmark_68(numpy array): Polar coordinates of 98 landmarks, (68, 2)
+ """
+
+ landmark_68 = np.zeros((68, 2), dtype='float32')
+ # cheek
+ for i in range(0, 33):
+ if i % 2 == 0:
+ landmark_68[int(i / 2), :] = landmark_98[i, :]
+ # nose
+ for i in range(51, 60):
+ landmark_68[i - 24, :] = landmark_98[i, :]
+ # mouth
+ for i in range(76, 96):
+ landmark_68[i - 28, :] = landmark_98[i, :]
+ # left eyebrow
+ landmark_68[17, :] = landmark_98[33, :]
+ landmark_68[18, :] = (landmark_98[34, :] + landmark_98[41, :]) / 2
+ landmark_68[19, :] = (landmark_98[35, :] + landmark_98[40, :]) / 2
+ landmark_68[20, :] = (landmark_98[36, :] + landmark_98[39, :]) / 2
+ landmark_68[21, :] = (landmark_98[37, :] + landmark_98[38, :]) / 2
+ # right eyebrow
+ landmark_68[22, :] = (landmark_98[42, :] + landmark_98[50, :]) / 2
+ landmark_68[23, :] = (landmark_98[43, :] + landmark_98[49, :]) / 2
+ landmark_68[24, :] = (landmark_98[44, :] + landmark_98[48, :]) / 2
+ landmark_68[25, :] = (landmark_98[45, :] + landmark_98[47, :]) / 2
+ landmark_68[26, :] = landmark_98[46, :]
+ # left eye
+ LUT_landmark_68_left_eye = [36, 37, 38, 39, 40, 41]
+ LUT_landmark_98_left_eye = [60, 61, 63, 64, 65, 67]
+ for idx, landmark_98_index in enumerate(LUT_landmark_98_left_eye):
+ landmark_68[LUT_landmark_68_left_eye[idx], :] = landmark_98[landmark_98_index, :]
+ # right eye
+ LUT_landmark_68_right_eye = [42, 43, 44, 45, 46, 47]
+ LUT_landmark_98_right_eye = [68, 69, 71, 72, 73, 75]
+ for idx, landmark_98_index in enumerate(LUT_landmark_98_right_eye):
+ landmark_68[LUT_landmark_68_right_eye[idx], :] = landmark_98[landmark_98_index, :]
+
+ return landmark_68
diff --git a/generative_adversarial_networks/sadtalker/preprocess/fan.py b/generative_adversarial_networks/sadtalker/preprocess/fan.py
new file mode 100644
index 000000000..b72ea9e11
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/preprocess/fan.py
@@ -0,0 +1,64 @@
+import cv2
+import numpy as np
+
+def calculate_points(heatmaps):
+ # change heatmaps to landmarks
+ B, N, H, W = heatmaps.shape
+ HW = H * W
+ BN_range = np.arange(B * N)
+
+ heatline = heatmaps.reshape(B, N, HW)
+ indexes = np.argmax(heatline, axis=2)
+
+ preds = np.stack((indexes % W, indexes // W), axis=2)
+ preds = preds.astype(np.float64, copy=False)
+
+ inr = indexes.ravel()
+
+ heatline = heatline.reshape(B * N, HW)
+ x_up = heatline[BN_range, inr + 1]
+ x_down = heatline[BN_range, inr - 1]
+ # y_up = heatline[BN_range, inr + W]
+
+ if any((inr + W) >= 4096):
+ y_up = heatline[BN_range, 4095]
+ else:
+ y_up = heatline[BN_range, inr + W]
+ if any((inr - W) <= 0):
+ y_down = heatline[BN_range, 0]
+ else:
+ y_down = heatline[BN_range, inr - W]
+
+ think_diff = np.sign(np.stack((x_up - x_down, y_up - y_down), axis=1))
+ think_diff *= .25
+
+ preds += think_diff.reshape(B, N, 2)
+ preds += .5
+ return preds
+
+class FAN():
+ def __init__(self, face_align_net, use_onnx):
+ self.face_align_net = face_align_net
+ self.use_onnx = use_onnx
+
+ def get_landmarks(self, img):
+ H, W, _ = img.shape
+ offset = W / 64, H / 64, 0, 0
+
+ img = cv2.resize(img, (256, 256))
+ inp = img[..., ::-1]
+ inp = np.transpose(inp.astype(np.float32), (2, 0, 1))
+ inp = np.expand_dims(inp / 255.0, axis=0)
+
+ if self.use_onnx:
+ outputs = self.face_align_net.run(None, {"input_image": inp})[0]
+ else:
+ outputs = self.face_align_net.run([inp])[0]
+ out = outputs[:, :-1, :, :]
+
+ pred = calculate_points(out).reshape(-1, 2)
+
+ pred *= offset[:2]
+ pred += offset[-2:]
+
+ return pred
diff --git a/generative_adversarial_networks/sadtalker/preprocess/preprocess.py b/generative_adversarial_networks/sadtalker/preprocess/preprocess.py
new file mode 100644
index 000000000..1e6ae277e
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/preprocess/preprocess.py
@@ -0,0 +1,153 @@
+import os, sys
+import numpy as np
+import cv2
+from tqdm import tqdm
+from PIL import Image
+from scipy.io import loadmat, savemat
+
+from preprocess.face3d_utils import align_img
+from preprocess.croper import Preprocesser
+
+sys.path.append('../../util')
+from image_utils import imread # noqa: E402
+
+def split_coeff(coeffs):
+ return {
+ 'id': coeffs[:, :80],
+ 'exp': coeffs[:, 80:144],
+ 'tex': coeffs[:, 144:224],
+ 'angle': coeffs[:, 224:227],
+ 'gamma': coeffs[:, 227:254],
+ 'trans': coeffs[:, 254:]
+ }
+
+def load_lm3d(lm3d_path):
+ Lm3D = loadmat(lm3d_path)['lm']
+
+ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+ selected_lms = [
+ Lm3D[lm_idx[0], :],
+ np.mean(Lm3D[lm_idx[[1, 2]], :], axis=0),
+ np.mean(Lm3D[lm_idx[[3, 4]], :], axis=0),
+ Lm3D[lm_idx[5], :],
+ Lm3D[lm_idx[6], :]
+ ]
+
+ Lm3D = np.stack(selected_lms, axis=0)
+ return Lm3D[[1, 2, 0, 3, 4], :]
+
+class CropAndExtract():
+ def __init__(self, face3d_recon_net, face_align_net, face_det_net, lm3d_path, use_onnx):
+ self.propress = Preprocesser(face_align_net, face_det_net, use_onnx)
+ self.face3d_recon_net = face3d_recon_net
+ self.lm3d_std = load_lm3d(lm3d_path)
+ self.use_onnx = use_onnx
+
+ def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256):
+ pic_name = os.path.splitext(os.path.split(input_path)[-1])[0]
+ landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt')
+ coeff_path = os.path.join(save_dir, pic_name+'.mat')
+ png_path = os.path.join(save_dir, pic_name+'.png')
+
+ #load input
+ if not os.path.isfile(input_path):
+ raise ValueError('input_path must be a valid path to video/image file')
+ elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+ # loader for first frame
+ full_frames = [imread(input_path)]
+ fps = 25
+ else:
+ # loader for videos
+ video_stream = cv2.VideoCapture(input_path)
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ full_frames.append(frame)
+ if source_image_flag:
+ break
+
+ x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames]
+
+ #### crop images as the
+ if 'crop' in crop_or_resize.lower(): # default crop
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+ elif 'full' in crop_or_resize.lower():
+ x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512)
+ clx, cly, crx, cry = crop
+ lx, ly, rx, ry = quad
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+ oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+ crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+ else: # resize mode
+ oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1]
+ crop_info = ((ox2 - ox1, oy2 - oy1), None, None)
+
+ frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames]
+ if len(frames_pil) == 0:
+ print('No face is detected in the input file')
+ return None, None
+
+ # save crop info
+ for frame in frames_pil:
+ cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
+
+ # 2. get the landmark according to the detected face.
+ if not os.path.isfile(landmarks_path):
+ lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path)
+ else:
+ print(' Using saved landmarks.')
+ lm = np.loadtxt(landmarks_path).astype(np.float32)
+ lm = lm.reshape([len(x_full_frames), -1, 2])
+
+ if not os.path.isfile(coeff_path):
+ # load 3dmm paramter generator from Deep3DFaceRecon_pytorch
+ video_coeffs, full_coeffs = [], []
+ for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video'):
+ frame = frames_pil[idx]
+ W,H = frame.size
+ lm1 = lm[idx].reshape([-1, 2])
+
+ if np.mean(lm1) == -1:
+ lm1 = (self.lm3d_std[:, :2]+1)/2.
+ lm1 = np.concatenate(
+ [lm1[:, :1]*W, lm1[:, 1:2]*H], 1
+ )
+ else:
+ lm1[:, -1] = H - 1 - lm1[:, -1]
+
+ trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std)
+
+ trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32)
+ im_np = np.transpose(np.array(im1, dtype=np.float32) / 255.0, (2, 0, 1))[np.newaxis, ...]
+
+ if self.use_onnx:
+ full_coeff = self.face3d_recon_net.run(None, {"input_image": im_np})[0]
+ else:
+ full_coeff = self.face3d_recon_net.run([im_np])[0]
+ coeffs = split_coeff(full_coeff)
+
+ pred_coeff = {key:coeffs[key] for key in coeffs}
+
+ pred_coeff = np.concatenate([
+ pred_coeff['exp'],
+ pred_coeff['angle'],
+ pred_coeff['trans'],
+ trans_params[2:][None],
+ ], 1)
+ video_coeffs.append(pred_coeff)
+ full_coeffs.append(full_coeff)
+
+ semantic_npy = np.array(video_coeffs)[:,0]
+
+ savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0]})
+
+ return coeff_path, png_path, crop_info
diff --git a/generative_adversarial_networks/sadtalker/preprocess/similarity_Lm3D_all.mat b/generative_adversarial_networks/sadtalker/preprocess/similarity_Lm3D_all.mat
new file mode 100644
index 000000000..a0e235883
Binary files /dev/null and b/generative_adversarial_networks/sadtalker/preprocess/similarity_Lm3D_all.mat differ
diff --git a/generative_adversarial_networks/sadtalker/requirements.txt b/generative_adversarial_networks/sadtalker/requirements.txt
new file mode 100644
index 000000000..d17544097
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/requirements.txt
@@ -0,0 +1,5 @@
+pydub==0.25.1
+librosa==0.9.2
+imageio==2.19.3
+imageio-ffmpeg==0.4.7
+scikit-image==0.19.3
diff --git a/generative_adversarial_networks/sadtalker/sadtalker.py b/generative_adversarial_networks/sadtalker/sadtalker.py
new file mode 100644
index 000000000..8613ee646
--- /dev/null
+++ b/generative_adversarial_networks/sadtalker/sadtalker.py
@@ -0,0 +1,298 @@
+import os, sys, time
+from time import strftime
+import shutil
+import numpy as np
+import random
+from argparse import ArgumentParser
+
+import ailia
+
+# import original modules
+sys.path.append('../../util')
+import webcamera_utils # noqa: E402
+from image_utils import imread, load_image # noqa: E402
+from model_utils import check_and_download_models # noqa: E402
+from arg_utils import get_base_parser, get_savepath, update_parser # noqa: E402
+
+from preprocess import CropAndExtract
+from audio2coeff import Audio2Coeff
+from animation import AnimateFromCoeff
+from batch_generation import get_data, get_facerender_data
+
+# logger
+from logging import getLogger # noqa: E402
+logger = getLogger(__name__)
+
+# ======================
+# Parameters
+# ======================
+IMAGE_PATH = 'input.png'
+INPUT_AUDIO_PATH = "input.wav"
+SAVE_IMAGE_PATH = 'output.mp4'
+LM3D_PATH = "./preprocess/similarity_Lm3D_all.mat"
+PREPROCESS_LIST = ['crop', 'extcrop', 'resize', 'full', 'extfull']
+
+# ======================
+# Arguemnt Parser Config
+# ======================
+parser = get_base_parser("sadtalker", IMAGE_PATH, SAVE_IMAGE_PATH)
+parser.add_argument("-a", "--audio", default=INPUT_AUDIO_PATH, help="Path to input audio")
+parser.add_argument("--result_dir", default='./results', help="path to output")
+parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
+parser.add_argument("--expression_scale", type=float, default=1.0, help="the value of expression intensity")
+parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
+parser.add_argument("--size", type=int, default=256, help="the image size of the facerender")
+parser.add_argument('--enhancer', action="store_true", help="Face enhancer with gfpgan")
+parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion")
+parser.add_argument("--preprocess", default='crop', choices=PREPROCESS_LIST, help="how to preprocess the images")
+parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
+parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
+parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user")
+parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
+parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
+parser.add_argument("--verbose", action="store_true", help="saving the intermedia output or not")
+parser.add_argument("--seed", type=int, default=42, help="ramdom seed")
+parser.add_argument('-o', '--onnx', action='store_true', help="Option to use onnxrutime to run or not.")
+args = update_parser(parser)
+
+# ======================
+# Parameters 2
+# ======================
+WEIGHT_FACE3D_RECON_PATH = "face3d_recon.onnx"
+MODEL_FACE3D_RECON_PATH = "face3d_recon.onnx.prototxt"
+WEIGHT_FACE_ALIGN_PATH = "face_align.onnx"
+MODEL_FACE_ALIGN_PATH = "face_align.onnx.prototxt"
+WEIGHT_AUDIO2EXP_PATH = "audio2exp.onnx"
+MODEL_AUDIO2EXP_PATH = "audio2exp.onnx.prototxt"
+WEIGHT_AUDIO2POSE_PATH = "audio2pose.onnx"
+MODEL_AUDIO2POSE_PATH = "audio2pose.onnx.prototxt"
+WEIGHT_ANIMATION_GENERATOR_PATH = "animation_generator.onnx"
+MODEL_ANIMATION_GENERATOR_PATH = "animation_generator.onnx.prototxt"
+WEIGHT_KP_DETECTOR_PATH = "kp_detector.onnx"
+MODEL_KP_DETECTOR_PATH = "kp_detector.onnx.prototxt"
+WEIGHT_MAPPING_NET = "mappingnet_full.onnx" if "full" in args.preprocess else "mappingnet_not_full.onnx"
+MODEL_MAPPING_NET = WEIGHT_MAPPING_NET + ".prototxt"
+REMOTE_PATH = "https://storage.googleapis.com/ailia-models/sadtalker/"
+
+WEIGHT_FACE_DET_PATH = "retinaface_resnet50.onnx"
+MODEL_FACE_DET_PATH = "retinaface_resnet50.onnx.prototxt"
+REMOTE_FACE_DET_PATH = "https://storage.googleapis.com/ailia-models/retinaface/"
+
+WEIGHT_GFPGAN_PATH = "GFPGANv1.4.onnx"
+MODEL_GFPGAN_PATH = "GFPGANv1.4.onnx.prototxt"
+REMOTE_GFPGAN_PATH = "https://storage.googleapis.com/ailia-models/gfpgan/"
+
+# ======================
+# Utils
+# ======================
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+
+def load_model(model_path, weight_path, env_id=args.env_id, use_onnx=args.onnx):
+ if use_onnx:
+ import onnxruntime
+ cuda = 0 < ailia.get_gpu_environment_id()
+ providers = (
+ ["CUDAExecutionProvider", "CPUExecutionProvider"]
+ if cuda
+ else ["CPUExecutionProvider"]
+ )
+ return onnxruntime.InferenceSession(weight_path, providers=providers)
+ else:
+ return ailia.Net(model_path, weight_path, env_id=env_id)
+
+def generate_ref_coeff(preprocess_model, video_path, save_dir):
+ if not video_path:
+ return None
+
+ videoname = os.path.splitext(os.path.split(video_path)[-1])[0]
+ frame_dir = os.path.join(save_dir, videoname)
+ os.makedirs(frame_dir, exist_ok=True)
+
+ print(f'3DMM Extraction for reference video: {videoname}')
+ coeff_path, _, _ = preprocess_model.generate(
+ video_path,
+ frame_dir,
+ args.preprocess,
+ source_image_flag=False
+ )
+ return coeff_path
+
+# ======================
+# Main functions
+# ======================
+def download_and_load_models():
+ check_and_download_models(WEIGHT_FACE3D_RECON_PATH, MODEL_FACE3D_RECON_PATH, REMOTE_PATH)
+ check_and_download_models(WEIGHT_FACE_ALIGN_PATH, MODEL_FACE_ALIGN_PATH, REMOTE_PATH)
+ check_and_download_models(WEIGHT_AUDIO2EXP_PATH, MODEL_AUDIO2EXP_PATH, REMOTE_PATH)
+ check_and_download_models(WEIGHT_AUDIO2POSE_PATH, MODEL_AUDIO2POSE_PATH, REMOTE_PATH)
+ check_and_download_models(WEIGHT_ANIMATION_GENERATOR_PATH, MODEL_ANIMATION_GENERATOR_PATH, REMOTE_PATH)
+ check_and_download_models(WEIGHT_KP_DETECTOR_PATH, MODEL_KP_DETECTOR_PATH, REMOTE_PATH)
+ check_and_download_models(WEIGHT_MAPPING_NET, MODEL_MAPPING_NET, REMOTE_PATH)
+ check_and_download_models(WEIGHT_FACE_DET_PATH, MODEL_FACE_DET_PATH, REMOTE_FACE_DET_PATH)
+ if args.enhancer:
+ check_and_download_models(WEIGHT_GFPGAN_PATH, MODEL_GFPGAN_PATH, REMOTE_GFPGAN_PATH)
+
+ models = {
+ "face3d_recon_net": load_model(MODEL_FACE3D_RECON_PATH, WEIGHT_FACE3D_RECON_PATH),
+ "face_align_net": load_model(MODEL_FACE_ALIGN_PATH, WEIGHT_FACE_ALIGN_PATH),
+ "audio2exp_net": load_model(MODEL_AUDIO2EXP_PATH, WEIGHT_AUDIO2EXP_PATH),
+ "audio2pose_net": load_model(MODEL_AUDIO2POSE_PATH, WEIGHT_AUDIO2POSE_PATH),
+ "generator_net": load_model(MODEL_ANIMATION_GENERATOR_PATH, WEIGHT_ANIMATION_GENERATOR_PATH),
+ "kp_detector_net": load_model(MODEL_KP_DETECTOR_PATH, WEIGHT_KP_DETECTOR_PATH),
+ "mapping_net": load_model(MODEL_MAPPING_NET, WEIGHT_MAPPING_NET),
+ "retinaface_net": ailia.Net(MODEL_FACE_DET_PATH, WEIGHT_FACE_DET_PATH, env_id=args.env_id),
+ "gfpgan_net": ailia.Net(MODEL_GFPGAN_PATH, WEIGHT_GFPGAN_PATH, env_id=args.env_id) if args.enhancer else None
+ }
+ return models
+
+def preprocess_image(preprocess_model, pic_path, save_dir):
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
+ os.makedirs(first_frame_dir, exist_ok=True)
+
+ first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(
+ pic_path,
+ first_frame_dir,
+ args.preprocess,
+ source_image_flag=True,
+ pic_size=args.size
+ )
+ if first_coeff_path is None:
+ raise ValueError("Error: Can't get the coeffs of the input.")
+
+ return first_coeff_path, crop_pic_path, crop_info
+
+def extract_reference_coeffs(preprocess_model, ref_eyeblink, ref_pose, save_dir):
+ ref_eyeblink_coeff_path = generate_ref_coeff(preprocess_model, ref_eyeblink, save_dir)
+
+ if ref_pose == ref_eyeblink:
+ ref_pose_coeff_path = ref_eyeblink_coeff_path
+ else:
+ ref_pose_coeff_path = generate_ref_coeff(preprocess_model, ref_pose, save_dir)
+
+ return ref_eyeblink_coeff_path, ref_pose_coeff_path
+
+def generate_audio_to_coeff(
+ audio_to_coeff,
+ first_coeff_path,
+ audio_path,
+ ref_eyeblink_coeff_path,
+ save_dir,
+ ref_pose_coeff_path
+):
+ batch = get_data(first_coeff_path, audio_path, ref_eyeblink_coeff_path, still=args.still)
+ coeff_path = audio_to_coeff.generate(batch, save_dir, args.pose_style, ref_pose_coeff_path)
+ return coeff_path
+
+def generate_animation(
+ animate_from_coeff,
+ coeff_path,
+ crop_pic_path,
+ first_coeff_path,
+ crop_info,
+ save_dir,
+ pic_path
+):
+ data = get_facerender_data(
+ coeff_path,
+ crop_pic_path,
+ first_coeff_path,
+ args.audio,
+ args.batch_size,
+ args.input_yaw,
+ args.input_pitch,
+ args.input_roll,
+ expression_scale=args.expression_scale,
+ still_mode=args.still,
+ preprocess=args.preprocess,
+ size=args.size,
+ )
+
+ result = animate_from_coeff.generate(
+ data,
+ save_dir,
+ pic_path,
+ crop_info,
+ enhancer=args.enhancer,
+ background_enhancer=None,
+ preprocess=args.preprocess,
+ img_size=args.size,
+ )
+
+ return result
+
+def main():
+ set_seed(args.seed)
+ save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
+ os.makedirs(save_dir, exist_ok=True)
+
+ models = download_and_load_models()
+
+ # init model
+ preprocess_model = CropAndExtract(
+ models["face3d_recon_net"],
+ models["face_align_net"],
+ models["retinaface_net"],
+ LM3D_PATH,
+ use_onnx=args.onnx
+ )
+ audio_to_coeff = Audio2Coeff(
+ models["audio2exp_net"],
+ models["audio2pose_net"],
+ use_onnx=args.onnx
+ )
+ animate_from_coeff = AnimateFromCoeff(
+ models["generator_net"],
+ models["kp_detector_net"],
+ models["mapping_net"],
+ models["retinaface_net"],
+ models["gfpgan_net"],
+ use_onnx=args.onnx
+ )
+
+ # crop image and extract 3dmm coefficients
+ print('3DMM Extraction for source image')
+ first_coeff_path, crop_pic_path, crop_info = preprocess_image(
+ preprocess_model,
+ args.input[0],
+ save_dir
+ )
+
+ # extract 3dmm coefficients of the reference video (for eye-blink and pose).
+ ref_eyeblink_coeff_path, ref_pose_coeff_path = extract_reference_coeffs(
+ preprocess_model,
+ args.ref_eyeblink,
+ args.ref_pose,
+ save_dir
+ )
+
+ # Generate coefficients for animation from audio data
+ coeff_path = generate_audio_to_coeff(
+ audio_to_coeff,
+ first_coeff_path,
+ args.audio,
+ ref_eyeblink_coeff_path,
+ save_dir,
+ ref_pose_coeff_path
+ )
+
+ # generate animation
+ result = generate_animation(
+ animate_from_coeff,
+ coeff_path,
+ crop_pic_path,
+ first_coeff_path,
+ crop_info,
+ save_dir,
+ args.input[0]
+ )
+
+ shutil.move(result, args.savepath)
+ print('The generated video is named:', args.savepath)
+
+ if not args.verbose:
+ shutil.rmtree(args.result_dir)
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/download_all_models.sh b/scripts/download_all_models.sh
index 8cb2c88b6..e63e1ded2 100755
--- a/scripts/download_all_models.sh
+++ b/scripts/download_all_models.sh
@@ -142,6 +142,7 @@ cd ../../generative_adversarial_networks/psgan; python3 psgan.py ${OPTION}
cd ../../generative_adversarial_networks/encoder4editing; python3 encoder4editing.py ${OPTION}
cd ../../generative_adversarial_networks/lipgan; python3 lipgan.py ${OPTION}
cd ../../generative_adversarial_networks/live_portrait; python3 live_portrait.py ${OPTION}
+cd ../../generative_adversarial_networks/sadtalker; python3 sadtalker.py ${OPTION}
cd ../../hand_detection/yolov3-hand; python3 yolov3-hand.py ${OPTION}
cd ../../hand_detection/hand_detection_pytorch python3 hand_detection_pytorch.py ${OPTION}
cd ../../hand_detection/blazepalm; python3 blazepalm.py ${OPTION}