diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..18a1acc --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +.idea +.idea/* +**/__pycache__ +.env +gf/gf/weights/* +gf/weights/* +app/logs +!app/logs/.gitignore +app/files +!app/files/.gitignore diff --git a/app/aws_processor.py b/app/aws_processor.py new file mode 100644 index 0000000..6acdfa9 --- /dev/null +++ b/app/aws_processor.py @@ -0,0 +1,110 @@ +import json +import boto3 +from typing import Union, Tuple, Any +from app.settings import AWS_CONFIG +import os +from app.utilities import now, generate_video_path, generate_final_video, generate_s3_video_arn, generate_s3_media_arn +from botocore.exceptions import ClientError +import requests +import time +# from balancer import Balancer +# import shutil + +class AWSProcessor: + def __init__(self): + self.sqs_client = boto3.client( + 'sqs', aws_access_key_id=AWS_CONFIG['key'], aws_secret_access_key=AWS_CONFIG['secret'], region_name=AWS_CONFIG['region']) + self.s3_client = boto3.client( + 's3', aws_access_key_id=AWS_CONFIG['key'], aws_secret_access_key=AWS_CONFIG['secret']) + + self.bucket = AWS_CONFIG['bucket'] + self.sqs_url = AWS_CONFIG['sqs'] + self.upload_bucket = AWS_CONFIG['final_upload_bucket'] + + def get_sqs_client(self): + return self.sqs_client + + def get_s3_client(self): + return self.s3_client + + def upload_logs(self, uid=None, instance_id=None): + log_general = 'app/logs/{}_{}.log'.format(now(True), instance_id) + job_log = False + if uid is not None: + job_log = "app/logs/{}.log".format(uid) + try: + self.s3_client.upload_file(log_general, self.bucket, "logs/bg-removal/{}".format(os.path.basename(log_general))) + if uid is not None and job_log: + self.s3_client.upload_file(job_log, self.bucket, "logs/bg-removal/{}".format(os.path.basename(job_log))) + except ClientError as E: + raise Exception(E) + + def uplaod_final_video(self, uid, final_video_local): + try: + _, extension = os.path.splitext(final_video_local) + self.s3_client.upload_file(final_video_local, self.upload_bucket, "avatars/users/{}{}".format(uid, extension), ExtraArgs={'ACL': 'public-read'}) + return "https://{}.s3.amazonaws.com/avatars/users/{}{}".format(self.upload_bucket, uid, extension) + except ClientError as E: + raise Exception(E) + + def delete_sqs_message(self, handler): + self.sqs_client.delete_message( + QueueUrl=self.sqs_url, + ReceiptHandle=handler + ) + + + def get_sqs(self, process_name) -> Union[Tuple[Any, Any], bool]: + """ + This method is responsible for reading AWS SQS queues through aws_process or + Returns: + Union[dict, bool]: + """ + # balancer = Balancer('lipsync') + # if balancer.is_main(): + # balancer.create_blocker(process_name) + # elif balancer.main_running(): + # time.sleep(30) + # return False, False + # elif not balancer.can_run(): + # time.sleep(30) + # return False, False + + response = self.sqs_client.receive_message(QueueUrl=self.sqs_url, MaxNumberOfMessages=1, WaitTimeSeconds=2) + + for message in response.get('Messages', []): + message_body = message['Body'] + sqs_message_handler = message['ReceiptHandle'] + # while not balancer.can_run(): + # time.sleep(10) + # continue + + return json.loads(message_body), sqs_message_handler + + # balancer.remove_process(process_name) + return False, False + + def generate_full_url(self, arn): + return "https://{}.s3.amazonaws.com/tts/{}.wav".format(self.bucket, arn) + + def download_video(self, video, uid): + bucket, arn = generate_s3_media_arn(video) + _, extension = os.path.splitext(video) + video_local = generate_video_path(uid, extension) + + try: + self.s3_client.download_file(bucket, arn, video_local) + except Exception as E: + print(E) + return video_local, extension + + def file_exists(self, arn): + try: + self.s3_client.head_object(Bucket='mltts', Key=arn) + except ClientError as e: + if e.response['Error']['Code'] == "404": + return False + else: + return False + else: + return True \ No newline at end of file diff --git a/app/balancer.py b/app/balancer.py new file mode 100644 index 0000000..c72ce82 --- /dev/null +++ b/app/balancer.py @@ -0,0 +1,104 @@ +import os + +from settings import BALANCER, BALANCER_TIMES, BALANCER_SERVER_TYPES, BALANCER_RESOURCES, SERVER_RESOURCES_TOTAL +from utilities import now +from os.path import exists +from datetime import datetime +import nvidia_smi +import psutil +import time + + +class Balancer: + def __init__(self, current_server): + self.current_server = current_server + + def is_main(self): + return self.current_server == BALANCER['main_for'] + + def create_blocker(self, process_name): + with open('/home/ubuntu/processes/{}/{}'.format(self.current_server, process_name), 'w+') as f: + f.write(now()) + + def main_running(self): + main_server_processes = len(os.listdir('/home/ubuntu/processes/{}'.format(BALANCER['main_for']))) + return main_server_processes != 0 + + def can_run_main(self): + processes = 0 + for proc in BALANCER_SERVER_TYPES: + if proc == self.current_server: + continue + processes = len(os.listdir('/home/ubuntu/processes/{}'.format(proc))) + \ + len(os.listdir('/home/ubuntu/processes/{}'.format(proc))) + \ + len(os.listdir('/home/ubuntu/processes/{}'.format(proc))) + + print("processes ", processes) + if processes == 0: + return True + + return False + + def can_run(self): + if self.is_main(): + print('main check') + return self.can_run_main() + print('has time {}'.format(self.has_time())) + print('has resource {}'.format(self.has_resource())) + return self.has_time() and self.has_resource() + + def remove_process(self, process_name): + process_filename = '/home/ubuntu/processes/{}/{}'.format(self.current_server, process_name) + if exists(process_filename): + os.remove(process_filename) + + def has_time(self): + time_info = os.popen('who -b').read() + time = int(time_info.split(':')[-1]) # 22 sarqvela + time_now = int(datetime.now().strftime("%M")) + remaining_time = time + 55 - time_now + return remaining_time >= BALANCER_TIMES[self.current_server] + + def has_resource(self): + cpu_use = 0 + ram_use = 0 + gpu_use = 0 + nvidia_smi.nvmlInit() + + for i in range(60): + cpu_use += psutil.cpu_percent() + + handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0) + info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle) + + gpu_use += info.free / (1024 ** 2) + ram_use += psutil.virtual_memory().free / (1024 ** 2) + + time.sleep(0.5) + + gpu = int(gpu_use / 60) + ram = int(ram_use / 60) + cpu = int(cpu_use / 60) + + if gpu < BALANCER_RESOURCES[self.current_server]['gpu'] or ram < BALANCER_RESOURCES[self.current_server]['ram']: + return False + + gpu, ram, cpu = self.other_processes_status() + if gpu < BALANCER_RESOURCES[self.current_server]['gpu'] or ram < BALANCER_RESOURCES[self.current_server]['ram']: + return False + + return True + + def other_processes_status(self): + gpu_free_proc = int(SERVER_RESOURCES_TOTAL['gpu']) + ram_free_proc = int(SERVER_RESOURCES_TOTAL['ram']) + cpu_free_proc = 100 + + for proc in BALANCER_SERVER_TYPES: + if proc == self.current_server: + continue + gpu_free_proc -= int(len(os.listdir('/home/ubuntu/processes/{}'.format(proc)))) * 1.2 * int(BALANCER_RESOURCES[proc]['gpu']) + ram_free_proc -= int(len(os.listdir('/home/ubuntu/processes/{}'.format(proc)))) * 1.2 * int(BALANCER_RESOURCES[proc]['ram']) + # cpu_free_proc -= int(len(os.listdir('/home/ubuntu/processes/{}'.format(proc)))) * 1.2 * BALANCER_RESOURCES[proc]['cpu'] + + return gpu_free_proc, ram_free_proc, cpu_free_proc diff --git a/app/bg.py b/app/bg.py new file mode 100644 index 0000000..4b4fa20 --- /dev/null +++ b/app/bg.py @@ -0,0 +1,27 @@ +import torch +from model import MattingNetwork +from inference import convert_video +from app.utilities import generate_final_video, generate_video_path +from app.settings import VIDEO_CONFIG + + +def removal(uid, extension, video_local, seq_chunk): + model = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50" + model.load_state_dict(torch.load('rvm_mobilenetv3.pth')) + + final_video = generate_final_video(uid, extension) + + convert_video( + model, # The model, can be on any device (cpu or cuda). + input_source=video_local, # A video file or an image sequence directory. + output_type='video', # Choose "video" or "png_sequence" + output_composition=final_video, # File path if video; directory path if png sequence. + # output_alpha="pha.mp4", # [Optional] Output the raw alpha prediction. + # output_foreground="fgr.mp4", # [Optional] Output the raw foreground prediction. + output_video_mbps=4, # Output video mbps. Not needed for png sequence. + downsample_ratio=None, # A hyperparameter to adjust or use None for auto. + seq_chunk=seq_chunk, # Process n frames at once for better parallelism. + ) + + return final_video + diff --git a/app/settings.py b/app/settings.py new file mode 100644 index 0000000..7e4a6d8 --- /dev/null +++ b/app/settings.py @@ -0,0 +1,75 @@ +import os +from dotenv import load_dotenv + +load_dotenv() + +WEBHOOK_CONFIG = { + 'key': os.getenv('API_TOKEN'), + 'url': os.getenv('WEBHOOK_URL') +} + +RESPONSE_CODES = {} + +AWS_CONFIG = { + 'key': os.getenv('AWS_KEY'), + 'secret': os.getenv('AWS_SECRET'), + 'bucket': os.getenv('AWS_BUCKET'), + 'region': os.getenv('AWS_REGION'), + 'sqs': os.getenv('AWS_SQS_URL'), + 'sqs_handler': 'handler.json', + 'final_upload_bucket': os.getenv('UPLOAD_BUCKET') +} + +SERVER = { + 'main_id': os.getenv('MAIN_ID') +} + +SHUTDOWN_TIMINGS = { + 'minutes': os.getenv('MINUTES'), + 'intermediate': os.getenv('INTERMEDIATE') +} + +BALANCER = { + 'main_for': os.getenv('MAIN_FOR') +} + +BALANCER_SERVER_TYPES = ['cloning', 'tts', 'lipsync', 'tortoise'] + +BALANCER_TIMES = { + 'cloning': 420, + 'tts': 3, + 'lipsync': 20, + 'tortoise': 10 +} + +BALANCER_RESOURCES = { + 'cloning': { + 'cpu': 100, + 'ram': 3550, + 'gpu': 3658 + }, + 'tts': { + 'cpu': 17, + 'ram': 3335, + 'gpu': 2060 + }, + 'lipsync': { + 'cpu': 100, + 'ram': 22000, + 'gpu': 14000 + }, + 'tortoise': { + 'cpu': 77, + 'ram': 9582, + 'gpu': 15200 + } +} + +SERVER_RESOURCES_TOTAL = { + 'gpu': os.getenv('GPU_TOTAL'), + 'ram': os.getenv('RAM_TOTAL') +} + +VIDEO_CONFIG = { + 'seq_chunk': os.getenv('SEQ_CHUNK') +} \ No newline at end of file diff --git a/app/utilities.py b/app/utilities.py new file mode 100644 index 0000000..36fc995 --- /dev/null +++ b/app/utilities.py @@ -0,0 +1,220 @@ +import json +import logging +import os +from datetime import datetime +import requests +from dotenv import load_dotenv +from app.settings import WEBHOOK_CONFIG, SHUTDOWN_TIMINGS, SERVER +import glob +import time +import platform +load_dotenv() + + +def send_webhook(url=None, headers=None, data=None) -> bool: + + if url is None: + url = WEBHOOK_CONFIG['url'] + + if headers is None: + headers = { + 'Authorization': WEBHOOK_CONFIG['key'] + } + + try: + if data is None: + requests.get(url, headers=headers) + else: + requests.post(url, data=data, headers=headers, timeout=10) + except Exception as E: + logger().error(E) + + +def logger(file=None): + file_name = file + if file is None: + instance_id = os.popen('wget -q -O - http://169.254.169.254/latest/meta-data/instance-id').read() + daily_log = "{}_{}.log".format(now(True), instance_id) + file = "app/logs/{}".format(daily_log) + level = 'DEBUG' + else: + file = "app/logs/{}.log".format(file) + level = 'INFO' + + log_format = logging.Formatter("%(levelname)s %(asctime)s - %(message)s") + + handler = logging.FileHandler(file) + handler.setFormatter(log_format) + logger = logging.getLogger(file_name) + logger.setLevel(level) + + if logger.hasHandlers(): + logger.handlers.clear() + + logger.addHandler(handler) + + return logger + + +def clear_old_data(file_or_dir, is_folder=False): + if is_folder: + files = glob.glob(file_or_dir) + for f in files: + os.remove(f) + else: + if os.path.exists(file_or_dir): + os.remove(file_or_dir) + + +def now(is_only_date=False): + return datetime.now().strftime("%Y-%m-%d") if is_only_date else datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +def generate_s3_media_arn(media): + bucket_full, arn = media.split('/', 2)[-1].split('/', 1) + return bucket_full.split('.')[0], arn + + +def generate_s3_video_arn(avatar_id): + return 'avatars/{}.mp4'.format(avatar_id) + + +def generate_video_path(uid, extension): + return "app/files/{}/video{}".format(uid, extension) + + +def generate_final_video(uid, extension): + return 'app/files/{}/final{}'.format(uid, extension) + + +def shutdown_single_proccess(): + instance_id = os.popen('wget -q -O - http://169.254.169.254/latest/meta-data/instance-id').read() + + if instance_id == SERVER['main_id']: + return False + + if not check_runtime(): + return False + + if not other_servers_are_running(): + return False + + with open('/home/ubuntu/terminate.txt', 'w') as f: + f.write(now()) + + terminate_command = "aws autoscaling terminate-instance-in-auto-scaling-group --instance-id {} --should-decrement-desired-capacity 2>&1".format( + instance_id) + os.system(terminate_command) + + +def shutdown(): + instance_id = os.popen('wget -q -O - http://169.254.169.254/latest/meta-data/instance-id').read() + if instance_id == SERVER['main_id']: + return False + else: + if not check_runtime(): + return False + + if not check_all_processes_statuses(): + return False + + with open('/home/ubuntu/terminate.txt', 'w') as f: + f.write(now()) + + terminate_command = "aws autoscaling terminate-instance-in-auto-scaling-group --instance-id {} --should-decrement-desired-capacity 2>&1".format( + instance_id) + os.system(terminate_command) + + +def other_servers_are_running(): + import json + + command = 'aws autoscaling describe-auto-scaling-groups --auto-scaling-group-name ASG-ML-BACKGROUND-REMOVAL' + output = os.popen(command).read() + asg_info = json.loads(output) + + min_size = asg_info['AutoScalingGroups'][0]['MinSize'] + if min_size == 0: + return True + + instances_info = asg_info['AutoScalingGroups'][0]['Instances'] + if len(instances_info) <= min_size: + return False + + return True + + +def check_runtime(): + time_info = os.popen('who -b').read() + time = int(time_info.split(':')[-1]) + terminate_time = time + int(SHUTDOWN_TIMINGS['minutes']) - int(SHUTDOWN_TIMINGS['intermediate']) + terminate_time = terminate_time if terminate_time < 60 else terminate_time - 60 + + time_now = int(datetime.now().strftime("%M")) + if terminate_time - 1 <= time_now <= terminate_time + 1: + return True + + return False + + +def check_all_processes_statuses(): + from settings import BALANCER_SERVER_TYPES + time_now = time.time() + for proc in BALANCER_SERVER_TYPES: + if not len(os.listdir('/home/ubuntu/processes/{}'.format(proc))) == 0: + processes = os.listdir('/home/ubuntu/processes/{}'.format(proc)) + for process in processes: + created = creation_date('/home/ubuntu/processes/{}/{}'.format(proc, process)) + minutes_pass = (time_now - created) / 60 + if minutes_pass > 50: + os.remove('process/{}'.format(process)) + else: + return False + + return True + + +def creation_date(file): + if platform.system() == 'Windows': + return os.path.getctime(file) + else: + stat = os.stat(file) + try: + return stat.st_ctime + except AttributeError: + return stat.st_mtime + + +def check_processess(): + import time + time_now = time.time() + processes = os.listdir('process') + for process in processes: + created = creation_date('process/{}'.format(process)) + minutes_pass = (time_now - created)/60 + if minutes_pass > 15: + os.remove('process/{}'.format(process)) + else: + return False + + return True + + +def video_dimension_unifier(video, uid, extension): + import cv2 + video_dimensions = os.popen('ffprobe -v error -select_streams v -show_entries stream=width,height -of json {}'.format(video)).read() + video_width, video_height, *_ = json.loads(video_dimensions)['streams'][0].values() + + first_frame_path = "app/files/{}/first_frame.jpg".format(uid) + os.system('ffmpeg -i {} -y -vf "select=eq(n\,0)" -q:v 3 {}'.format(video, first_frame_path)) + first_frame_dimensions = cv2.imread(first_frame_path) + frame_h, frame_w, _ = first_frame_dimensions.shape + + if video_width != frame_w or video_height != frame_h: + video_new_dimensions = "app/files/{}/video_n_dim{}".format(uid, extension) + os.system('ffmpeg -i {} -y -vf scale={}:{} {}'.format(video, frame_w, frame_h, video_new_dimensions)) + return video_new_dimensions + + video_dimension_factor = frame_h * frame_w / (1920 * 1080) + video_dimension_factor = 1 if video_dimension_factor < 1 else video_dimension_factor + return video, video_dimension_factor diff --git a/remove_background.py b/remove_background.py new file mode 100644 index 0000000..9aa50de --- /dev/null +++ b/remove_background.py @@ -0,0 +1,108 @@ +from app.aws_processor import AWSProcessor +import os +import platform +from app.utilities import now, creation_date, logger, send_webhook, shutdown_single_proccess, video_dimension_unifier +import logging +from os.path import exists +import string +import time +import random +from app.bg import removal +import shutil +import traceback +from app.settings import VIDEO_CONFIG + + +if __name__ == '__main__': + if platform.system() != 'Windows': + os.chdir('/home/ubuntu/videomatting') + + instance_id = os.popen('wget -q -O - http://169.254.169.254/latest/meta-data/instance-id').read() + if not exists('app/logs/{}_{}.log'.format(now(True), instance_id)): + with open('app/logs/{}_{}.log'.format(now(True), instance_id), 'w') as f: + f.write(now()) + + # general config + letters = string.ascii_lowercase + aws_client = AWSProcessor() + + # check server status for asgs + if exists('/home/ubuntu/terminate.txt'): + time_now = time.time() + created_termination = creation_date('/home/ubuntu/terminate.txt') + minutes_pass = (time_now - created_termination) / 60 + if minutes_pass > 15: + os.remove('/home/ubuntu/terminate.txt') + else: + time.sleep(20) + exit(0) + + process_name = ''.join(random.choice(letters) for i in range(10)) + process_name = "{}.txt".format(process_name) + + # read sqs + try: + sqs, handler = aws_client.get_sqs(process_name) + # + # sqs = { + # 'uid': 'sdasda', + # 'video': 'https://mltts.s3.amazonaws.com/tmp/test.mp4' + # } + + if not sqs: + shutdown_single_proccess() + # todo check if there is + # a need of a balancer + time.sleep(10) + exit(0) + + uid = sqs['uid'] + logger().info('Bg removal process started for {}'.format(uid)) + # todo check if there is a need of a balancer create process file + + logger(uid).info('{} task running on {} server'.format(uid, instance_id)) + + # create unique folder + os.makedirs('app/files/{}'.format(uid), exist_ok=True) + logger(uid).info('created folder: files/{}'.format(uid)) + + # download video + video, extension = aws_client.download_video(sqs['video'], uid) + logger(uid).info('downloaded video into: files/{}/video.mp4'.format(uid)) + + # video dimension modifier + video, video_dimension_factor = video_dimension_unifier(video, uid, extension) + + seq_chunk = int(int(VIDEO_CONFIG['seq_chunk']) / video_dimension_factor) + + # bg_removal + local_file = removal(uid, extension=extension, video_local=video, seq_chunk=seq_chunk) + + # upload final video + logger(uid).info('uploading final video ') + final_video_url = aws_client.uplaod_final_video(uid, local_file) + # print(final_video_url) + + # webhook + send_webhook(data={'video': final_video_url, 'uid': uid}) + + # delete sqs + aws_client.delete_sqs_message(handler) + + # upload logs + aws_client.upload_logs(uid=uid, instance_id=instance_id) + + # clear data + handlers = logging.getLogger(uid).handlers[:] + for handler in handlers: + handler.close() + os.remove("app/logs/{}.log".format(uid)) + + # remove files + time.sleep(10) + shutil.rmtree('app/files/{}'.format(uid)) + + except Exception as E: + print(E) + print(traceback.format_exc()) + pass \ No newline at end of file diff --git a/requirements_inference.txt b/requirements_inference.txt index 4b24a22..bdc0977 100644 --- a/requirements_inference.txt +++ b/requirements_inference.txt @@ -1,5 +1,8 @@ av==8.0.3 -torch==1.9.0 -torchvision==0.10.0 +#torch==1.9.0 +#torchvision==0.10.0 tqdm==4.61.1 -pims==0.5 \ No newline at end of file +pims==0.5 +boto3 +python-dotenv +opencv-python \ No newline at end of file diff --git a/test_bg.py b/test_bg.py new file mode 100644 index 0000000..3ad986b --- /dev/null +++ b/test_bg.py @@ -0,0 +1,22 @@ +# import torch +# from model import MattingNetwork +# from inference import convert_video +# +# model = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50" +# model.load_state_dict(torch.load('rvm_mobilenetv3.pth')) +# +# convert_video( +# model, # The model, can be on any device (cpu or cuda). +# input_source='test.mp4', # A video file or an image sequence directory. +# output_type='video', # Choose "video" or "png_sequence" +# output_composition='com.mp4', # File path if video; directory path if png sequence. +# output_alpha="pha.mp4", # [Optional] Output the raw alpha prediction. +# output_foreground="fgr.mp4", # [Optional] Output the raw foreground prediction. +# output_video_mbps=4, # Output video mbps. Not needed for png sequence. +# downsample_ratio=None, # A hyperparameter to adjust or use None for auto. +# seq_chunk=12, # Process n frames at once for better parallelism. +# ) + +from app.utilities import video_dimension_unifier + +video_dimension_unifier('app/files/sdasda/video1.mp4', 'sdasda', '.mp4') \ No newline at end of file diff --git a/train.py b/train.py index 462bd1f..34b041f 100644 --- a/train.py +++ b/train.py @@ -403,7 +403,7 @@ def train_seg(self, true_img, true_seg, log_label): true_seg = true_seg.to(self.rank, non_blocking=True) true_img, true_seg = self.random_crop(true_img, true_seg) - + with autocast(enabled=not self.args.disable_mixed_precision): pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0] loss = segmentation_loss(pred_seg, true_seg)