Helper Module for Deep Learning.
Generation of 3D brain MRI using VAE Generative Adversial Networks¶
Credit: A Grigis
Based on:
This tutorial is for the intuition of simple Generative Adversarial Networks (GAN) for generating realistic MRI images. Here, we propose a model that can successfully generate 3D brain MRI data from random vectors by learning the data distribution. After reading this tutorial, you’ll understand the technical details needed to implement VAE-GAN.
Let’s begin with importing stuffs:
import os
import sys
if "CI_MODE" in os.environ:
sys.exit()
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.autograd import Variable
from pynet.datasets import DataManager, fetch_brats
from pynet.interfaces import DeepLearningInterface
from pynet.plotting import Board, update_board
from pynet.utils import setup_logging
from pynet.preprocessing.spatial import downsample
from pynet.models import BGDiscriminator, BGEncoder, BGGenerator
# Global parameters
logger = logging.getLogger("pynet")
setup_logging(level="debug")
The model will be trained on BRATS
We will train the model to synthesize brain disorder MRI data (Glioma).
data = fetch_brats(
datasetdir="/neurospin/nsap/processed/deepbrain/tumor/data/brats")
batch_size = 4
def transformer(data, imgtype="flair"):
typemap = {
"t1": 0, "t1ce": 1, "t2": 2, "flair": 3}
if imgtype is None:
imgtype = range(4)
else:
if not isinstance(imgtype, list):
imgtype = [imgtype]
imgtype = [typemap[key] for key in imgtype]
transformed_data = []
for channel_id in range(len(data)):
if channel_id not in imgtype:
continue
arr = data[channel_id]
transformed_data.append(downsample(arr, scale=3))
return np.asarray(transformed_data)
manager = DataManager(
input_path=data.input_path,
metadata_path=data.metadata_path,
stratify_label="grade",
number_of_folds=10,
batch_size=batch_size,
test_size=0,
input_transforms=[transformer],
sample_size=0.2)
Loss¶
criterion_bce = nn.BCELoss()
criterion_l1 = nn.L1Loss()
Training¶
We’ll train the encoder, generator and discriminator to optimize the losses using Adam optimizer.
n_epochs = 100
latent_dim = 1000
use_cuda = False
channels = 1
in_shape = (50, 64, 45) # (150, 190, 135)
gamma = 20
beta = 10
device = torch.device("cuda" if use_cuda else "cpu")
generator = BGGenerator(
in_shape=in_shape, out_channels=channels, start_filts=64,
latent_dim=latent_dim, mode="trilinear", with_code=False).to(device)
discriminator = BGDiscriminator(
in_shape=in_shape, in_channels=channels, out_channels=channels,
start_filts=64, with_logit=True).to(device)
encoder = BGEncoder(
in_shape=in_shape, in_channels=channels, start_filts=64,
latent_dim=latent_dim).to(device)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0001)
e_optimizer = torch.optim.Adam(encoder.parameters(), lr = 0.0001)
real_y = Variable(torch.ones((batch_size, channels)).to(device))
fake_y = Variable(torch.zeros((batch_size, channels)).to(device))
board = Board(port=8097, host="http://localhost", env="vae")
outdir = "/tmp/vae-gan/checkpoint"
if not os.path.isdir(outdir):
os.makedirs(outdir)
for epoch in range(n_epochs):
loaders = manager.get_dataloader(train=True, validation=False,
fold_index=0)
for iteration, item in enumerate(loaders.train):
real_images = item.inputs.to(device)
batch_size = real_images.size(0)
real_images = Variable(real_images,requires_grad=False).to(device)
z_rand = Variable(torch.randn(
(batch_size, latent_dim)), requires_grad=False).to(device)
mean, logvar, code = encoder(real_images)
x_rec = generator(code)
x_rand = generator(z_rand)
logger.debug("X_real: {0}".format(real_images.shape))
logger.debug("X_rand: {0}".format(x_rand.shape))
logger.debug("X_rec: {0}".format(x_rec.shape))
# Train discriminator
d_optimizer.zero_grad()
d_real_loss = criterion_bce(
discriminator(real_images), real_y[:batch_size])
d_recon_loss = criterion_bce(discriminator(x_rec), fake_y[:batch_size])
d_fake_loss = criterion_bce(discriminator(x_rand), fake_y[:batch_size])
dis_loss = d_recon_loss + d_real_loss + d_fake_loss
dis_loss.backward(retain_graph=True)
d_optimizer.step()
# Train generator
g_optimizer.zero_grad()
output = discriminator(real_images)
d_real_loss = criterion_bce(output, real_y[:batch_size])
output = discriminator(x_rec)
d_recon_loss = criterion_bce(output, fake_y[:batch_size])
output = discriminator(x_rand)
d_fake_loss = criterion_bce(output, fake_y[:batch_size])
d_img_loss = d_real_loss + d_recon_loss + d_fake_loss
gen_img_loss = -d_img_loss
rec_loss = ((x_rec - real_images)**2).mean()
err_dec = gamma * rec_loss + gen_img_loss
err_dec.backward(retain_graph=True)
g_optimizer.step()
# Train encoder
prior_loss = 1 + logvar-mean.pow(2) - logvar.exp()
prior_loss = (-0.5 * torch.sum(prior_loss)) / torch.numel(mean.data)
err_enc = prior_loss + beta * rec_loss
e_optimizer.zero_grad()
err_enc.backward()
e_optimizer.step()
# Visualization
if iteration % 4 == 0:
print("[{0}/{1}]".format(epoch, n_epochs),
"D: {:<8.3}".format(dis_loss.item()),
"En: {:<8.3}".format(err_enc.item()),
"De: {:<8.3}".format(err_dec.item()))
for name, data in [("X_real", real_images), ("X_dec", x_rec),
("X_rand", x_rand)]:
featmask = (0.5 * data[0] + 0.5).data.cpu().numpy()
img = featmask[..., featmask.shape[-1] // 2]
img = np.expand_dims(img, axis=1)
img = (img / img.max()) * 255
board.viewer.images(
img,
opts={
"title": name,
"caption": name},
win=name)
# Save model
for name, model in [("generator", generator),
("discriminator", discriminator),
("encoder", encoder)]:
fname = os.path.join(
outdir, name + "_epoch_" + str(epoch + 1) + ".pth")
torch.save(model.state_dict(), fname)
Conclusion¶
Variational Auto-Encoder(VAE) GAN are free from mode collapse but outputs are characterized with blurriness. In order to effectively address the problems of both mode collapse of GANs and blurriness of VAEs, we will use α-GAN, a solution born by combining both models, in the next tutorial.
Total running time of the script: ( 0 minutes 0.000 seconds)
Gallery generated by Sphinx-Gallery
Follow us
Inspired by AZMIND template.