diff --git a/jupyters/ruDalle_pl.ipynb b/jupyters/ruDalle_pl.ipynb new file mode 100644 index 0000000..dcd2731 --- /dev/null +++ b/jupyters/ruDalle_pl.ipynb @@ -0,0 +1,1160 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "k3EMS6NSieNf" + }, + "source": [ + "# ⚡️⚡️⚡️ ruDALLE torch lightning finetuning notebook ⚡️⚡️⚡️" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "PnJMPgBhiao8" + }, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install rudalle==1.1.0\n", + "!pip install bitsandbytes-cuda113 > /dev/null\n", + "!pip install wandb > /dev/null\n", + "!pip install pytorch-lightning > /dev/null\n", + "!pip install deepspeed > /dev/nul\n", + "!pip install gdown > /dev/null" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 108 + }, + "id": "oZWBrXrDiao-", + "outputId": "67c7b524-0f15-48c0-cf7a-aff4fbced172" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading...\n", + "From: http://drive.google.com/uc?id=17bPt7G3N_vGKCCxppIOPbPlhv1qUnv0o\n", + "To: /workspace/food.zip\n", + "100%|██████████| 34.8M/34.8M [00:00<00:00, 63.6MB/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "'food.zip'" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#@markdown Lets download data\n", + "\n", + "import gdown\n", + "url = \"http://drive.google.com/uc?id=17bPt7G3N_vGKCCxppIOPbPlhv1qUnv0o\"\n", + "output = \"food.zip\"\n", + "gdown.download(url, output, quiet=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "6wadzpKAiao_" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "food/Булочки с начинкой «Молочная нежность».jpg: mismatching \"local\" filename (food/╨С╤Г╨╗╨╛╤З╨║╨╕ ╤Б ╨╜╨░╤З╨╕╨╜╨║╨╛╨╣ ┬л╨Ь╨╛╨╗╨╛╤З╨╜╨░╤П ╨╜╨╡╨╢╨╜╨╛╤Б╤В╤М┬╗.jpg),\n", + " continuing with \"central\" filename version\n", + "food/Запеченный картофель «Чесночно-сырный».jpg: mismatching \"local\" filename (food/╨Ч╨░╨┐╨╡╤З╨╡╨╜╨╜╤Л╨╣ ╨║╨░╤А╤В╨╛╤Д╨╡╨╗╤М ┬л╨з╨╡╤Б╨╜╨╛╤З╨╜╨╛-╤Б╤Л╤А╨╜╤Л╨╣┬╗.jpg),\n", + " continuing with \"central\" filename version\n", + "food/Лепешки «В путь дорожку!».jpg: mismatching \"local\" filename (food/╨Ы╨╡╨┐╨╡╤И╨║╨╕ ┬л╨Т ╨┐╤Г╤В╤М ╨┤╨╛╤А╨╛╨╢╨║╤Г!┬╗.jpg),\n", + " continuing with \"central\" filename version\n", + "food/Салат «Анастасия».jpg: mismatching \"local\" filename (food/╨б╨░╨╗╨░╤В ┬л╨Р╨╜╨░╤Б╤В╨░╤Б╨╕╤П┬╗.jpg),\n", + " continuing with \"central\" filename version\n", + "food/Салат «Коко Шанель».jpg: mismatching \"local\" filename (food/╨б╨░╨╗╨░╤В ┬л╨Ъ╨╛╨║╨╛ ╨и╨░╨╜╨╡╨╗╤М┬╗.jpg),\n", + " continuing with \"central\" filename version\n", + "food/Салат «Курица в шубе».jpg: mismatching \"local\" filename (food/╨б╨░╨╗╨░╤В ┬л╨Ъ╤Г╤А╨╕╤Ж╨░ ╨▓ ╤И╤Г╨▒╨╡┬╗.jpg),\n", + " continuing with \"central\" filename version\n", + "food/Салат «Лохматый».jpg: mismatching \"local\" filename (food/╨б╨░╨╗╨░╤В ┬л╨Ы╨╛╤Е╨╝╨░╤В╤Л╨╣┬╗.jpg),\n", + " continuing with \"central\" filename version\n", + "food/Салат «Нежный».jpg: mismatching \"local\" filename (food/╨б╨░╨╗╨░╤В ┬л╨Э╨╡╨╢╨╜╤Л╨╣┬╗.jpg),\n", + " continuing with \"central\" filename version\n", + "food/Салат «Фуршетный».jpg: mismatching \"local\" filename (food/╨б╨░╨╗╨░╤В ┬л╨д╤Г╤А╤И╨╡╤В╨╜╤Л╨╣┬╗.jpg),\n", + " continuing with \"central\" filename version\n", + "food/Салат с печенью «Осенние краски».jpg: mismatching \"local\" filename (food/╨б╨░╨╗╨░╤В ╤Б ╨┐╨╡╤З╨╡╨╜╤М╤О ┬л╨Ю╤Б╨╡╨╜╨╜╨╕╨╡ ╨║╤А╨░╤Б╨║╨╕┬╗.jpg),\n", + " continuing with \"central\" filename version\n", + "food/Хлеб без замеса «Проще не бывает».jpg: mismatching \"local\" filename (food/╨е╨╗╨╡╨▒ ╨▒╨╡╨╖ ╨╖╨░╨╝╨╡╤Б╨░ ┬л╨Я╤А╨╛╤Й╨╡ ╨╜╨╡ ╨▒╤Л╨▓╨░╨╡╤В┬╗.jpg),\n", + " continuing with \"central\" filename version\n" + ] + } + ], + "source": [ + "!unzip food.zip > /dev/null" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JhO29PfCiao_" + }, + "source": [ + "For better results use data with good text descritption, and propetis what you want to train model.\n", + "For example - if you want to teach model to generate shape provide text input such as: \"a thing in shape of circle\" 'circle.png'" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "CWVD0dPNiao_" + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import random\n", + "from collections import Counter\n", + "\n", + "import PIL\n", + "import torch\n", + "import numpy as np\n", + "import pandas as pd\n", + "import bitsandbytes as bnb\n", + "import torchvision.transforms as T\n", + "import torchvision.transforms.functional as TF\n", + "from tqdm import tqdm\n", + "\n", + "from matplotlib import pyplot as plt\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from rudalle import get_tokenizer, get_vae\n", + "from rudalle.utils import seed_everything\n", + "import pytorch_lightning as pl\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bu36SjSxiao_", + "outputId": "213817f5-12ff-4a51-faa8-60c95bd5b7fc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.7/site-packages/huggingface_hub/file_download.py:563: FutureWarning: `cached_download` is the legacy way to download files from the HF hub, please consider upgrading to `hf_hub_download`\n", + " FutureWarning,\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cbb33bbf251245babfda229c039dfcdc", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading: 0%| | 0.00/2.62G [00:00 ready\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "47416b1980cf42c783f46cf9ff037581", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading: 0%| | 0.00/224k [00:00 ready\n" + ] + } + ], + "source": [ + "\n", + "from rudalle.pipelines import generate_images, show, super_resolution\n", + "from rudalle import get_tokenizer, get_vae, get_realesrgan,get_rudalle_model\n", + "from rudalle.utils import seed_everything\n", + "device = 'cuda'\n", + "model = get_rudalle_model('Malevich', pretrained=True, fp16=True, device='cpu')\n", + "vae = get_vae().to(device)\n", + "tokenizer = get_tokenizer()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "X8BUqLHtSJHC", + "outputId": "a6935e2b-b8d5-4645-b326-1ea864b810fc" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "1152" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "1024+128" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "WF5t_0zdiao_" + }, + "outputs": [], + "source": [ + "class Args():\n", + " def __init__(self):\n", + " self.epochs = 2\n", + " self.bs = 1\n", + " self.save_every = 100 \n", + " self.lr = 4e-5\n", + "\n", + "\n", + "\n", + " self.text_seq_length = 128\n", + " self.total_seq_length = 1152\n", + " self.save_path='checkpoints'\n", + " self.model_name = 'awesomemodel_'\n", + " self.prefix_length = 10\n", + " self.image_size = 256\n", + " self.freeze = False\n", + " self.clip = 0.24\n", + " self.warmup_steps =50\n", + " \n", + " self.wandb = True\n", + "args = Args()\n", + "if not os.path.exists(args.save_path):\n", + " os.makedirs(args.save_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "SXcWtNnKiapA" + }, + "outputs": [], + "source": [ + "class FoodDataset(Dataset):\n", + " def __init__(self, file_path, csv_path, tokenizer, shuffle=True):\n", + " self.tokenizer = tokenizer\n", + " self.samples = []\n", + " self.image_transform = T.Compose([\n", + " T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),\n", + " T.RandomResizedCrop(args.image_size, scale=(1., 1.), ratio=(1., 1.)),\n", + " T.ToTensor()\n", + " ])\n", + "\n", + " df = pd.read_csv(csv_path)\n", + " df.columns = ['index', 'belok', 'fats', 'uglevod', 'kkal', 'name', 'path']\n", + "\n", + " for belok, fats, uglevod, kkal, caption, f_path in zip(\n", + " df['belok'],df['fats'], df['uglevod'], df['kkal'], df['name'], df['path']\n", + " ):\n", + " caption = f'блюдо: {caption}; белков: {belok}; жиров: {fats}; углеводов: {uglevod}; ккал: {kkal};'\n", + " if len(caption)>10 and len(caption)<100 and os.path.isfile(f'{file_path}/{f_path}'):\n", + " self.samples.append([file_path, f_path, caption.lower()])\n", + " if shuffle:\n", + " np.random.shuffle(self.samples)\n", + " print('Shuffled')\n", + "\n", + " def __len__(self):\n", + " return len(self.samples)\n", + "\n", + " def load_image(self, file_path, img_name):\n", + " return PIL.Image.open(f'{file_path}/{img_name}')\n", + "\n", + " def __getitem__(self, item):\n", + " item = item % len(self.samples)\n", + " file_path, img_name, text = self.samples[item]\n", + "\n", + " try:\n", + " image = self.load_image(file_path, img_name)\n", + " image = self.image_transform(image)\n", + " except Exception as err: \n", + " print(err)\n", + " random_item = random.randint(0, len(self.samples) - 1)\n", + " return self.__getitem__(random_item)\n", + " \n", + " text = text.lower().strip()\n", + " encoded = self.tokenizer.encode_text(text, text_seq_length=args.text_seq_length) \n", + " return encoded, image" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HRiRjILWiapB", + "outputId": "f2c798d3-6fec-4c69-cbd1-cd0ed1b4e1d1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shuffled\n", + "474\n" + ] + } + ], + "source": [ + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "class FoodDataModule(pl.LightningDataModule):\n", + " \n", + " def __init__(self, file_path, csv_path, tokenizer):\n", + " super().__init__()\n", + " self.file_path = file_path\n", + " self.csv_path = csv_path\n", + "\n", + " def setup(self, stage=None):\n", + " self.train_dataset = FoodDataset(file_path= self.file_path, \n", + " csv_path = self.csv_path, \n", + " tokenizer = tokenizer)\n", + "\n", + " \n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(\n", + " self.train_dataset,\n", + " batch_size=args.bs,\n", + " shuffle=True,\n", + " \n", + " )\n", + "dataset = FoodDataset(file_path = 'food' ,csv_path = 'food/food.csv',tokenizer = tokenizer)\n", + "args.train_steps = len(dataset)//args.bs\n", + "print(args.train_steps )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "Q6F34xzDiapB" + }, + "outputs": [], + "source": [ + "from transformers import AdamW, get_linear_schedule_with_warmup\n", + "import bitsandbytes as bnb\n", + "\n", + "\n", + "\n", + "def freeze(params,\n", + " freeze_emb=False,\n", + " freeze_ln=False,\n", + " freeze_attn=True,\n", + " freeze_ff=True,\n", + " freeze_other=False):\n", + " \n", + " for name, p in params.named_parameters(): \n", + " name = name.lower()\n", + " if 'ln' in name or 'norm' in name:\n", + " p.requires_grad = not freeze_ln\n", + " elif 'embeddings' in name:\n", + " p.requires_grad = not freeze_emb\n", + " elif 'mlp' in name:\n", + " p.requires_grad = not freeze_ff\n", + " elif 'attn' in name:\n", + " p.requires_grad = not freeze_attn\n", + " else:\n", + " p.requires_grad = not freeze_other\n", + " \n", + " return params\n", + "\n", + "\n", + "class Rudalle_(pl.LightningModule):\n", + " \n", + " def __init__(self, args, vae):\n", + " super().__init__()\n", + " \n", + " \n", + " print(self.device)\n", + " self.model = get_rudalle_model('Malevich', pretrained=True, fp16=True, device='cuda')\n", + "\n", + " if args.freeze:\n", + " self.model = freeze(self.model) \n", + " #self.vae = get_vae(dwt=False).to(self.device)\n", + "\n", + " \n", + " print(self.device)\n", + " \n", + " def forward(self, \n", + " input_ids, \n", + " return_loss=True):\n", + " \n", + " input_ids = input_ids.to(self.device)\n", + " \n", + " attention_mask = torch.tril(torch.ones((args.bs, 1, args.total_seq_length, args.total_seq_length)))\n", + " attention_mask = attention_mask.to(self.device)\n", + " \n", + " loss, loss_values = self.model.forward(input_ids, \n", + " attention_mask, \n", + " return_loss=return_loss)\n", + " \n", + " return loss\n", + "\n", + "\n", + "\n", + " def training_step(self, batch):\n", + " text, image = batch[0], batch[1]\n", + "\n", + " text = text.to(self.device)\n", + " image = image.to(self.device)\n", + "\n", + " image_input_ids = vae.get_codebook_indices(image)\n", + "\n", + " image_input_ids = image_input_ids.to(self.device)\n", + " \n", + " input_ids = torch.cat((text, image_input_ids), dim=1)\n", + " \n", + " loss = self.forward(input_ids, return_loss=True)\n", + " \n", + " \n", + " self.log(\"train_loss\", loss, prog_bar=True, logger=True)\n", + "\n", + " return {\"loss\": loss}\n", + "\n", + " def training_epoch_end(self, outputs):\n", + " pass\n", + "\n", + " def configure_optimizers(self):\n", + " \n", + " optimizer = AdamW(self.parameters(), lr=args.lr, betas=(0.9, 0.995))\n", + "\n", + " \n", + "\n", + " \n", + " scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", + " optimizer, \n", + " max_lr=args.lr,\n", + " final_div_factor=500, \n", + " steps_per_epoch=args.train_steps,\n", + " epochs=args.epochs \n", + " )\n", + "\n", + " return dict(\n", + " optimizer=optimizer,\n", + " lr_scheduler=dict(\n", + " scheduler=scheduler,\n", + " interval='step'\n", + " )\n", + " )\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 88 + }, + "id": "K8eQZJ6diapB", + "outputId": "8eca2a72-b2aa-4dae-a394-077d4d7b1740" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33malexwortega\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.12.21" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /workspace/wandb/run-20220716_004709-1do8x8h3" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run fearless-cherry-27 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from pytorch_lightning.loggers import WandbLogger\n", + "# я использую wandb в качестве логера, если надо замените на тенсорборду\n", + "# i prefer wandb, change to tensobord if u want\n", + "wandb_logger = WandbLogger(project=\"rudalle\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "VOCg8SCAiapB" + }, + "outputs": [], + "source": [ + "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n", + "from pytorch_lightning.loggers import TensorBoardLogger\n", + "checkpoint_callback = ModelCheckpoint(\n", + " dirpath=args.save_path,\n", + " filename=args.model_name,\n", + " save_top_k=1,\n", + " verbose=True,\n", + " mode=\"min\"\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "naRy8C7tiapC", + "outputId": "7b24a6ae-656e-4f8f-c900-066ccc79a2be" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cpu\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "◼️ Malevich is 1.3 billion params model from the family GPT3-like, that uses Russian language and text+image multi-modality.\n", + "cpu\n" + ] + } + ], + "source": [ + "vae.to('cuda')\n", + "model = Rudalle_(args, vae)\n", + "\n", + "data_module = FoodDataModule(file_path='food', csv_path ='food/food.csv', tokenizer=tokenizer)\n", + "\n", + "trainer = pl.Trainer(\n", + " logger=wandb_logger,\n", + " checkpoint_callback=checkpoint_callback,\n", + " max_epochs=2,\n", + " \n", + " accelerator=\"gpu\",\n", + " accumulate_grad_batches=5,\n", + " devices=[0],\n", + " progress_bar_refresh_rate=30\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "E8ddEUsRiapC" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params\n", + "-------------------------------------\n", + "0 | model | FP16Module | 1.3 B \n", + "-------------------------------------\n", + "1.3 B Trainable params\n", + "0 Non-trainable params\n", + "1.3 B Total params\n", + "5,241.644 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shuffled\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1f7daa17a6cb4dd0abbe9175df03b1cb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py:1366: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n", + " \"Positional args are being deprecated, use kwargs instead. Refer to \"\n" + ] + } + ], + "source": [ + "trainer.fit(model, data_module)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_checkpoint('/rudalle')\n", + "def _fix_pl(path):\n", + " d = torch.load(path)[\"state_dict\"]\n", + " checkpoint = {}\n", + " for key in d.keys():\n", + " checkpoint[key.replace('model.','')] = d[key]\n", + " torch.save(checkpoint,'fixed.pt')\n", + "_fix_pl('rudalle')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b60H9WQ0ZVbB" + }, + "source": [ + "# deepspeed (advanced)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PT-lJNPhalOj" + }, + "outputs": [], + "source": [ + "%%writefile ds_config.json\n", + "\n", + "{\n", + " \"train_batch_size\": 15,\n", + " \"fp16\": {\n", + " \"enabled\": true,\n", + " \"min_loss_scale\": 1,\n", + " \"opt_level\": \"O3\"\n", + " },\n", + " \"zero_optimization\": {\n", + " \"stage\": 2,\n", + " \"offload_param\": {\n", + " \"device\": \"nvme\",\n", + " \"nvme_path\": \"/home/asta/deepspeed\",\n", + " \"buffer_count\": 5,\n", + " \"buffer_size\": 1e8,\n", + " \"reduce_bucket_size\":2e7,\n", + " \"max_in_cpu\": 1e9\n", + " },\n", + " \"offload_optimizer\": {\n", + " \"device\": \"nvme\",\n", + " \"nvme_path\": \"/home/asta/deepspeed\",\n", + " \"buffer_count\": 4,\n", + " \"pipeline_read\": false,\n", + " \"pipeline_write\": false,\n", + " \"pin_memory\": true\n", + " },\n", + " \"allgather_partitions\": true,\n", + " \"allgather_bucket_size\": 5e8,\n", + " \"contiguous_gradients\": true,\n", + " \"overlap_comm\": true,\n", + " \"aio\": {\n", + " \"block_size\": 1048576,\n", + " \"queue_depth\": 8,\n", + " \"thread_count\": 1,\n", + " \"single_submit\": false,\n", + " \"overlap_events\": true\n", + " }\n", + " },\n", + " \"optimizer\": {\n", + " \"type\": \"AdamW\",\n", + " \"params\": {\n", + " \"lr\": 5e-05,\n", + " \"betas\": [\n", + " 0.9,\n", + " 0.999\n", + " ],\n", + " \"eps\": 1e-08\n", + " }\n", + " },\n", + " \"scheduler\": {\n", + " \"type\": \"WarmupLR\",\n", + " \"params\": {\n", + " \"warmup_min_lr\": 0,\n", + " \"warmup_max_lr\": 5e-05,\n", + " \"warmup_num_steps\": 100\n", + " }\n", + " }\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X8-ZB9fsZIx5" + }, + "outputs": [], + "source": [ + "%%writefile train.py\n", + "import os\n", + "import sys\n", + "import random\n", + "from collections import Counter\n", + "\n", + "import PIL\n", + "import torch\n", + "import numpy as np\n", + "import pandas as pd\n", + "import bitsandbytes as bnb\n", + "import torchvision.transforms as T\n", + "import torchvision.transforms.functional as TF\n", + "from tqdm import tqdm\n", + "\n", + "from matplotlib import pyplot as plt\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from rudalle import get_tokenizer, get_vae\n", + "from rudalle.utils import seed_everything\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning.plugins import DeepSpeedPlugin\n", + "\n", + "\n", + "from rudalle.pipelines import generate_images, show, super_resolution\n", + "from rudalle import get_tokenizer, get_vae, get_realesrgan,get_rudalle_model\n", + "from rudalle.utils import seed_everything\n", + "device = 'cuda'\n", + "model = get_rudalle_model('Malevich', pretrained=True, fp16=True, device='cpu')\n", + "vae = get_vae().to(device)\n", + "tokenizer = get_tokenizer()\n", + "\n", + "class Args():\n", + " def __init__(self):\n", + " self.epochs = 2\n", + " self.bs = 1\n", + " self.save_every = 100 \n", + " self.lr = 4e-5\n", + "\n", + "\n", + "\n", + " self.text_seq_length = 128\n", + " self.total_seq_length = 1152\n", + " self.save_path='checkpoints'\n", + " self.model_name = 'awesomemodel_'\n", + " self.prefix_length = 10\n", + " self.image_size = 256\n", + " self.freeze = False\n", + " self.clip = 0.24\n", + " self.warmup_steps =50\n", + " \n", + " self.wandb = True\n", + "args = Args()\n", + "if not os.path.exists(args.save_path):\n", + " os.makedirs(args.save_path)\n", + "\n", + "class FoodDataset(Dataset):\n", + " def __init__(self, file_path, csv_path, tokenizer, shuffle=True):\n", + " self.tokenizer = tokenizer\n", + " self.samples = []\n", + " self.image_transform = T.Compose([\n", + " T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),\n", + " T.RandomResizedCrop(args.image_size, scale=(1., 1.), ratio=(1., 1.)),\n", + " T.ToTensor()\n", + " ])\n", + "\n", + " df = pd.read_csv(csv_path)\n", + " df.columns = ['index', 'belok', 'fats', 'uglevod', 'kkal', 'name', 'path']\n", + "\n", + " for belok, fats, uglevod, kkal, caption, f_path in zip(\n", + " df['belok'],df['fats'], df['uglevod'], df['kkal'], df['name'], df['path']\n", + " ):\n", + " caption = f'блюдо: {caption}; белков: {belok}; жиров: {fats}; углеводов: {uglevod}; ккал: {kkal};'\n", + " if len(caption)>10 and len(caption)<100 and os.path.isfile(f'{file_path}/{f_path}'):\n", + " self.samples.append([file_path, f_path, caption.lower()])\n", + " if shuffle:\n", + " np.random.shuffle(self.samples)\n", + " print('Shuffled')\n", + "\n", + " def __len__(self):\n", + " return len(self.samples)\n", + "\n", + " def load_image(self, file_path, img_name):\n", + " return PIL.Image.open(f'{file_path}/{img_name}')\n", + "\n", + " def __getitem__(self, item):\n", + " item = item % len(self.samples)\n", + " file_path, img_name, text = self.samples[item]\n", + "\n", + " try:\n", + " image = self.load_image(file_path, img_name)\n", + " image = self.image_transform(image)\n", + " except Exception as err: \n", + " print(err)\n", + " random_item = random.randint(0, len(self.samples) - 1)\n", + " return self.__getitem__(random_item)\n", + " \n", + " text = text.lower().strip()\n", + " encoded = self.tokenizer.encode_text(text, text_seq_length=args.text_seq_length) \n", + " return encoded, image\n", + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "class FoodDataModule(pl.LightningDataModule):\n", + " \n", + " def __init__(self, file_path, csv_path, tokenizer):\n", + " super().__init__()\n", + " self.file_path = file_path\n", + " self.csv_path = csv_path\n", + "\n", + " def setup(self, stage=None):\n", + " self.train_dataset = FoodDataset(file_path= self.file_path, \n", + " csv_path = self.csv_path, \n", + " tokenizer = tokenizer)\n", + "\n", + " \n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(\n", + " self.train_dataset,\n", + " batch_size=args.bs,\n", + " shuffle=True,\n", + " \n", + " )\n", + "dataset = FoodDataset(file_path = 'food' ,csv_path = 'food/food.csv',tokenizer = tokenizer)\n", + "args.train_steps = len(dataset)//args.bs\n", + "print(args.train_steps )\n", + "\n", + "from transformers import AdamW, get_linear_schedule_with_warmup\n", + "import bitsandbytes as bnb\n", + "\n", + "\n", + "\n", + "def freeze(params,\n", + " freeze_emb=False,\n", + " freeze_ln=False,\n", + " freeze_attn=True,\n", + " freeze_ff=True,\n", + " freeze_other=False):\n", + " \n", + " for name, p in params.named_parameters(): \n", + " name = name.lower()\n", + " if 'ln' in name or 'norm' in name:\n", + " p.requires_grad = not freeze_ln\n", + " elif 'embeddings' in name:\n", + " p.requires_grad = not freeze_emb\n", + " elif 'mlp' in name:\n", + " p.requires_grad = not freeze_ff\n", + " elif 'attn' in name:\n", + " p.requires_grad = not freeze_attn\n", + " else:\n", + " p.requires_grad = not freeze_other\n", + " \n", + " return params\n", + "\n", + "\n", + "class Rudalle_(pl.LightningModule):\n", + " \n", + " def __init__(self, args, vae):\n", + " super().__init__()\n", + " \n", + " \n", + " print(self.device)\n", + " self.model = get_rudalle_model('Malevich', pretrained=True, fp16=True, device='cuda')\n", + "\n", + " if args.freeze:\n", + " self.model = freeze(self.model) \n", + " #self.vae = get_vae(dwt=False).to(self.device)\n", + "\n", + " \n", + " print(self.device)\n", + " \n", + " def forward(self, \n", + " input_ids, \n", + " return_loss=True):\n", + " \n", + " input_ids = input_ids.to(self.device)\n", + " \n", + " attention_mask = torch.tril(torch.ones((args.bs, 1, args.total_seq_length, args.total_seq_length)))\n", + " attention_mask = attention_mask.to(self.device)\n", + " \n", + " loss, loss_values = self.model.forward(input_ids, \n", + " attention_mask, \n", + " return_loss=return_loss)\n", + " \n", + " return loss\n", + "\n", + "\n", + "\n", + " def training_step(self, batch):\n", + " text, image = batch[0], batch[1]\n", + "\n", + " text = text.to(self.device)\n", + " image = image.to(self.device)\n", + "\n", + " image_input_ids = vae.get_codebook_indices(image)\n", + "\n", + " image_input_ids = image_input_ids.to(self.device)\n", + " \n", + " input_ids = torch.cat((text, image_input_ids), dim=1)\n", + " \n", + " loss = self.forward(input_ids, return_loss=True)\n", + " \n", + " \n", + " self.log(\"train_loss\", loss, prog_bar=True, logger=True)\n", + "\n", + " return {\"loss\": loss}\n", + "\n", + " def training_epoch_end(self, outputs):\n", + " pass\n", + "\n", + " def configure_optimizers(self):\n", + " \n", + " optimizer = bnb.optim.Adam8bit(self.parameters(), lr=args.lr, betas=(0.9, 0.995))\n", + "\n", + " \n", + "\n", + " \n", + " scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", + " optimizer, \n", + " max_lr=args.lr,\n", + " final_div_factor=500, \n", + " steps_per_epoch=args.train_steps,\n", + " epochs=args.epochs \n", + " )\n", + "\n", + " return dict(\n", + " optimizer=optimizer,\n", + " lr_scheduler=dict(\n", + " scheduler=scheduler,\n", + " interval='step'\n", + " )\n", + " )\n", + " \n", + "\n", + "\n", + "from pytorch_lightning.loggers import WandbLogger\n", + "# я использую wandb в качестве логера, если надо замените на тенсорборду\n", + "# i prefer wandb, change to tensobord if u want\n", + "wandb_logger = WandbLogger(project=\"rudalle\")\n", + "\n", + "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n", + "from pytorch_lightning.loggers import TensorBoardLogger\n", + "checkpoint_callback = ModelCheckpoint(\n", + " dirpath=args.save_path,\n", + " filename=args.model_name,\n", + " save_top_k=1,\n", + " verbose=True,\n", + " mode=\"min\"\n", + ")\n", + "\n", + "vae.to('cuda')\n", + "model = Rudalle_(args, vae)\n", + "\n", + "data_module = FoodDataModule(file_path='food', csv_path ='food/food.csv', tokenizer=tokenizer)\n", + "\n", + "\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + "\n", + "\n", + " trainer = pl.Trainer(\n", + " logger=wandb_logger,\n", + " checkpoint_callback=checkpoint_callback,\n", + " max_epochs=2,\n", + " #precision=16,\n", + " \n", + " strategy=DeepSpeedPlugin(\"/content/ds_config.json\"),\n", + " \n", + " \n", + " progress_bar_refresh_rate=30\n", + " )\n", + "\n", + " trainer.fit(model, data_module)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!python3 train.py" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [ + "b60H9WQ0ZVbB" + ], + "machine_shape": "hm", + "name": "dalle_pl", + "provenance": [] + }, + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}