Skip to content

Style Transfer Template #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions templates/Style Transfer_Pytorch/code-template.py.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
from torch import nn
from torch import optim
from torchvision import transforms
from PIL import Image
import torchvision.models as models
from torchvision.utils import save_image
import numpy as np

device = {{Device}}
model = models.vgg19(pretrained={{pretrained}}).features
print(model)

class VGG(nn.Module):
def __init__(self):
super(VGG,self).__init__()
self.choosen_features = ['0','5','10','19','28']
self.model = models.vgg19(pretrained=True).features[:29]

def forward(self,x):
features = []

for layer_num , layer in enumerate(self.model):
x = layer(x)

if str(layer_num) in self.choosen_features:
features.append(x)

return features

model = VGG().to(device).eval()
loader = transforms.Compose([
transforms.Resize((540,540)) ,
transforms.ToTensor(),
])
def load_image(image_name):
img = Image.open(image_name)
img = loader(img).unsqueeze(0)
return img.to(device)

image_size = 540
original_img = load_image('./content.jpg')
style_img = load_image('./style.jpg')

generated = original_img.clone().requires_grad_(True)

total_steps = {{ epochs }}
learning_rate = {{ lr }}
alpha = 1
beta = 0.01
optimizer = optim.{{optimizer}}([generated],lr=learning_rate)

for step in range(total_steps):
generated_features = model(generated)
original_img_features = model(original_img)
style_img_features = model(style_img)

style_loss = original_loss = 0

for gen_fea,ori_fea,style_fea in zip(generated_features,original_img_features,style_img_features):
batch_size , channel,height,width = gen_fea.shape
original_loss += torch.mean((gen_fea - ori_fea) ** 2)

G = gen_fea.view(channel,height*width).mm(
gen_fea.view(channel,height*width).t()
)

A = style_fea.view(channel,height*width).mm(
style_fea.view(channel,height*width).t()
)

style_loss += torch.mean((G - A)**2)

total_loss = alpha*original_loss + beta * style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()

if step % {{ visualize_per_epoch }} == 0:
print(f'Total_Loss: {total_loss}')
save_image(generated,'Generated.png')
84 changes: 84 additions & 0 deletions templates/Style Transfer_Pytorch/sidebar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import streamlit as st

MODELS = {
"VGG": "vgg19",
}


# Define possible optimizers in a dict.
# Format: optimizer -> default learning rate
OPTIMIZERS = {
"Adam": 0.001,
"Adadelta": 1.0,
"Adagrad": 0.01,
"Adamax": 0.002,
"RMSprop": 0.01,
"SGD": 0.1,
}



def show():
"""Shows the sidebar components for the template and returns user inputs as dict."""

# `show()` is the only method required in this module. You can add any other code
# you like above or below.

inputs = {} # dict to store all user inputs until return

with st.sidebar:

# Render all template-specific sidebar components here.

# Use ## to denote sections. Common sections for training templates:
# Model, Input data, Preprocessing, Training, Visualizations
st.write("## Model")

# Store all user inputs in the `inputs` dict. This will be passed to the code
# template later.
inputs["model"] = st.selectbox("Which model?", list(MODELS.keys()))

inputs["Device"] = st.selectbox(
"Which device would you like to train on?",
("GPU", "CPU"),
)

inputs["pretrained"] = st.checkbox("Use pre-trained model (Suggested Use is with a pretrained model)")

st.write("## Input data")
inputs["data_format"] = st.selectbox(
"Which data do you want to use?",
("Custom Image files"),
)

if input["data_format"]== "Custom Image files":
st.write("""
```
Make sure you have style.jpg and content.jpg .
```
""")


inputs["loss"] = st.selectbox(
"Loss function", ("Style Gram Loss")
)

inputs["optimizer"] = st.selectbox("Optimizer", list(OPTIMIZERS.keys()))

default_lr = OPTIMIZERS[inputs["optimizer"]]
inputs["lr"] = st.number_input(
"Learning rate", 0.000, None, default_lr, format="%f"
)

inputs["num_epochs"] = st.number_input("Epochs", 1, None, 5000)

inputs["visualize_per_epoch"] = st.number_input("Epochs", 1, None, 200)


return inputs


# To test the sidebar independent of the app or template, just run
# `streamlit run sidebar.py` from within this folder.
if __name__ == "__main__":
show()