diff --git a/templates/Style Transfer_Pytorch/code-template.py.jinja b/templates/Style Transfer_Pytorch/code-template.py.jinja new file mode 100644 index 0000000..a45c541 --- /dev/null +++ b/templates/Style Transfer_Pytorch/code-template.py.jinja @@ -0,0 +1,81 @@ +import torch +from torch import nn +from torch import optim +from torchvision import transforms +from PIL import Image +import torchvision.models as models +from torchvision.utils import save_image +import numpy as np + +device = {{Device}} +model = models.vgg19(pretrained={{pretrained}}).features +print(model) + +class VGG(nn.Module): + def __init__(self): + super(VGG,self).__init__() + self.choosen_features = ['0','5','10','19','28'] + self.model = models.vgg19(pretrained=True).features[:29] + + def forward(self,x): + features = [] + + for layer_num , layer in enumerate(self.model): + x = layer(x) + + if str(layer_num) in self.choosen_features: + features.append(x) + + return features + +model = VGG().to(device).eval() +loader = transforms.Compose([ + transforms.Resize((540,540)) , + transforms.ToTensor(), +]) +def load_image(image_name): + img = Image.open(image_name) + img = loader(img).unsqueeze(0) + return img.to(device) + +image_size = 540 +original_img = load_image('./content.jpg') +style_img = load_image('./style.jpg') + +generated = original_img.clone().requires_grad_(True) + +total_steps = {{ epochs }} +learning_rate = {{ lr }} +alpha = 1 +beta = 0.01 +optimizer = optim.{{optimizer}}([generated],lr=learning_rate) + +for step in range(total_steps): + generated_features = model(generated) + original_img_features = model(original_img) + style_img_features = model(style_img) + + style_loss = original_loss = 0 + + for gen_fea,ori_fea,style_fea in zip(generated_features,original_img_features,style_img_features): + batch_size , channel,height,width = gen_fea.shape + original_loss += torch.mean((gen_fea - ori_fea) ** 2) + + G = gen_fea.view(channel,height*width).mm( + gen_fea.view(channel,height*width).t() + ) + + A = style_fea.view(channel,height*width).mm( + style_fea.view(channel,height*width).t() + ) + + style_loss += torch.mean((G - A)**2) + + total_loss = alpha*original_loss + beta * style_loss + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + if step % {{ visualize_per_epoch }} == 0: + print(f'Total_Loss: {total_loss}') + save_image(generated,'Generated.png') diff --git a/templates/Style Transfer_Pytorch/sidebar.py b/templates/Style Transfer_Pytorch/sidebar.py new file mode 100644 index 0000000..674d559 --- /dev/null +++ b/templates/Style Transfer_Pytorch/sidebar.py @@ -0,0 +1,84 @@ +import streamlit as st + +MODELS = { + "VGG": "vgg19", +} + + +# Define possible optimizers in a dict. +# Format: optimizer -> default learning rate +OPTIMIZERS = { + "Adam": 0.001, + "Adadelta": 1.0, + "Adagrad": 0.01, + "Adamax": 0.002, + "RMSprop": 0.01, + "SGD": 0.1, +} + + + +def show(): + """Shows the sidebar components for the template and returns user inputs as dict.""" + + # `show()` is the only method required in this module. You can add any other code + # you like above or below. + + inputs = {} # dict to store all user inputs until return + + with st.sidebar: + + # Render all template-specific sidebar components here. + + # Use ## to denote sections. Common sections for training templates: + # Model, Input data, Preprocessing, Training, Visualizations + st.write("## Model") + + # Store all user inputs in the `inputs` dict. This will be passed to the code + # template later. + inputs["model"] = st.selectbox("Which model?", list(MODELS.keys())) + + inputs["Device"] = st.selectbox( + "Which device would you like to train on?", + ("GPU", "CPU"), + ) + + inputs["pretrained"] = st.checkbox("Use pre-trained model (Suggested Use is with a pretrained model)") + + st.write("## Input data") + inputs["data_format"] = st.selectbox( + "Which data do you want to use?", + ("Custom Image files"), + ) + + if input["data_format"]== "Custom Image files": + st.write(""" + ``` + Make sure you have style.jpg and content.jpg . + ``` + """) + + + inputs["loss"] = st.selectbox( + "Loss function", ("Style Gram Loss") + ) + + inputs["optimizer"] = st.selectbox("Optimizer", list(OPTIMIZERS.keys())) + + default_lr = OPTIMIZERS[inputs["optimizer"]] + inputs["lr"] = st.number_input( + "Learning rate", 0.000, None, default_lr, format="%f" + ) + + inputs["num_epochs"] = st.number_input("Epochs", 1, None, 5000) + + inputs["visualize_per_epoch"] = st.number_input("Epochs", 1, None, 200) + + + return inputs + + +# To test the sidebar independent of the app or template, just run +# `streamlit run sidebar.py` from within this folder. +if __name__ == "__main__": + show()