Skip to content

Use GAN to calculate style loss #441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,45 @@ Clockwise from upper left: "The Starry Night" + "The Scream", "The Scream" + "Co
<img src="https://raw.githubusercontent.com/jcjohnson/neural-style/master/examples/outputs/tubingen_seated_nude_composition_vii.png" height="250px">
</div>

### Learn from multiple styles with GAN
When using hundreds of pictures as style images, a discriminator could be used to calculate the style loss. The discriminator takes gram matrix as input and was trained to tell whether the generated image belongs to the target style.

**The traditional way of calculating sytle loss**:![](https://github.com/citymonkeymao/neural-style/blob/gan/data/style-gan.png?raw=true)

**The new way of calculating style loss**:
![](https://github.com/citymonkeymao/neural-style/blob/gan/data/style-gan2.png?raw=true)

#### Results
##### Imitate Shinkai Makoto Style
Transfered with ~160 high quality style images.
![](https://github.com/citymonkeymao/neural-style/blob/gan/data/cmp_manual.png?raw=true)
##### Imitate Monet(Comparing to [CycleGAN](https://github.com/junyanz/CycleGAN))
![](https://github.com/citymonkeymao/neural-style/blob/gan/data/monet.png?raw=true)
##### Imitate Vangogh(Comparing to [CycleGAN](https://github.com/junyanz/CycleGAN))
![](https://github.com/citymonkeymao/neural-style/blob/gan/data/vangogh.png?raw=true)


#### Usage
1. Download style image set(borrowed from CycleGAN):
`bash ./datasets/download_dataset.sh <dataset name>`

`<dataset name>` could be monet2photo, vangogh2photo, ukiyoe2photo, cezanne2photo
2. Do style transfer
```
th neural_style.lua -style_image `./list_images.sh <style_image_dir>` -content_<content_image> -gan -content_weight 2 -style_weight 50000 -image_size 256 -backend cudnn -num_iterations 10000 -d_learning_rate 0.000001`
```

`-gan`command specifies using Discriminators to calculate style losses. `d_learning_rate` is the learning rate for Discriminators. `list_images.sh` helps to list all images in one directory, all files in that directory should not contain space and `style_image_dir`should not contain`~`. You need to play with parameters for different style and size.
#### example
Transfer fj.jpg to vangogh style
1. Download vangogh's painting `bash ./datasets/download_dataset.sh vangogh2photo`
2. Add styles to image
```
th neural_style.lua -style_image `./list_images.sh datasets/vangogh2photo/trainA
` -content_image data/fj.jpg -gan -content_weight 1 -style_weight 50000 -image_size 256 -backend cudnn -num_iterations
10000 -d_learning_rate 0.0000001
```


### Style Interpolation
When using multiple style images, you can control the degree to which they are blended:
Expand Down
Binary file added data/cmp_manual.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/fj.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/monet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/style-gan.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/style-gan2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/vangogh.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions datasets/download_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FILE=$1

if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
exit 1
fi

URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
ZIP_FILE=./datasets/$FILE.zip
TARGET_DIR=./datasets/$FILE/
wget -N $URL -O $ZIP_FILE
mkdir $TARGET_DIR
unzip $ZIP_FILE -d ./datasets/
rm $ZIP_FILE
10 changes: 10 additions & 0 deletions list_images.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
if [ $# -eq 1 ];
then
a=`find $1 -type f`
else
a=`find $1 -type f -name "*.$2"`
fi
b=$(echo "$a" | tr '\n' ,)
b=${b::-1}
echo $b
164 changes: 153 additions & 11 deletions neural_style.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ require 'torch'
require 'nn'
require 'image'
require 'optim'

require 'io'
require 'loadcaffe'


Expand Down Expand Up @@ -48,10 +48,16 @@ cmd:option('-seed', -1)
cmd:option('-content_layers', 'relu4_2', 'layers for content')
cmd:option('-style_layers', 'relu1_1,relu2_1,relu3_1,relu4_1,relu5_1', 'layers for style')

--gan or not
cmd:option('-gan',false,'use a discriminator to calculate style_losses')
--gan learning rate
cmd:option('-d_learning_rate',0.0000001)

local function main(params)
local dtype, multigpu = setup_gpu(params)

local function main(params)
io.stdout:setvbuf('no')
--local dtype, multigpu = setup_gpu(params)
dtype, multigpu = setup_gpu(params)
local loadcaffe_backend = params.backend
if params.backend == 'clnn' then loadcaffe_backend = 'nn' end
local cnn = loadcaffe.load(params.proto_file, params.model_file, loadcaffe_backend):type(dtype)
Expand Down Expand Up @@ -222,7 +228,7 @@ local function main(params)
local optim_state = nil
if params.optimizer == 'lbfgs' then
optim_state = {
maxIter = params.num_iterations,
maxIter = 1,
verbose=true,
tolX=-1,
tolFun=-1,
Expand Down Expand Up @@ -297,14 +303,72 @@ local function main(params)
return loss, grad:view(grad:nElement())
end

--generate optimize function for discriminator for one style layer's loss
local function get_discriminator_loss(style_layer)
--get params and gradients of discriminator
local dis_params, gradParams = style_layer.D.model:getParameters()
--create input for discriminator: one real gram, one fake gram
local batchInputs = Tensor(2 , style_layer.target[1]:size(1), style_layer.target[1]:size(2))
--create label set
local batchLabels = torch.zeros(2) + 1
batchLabels[-1] = 0
batchLabels = batchLabels:type(dtype)
local crit = nn.BCECriterion():type(dtype)
d_loss = function (params)
--randomly choose a real image as positive input
batchInputs[1]:copy(style_layer.target[math.random(#style_layer.target)])
--copy fake images to training set
batchInputs[2]:copy(style_layer.G)
--change dtype
batchInputs = batchInputs:type(dtype)
--create loss
local output = style_layer.D:forward(batchInputs)
local loss = crit:forward(output, batchLabels)
local verbose = (num_calls % 50 == 0)
if verbose then
print(string.format('discriminator loss %f',loss))
end

--refresh gradients of discriminator
local d_output = crit:backward(output, batchLabels)
gradParams:zero()
style_layer.D:backward(batchInputs, d_output)
return loss, gradParams
end
return d_loss, dis_params
end

function train_discriminator()
if params.gan then
--print ('refresh gan losses')
for i =1, #style_losses do --train discriminator for each sytle layer
--local f, dis_params = get_discriminator_loss(style_losses[i])
local x, losses = optim.adam(style_losses[i].f, style_losses[i].dis_params, style_losses[i].optim_state)
end
end
end

--init gan optimization eval functions
if params.gan then
for i = 1, #style_losses do
f, dis_params = get_discriminator_loss(style_losses[i])
style_losses[i].f = f
style_losses[i].dis_params = dis_params
end
end

-- Run optimization.
if params.optimizer == 'lbfgs' then
print('Running optimization with L-BFGS')
local x, losses = optim.lbfgs(feval, img, optim_state)
for t = 1, params.num_iterations do
local x, losses = optim.lbfgs(feval, img, optim_state)
train_discriminator()
end
elseif params.optimizer == 'adam' then
print('Running optimization with ADAM')
for t = 1, params.num_iterations do
local x, losses = optim.adam(feval, img, optim_state)
train_discriminator()
end
end
end
Expand All @@ -321,6 +385,7 @@ function setup_gpu(params)
else
params.gpu = tonumber(params.gpu) + 1
end
Tensor= torch.FloatTensor
local dtype = 'torch.FloatTensor'
if multigpu or params.gpu > 0 then
if params.backend ~= 'clnn' then
Expand All @@ -331,6 +396,7 @@ function setup_gpu(params)
else
cutorch.setDevice(params.gpu)
end
Tensor = torch.CudaTensor
dtype = 'torch.CudaTensor'
else
require 'clnn'
Expand All @@ -340,6 +406,7 @@ function setup_gpu(params)
else
cltorch.setDevice(params.gpu)
end
Tensor = torch.Tensor():cl()
dtype = torch.Tensor():cl():type()
end
else
Expand Down Expand Up @@ -519,14 +586,25 @@ function StyleLoss:__init(strength, normalize)
parent.__init(self)
self.normalize = normalize or false
self.strength = strength
self.target = torch.Tensor()
--if gan, save all grams
if params.gan then
self.target = {}
self.optim_state = {
learningRate = params.d_learning_rate,
}
else
self.target = torch.Tensor()
end
self.mode = 'none'
self.loss = 0

self.gram = nn.GramMatrix()
self.blend_weight = nil
self.G = nil
self.crit = nn.MSECriterion()
if params.gan then
self.D = nil
end
end

function StyleLoss:updateOutput(input)
Expand All @@ -535,21 +613,45 @@ function StyleLoss:updateOutput(input)
if self.mode == 'capture' then
if self.blend_weight == nil then
self.target:resizeAs(self.G):copy(self.G)
--if gan mode, store every image gram
elseif params.gan then
gram_of_this_style = torch.Tensor():type(dtype):resizeAs(self.G):copy(self.G)
table.insert(self.target, gram_of_this_style)
elseif self.target:nElement() == 0 then
self.target:resizeAs(self.G):copy(self.G):mul(self.blend_weight)

else
self.target:add(self.blend_weight, self.G)
self.target:add(self.blend_weight, self.G)
end
elseif self.mode == 'loss' then
self.loss = self.strength * self.crit:forward(self.G, self.target)
end
if params.gan then --if gan mode
--create D after we knew the dim of Gram Matrix
if self.D == nil then
self.D = nn.Discriminator(self.G:size(1),self.G:size(2))
self.D:type(dtype)
end
--classify gram matrix
self.classified = self.D:forward(self.G)
-- print(self.classified)
--hope it looks like the target style
self.loss = self.strength * self.crit:forward(self.classified,torch.Tensor({1}):type(dtype))
else --if not gan mode
self.loss = self.strength * self.crit:forward(self.G, self.target)
end
end
self.output = input
return self.output
end

function StyleLoss:updateGradInput(input, gradOutput)
if self.mode == 'loss' then
local dG = self.crit:backward(self.G, self.target)
local dG
if params.gan then
d_classified = self.crit:backward(self.classified,torch.Tensor({1}):type(dtype))
dG = self.D:backward(self.G, d_classified)
else
dG = self.crit:backward(self.G, self.target)
end
dG:div(input:nElement())
self.gradInput = self.gram:backward(input, dG)
if self.normalize then
Expand Down Expand Up @@ -596,6 +698,46 @@ function TVLoss:updateGradInput(input, gradOutput)
return self.gradInput
end

-- Define an nn Module to compute style loss with a discriminator in-place
local Discriminator, parent = torch.class('nn.Discriminator', 'nn.Module')

function Discriminator:__init(input_H, input_W)
--define a simple discriminator
self.model = nn.Sequential()
--flatten the gram matrix
self.model:add(nn.View(input_H * input_W))
--hidden layer 1
self.model:add(nn.Linear(input_H * input_W,input_W))
self.model:add(nn.ReLU())
-- --hidden layer 2
-- self.model:add(nn.Linear(input_H * input_W / 4,input_H * input_W / 8))
-- self.model:add(nn.ReLU())
-- --hidden layer 3
self.model:add(nn.Linear(input_W ,input_W))
self.model:add(nn.ReLU())

--output layer
self.model:add(nn.Linear(input_W , 1))
self.model:add(nn.Sigmoid())
self.model = self.model:type(dtype)
end

--forward of discriminator
function Discriminator:updateOutput(input)
--input is the Gram Matrix
input = input:type(dtype)
self.output = self.model:forward(input)
return self.output
end

--backward of discriminator
function Discriminator:updateGradInput(input, gradOutput)
self.gradInput = self.model:backward(input,gradOutput)
return self.gradInput
end



local params = cmd:parse(arg)
-- local params = cmd:parse(arg)
params = cmd:parse(arg)
main(params)