Toward Multimodal Image-to-Image Translation
Code Walkthrough for Bi-Cycle GAN
List of files
- process_data.py
- bcgan.py
- networks.py
- layers.py
- test.py
- train.py
Libraries
import tensorflow as tf
Code is written in python using tensorflow library.
Other dependencies numpy, scipy, os, argparse, tqdm, h5py, time, random. ################################################################################################################
process_data.py [Data preprocessing]
- modules : get_data
- downloaded data is loaded using this module
- augmentation is done on the fly - not part of this step.
def get_data(image_size=256, dataset='edges2shoes' , is_train=True, debug= False):
'''function to get the training and validation data, dataset given as string,
image size in int format, is_train in bool format for the train/valid data.'''
.
.
.
.
return return_data
################################################################################################################
bcgan.py [BC-GAN Model definitions]
- modules : Bicycle_GAN
- dependencies
from network import generator, discriminator, encoder
module : Bicycle_GAN
module is a class definition - for bicyclic gan
functions:
- constructor
- summary_create
- train
- test
function: constructor
- creates all the necessary variable in the class object.
- uses the modules generator, discriminator and encoder to create cVAE-GAN and cLR-GAN
- formulates the loss functions
- optimizers
- update ops(taking care of batchnorm updates)
function: summary_create
- create the tensorboard summaries for all costs, and images
- merging all summaries
function: train
- runs the main training loop
- loss minimization and gradient updates
- learning rate is periodically decayed
- summaries are periodically written
function: test
- loads the pretrained weights
- generates the images by random sampling
- saves the images
class Bicycle_GAN(object):
def __init__(self, ...):
.
.
return xxx
def summary_create(self):
.
.
def train(self, sess, data, saver, summary_writer):
.
.
def test(self, sess, data, write_dir):
.
.
################################################################################################################
network.py [GEN, DISC, ENC Model definitions]
- modules - generator, discriminator, encoder
- dependencies
from layers import * ( wrapper functions for all the layers)
module : generator
- for creating the generator graph definition, with all the conv layers, normalizations, and activations.
- returns the final layer output
module : discriminator
- for creating the discriminator graph definition, uses the deconv layers in addition to other layers to increase the spatial size.
- returns the final layer output
module : encoder
- for creating the encoder graph definition, uses the residual skip connections along with other layers.
- returns the final layer output
class Generator(object):
def __init__():
.
.
def__call__():
.
.
class Discriminator(object):
def __init__():
.
.
def__call__():
.
.
class Encoder(object):
def __init__():
.
.
def__call__():
.
.
################################################################################################################
layers.py [Wrappers for tf.layers]
- modules - conv2d, flatten, residual etc …
wrapper functions on top of the tensorflow implementations of the defined layers.
def normalization(input, is_train, norm=None):
.
.
return output
def conv2d(input, is_train, norm=None):
.
.
return output
def residual(input, is_train, norm=None):
.
.
return output
################################################################################################################
train.py
- modules - collect_args, validate_args, train
- dependencies - Bicycle_GAN, get_data
function: collect_args
- collect the model parameters and training parameters using the argparse
function: validate_args
- validates the collected arguments are allowable values
function: train
- sets up the GPU environment and variables
- loads the training data
- creates the BiCycle GAN model definition
- load the pretrained weights if exists
- call the training function in Bicycle_GAN
def validate_args(args):
"""Validating the arguments"""
.
.
def collect_args():
"""Collecting the arguments"""
.
.
def train(args):
"""Training the Model"""
if __name__ == "__main__":
args = collect_args()
print 'Colleted the Argumets'
validate_args(args)
train(args)
################################################################################################################
test.py
- modules - collect_args, validate_args, train
- dependencies - Bicycle_GAN, get_data
function: collect_args
- collect the model parameters and training parameters using the argparse
function: validate_args
- validates the collected arguments are allowable values
function: test
- sets up the GPU environment and variables
- loads the testing data
- creates the BiCycle GAN model definition
- load the pretrained weights
- call the test function in BiCycle_GAN
def validate_args(args):
"""Validating the arguments"""
.
.
def collect_args():
"""Collecting the arguments"""
.
.
def test(args):
"""Training the Model"""
if __name__ == "__main__":
args = collect_args()
print 'Colleted the Argumets'
validate_args(args)
test(args)
################################################################################################################
Usage
- Training - default [edges2shoes, size=256]
python train.py --dataset edges2shoes --batch_size 1 --img_size 256 --gpu 1
- Testing
python test.py --pretrained_weights 'weights/location/go/here'
- Tensorboard
tensorboard --logdir=./logs
- in browser - localhost:6006/
- Download the data and store it in hdf5 format in data folder(create data folder in current directory )
Authors
Jun-Yan Zhu, Richard Zhang, Deepak Pathak, Trevor Darrell, Alexei A. Efros, Oliver Wang, Eli Shechtman