Menu

Helper Module for Deep Learning.

Generation of 3D brain MRI using VAE Generative Adversial Networks

Credit: A Grigis

Based on:

This tutorial is for the intuition of simple Generative Adversarial Networks (GAN) for generating realistic MRI images. Here, we propose a model that can successfully generate 3D brain MRI data from random vectors by learning the data distribution. After reading this tutorial, you’ll understand the technical details needed to implement VAE-GAN.

Let’s begin with importing stuffs:

import os
import sys
if "CI_MODE" in os.environ:
    sys.exit()

import numpy as np
import logging
import torch
import torch.nn as nn
from torch.autograd import Variable
from pynet.datasets import DataManager, fetch_brats
from pynet.interfaces import DeepLearningInterface
from pynet.plotting import Board, update_board
from pynet.utils import setup_logging
from pynet.preprocessing.spatial import downsample
from pynet.models import BGDiscriminator, BGEncoder, BGGenerator


# Global parameters
logger = logging.getLogger("pynet")
setup_logging(level="debug")

The model will be trained on BRATS

We will train the model to synthesize brain disorder MRI data (Glioma).

data = fetch_brats(
    datasetdir="/neurospin/nsap/processed/deepbrain/tumor/data/brats")
batch_size = 4

def transformer(data, imgtype="flair"):
    typemap = {
        "t1": 0, "t1ce": 1, "t2": 2, "flair": 3}
    if imgtype is None:
        imgtype = range(4)
    else:
        if not isinstance(imgtype, list):
            imgtype = [imgtype]
        imgtype = [typemap[key] for key in imgtype]
    transformed_data = []
    for channel_id in range(len(data)):
        if channel_id not in imgtype:
            continue
        arr = data[channel_id]
        transformed_data.append(downsample(arr, scale=3))
    return np.asarray(transformed_data)

manager = DataManager(
    input_path=data.input_path,
    metadata_path=data.metadata_path,
    stratify_label="grade",
    number_of_folds=10,
    batch_size=batch_size,
    test_size=0,
    input_transforms=[transformer],
    sample_size=0.2)

Loss

criterion_bce = nn.BCELoss()
criterion_l1 = nn.L1Loss()

Training

We’ll train the encoder, generator and discriminator to optimize the losses using Adam optimizer.

n_epochs = 100
latent_dim = 1000
use_cuda = False
channels = 1
in_shape = (50, 64, 45) # (150, 190, 135)
gamma = 20
beta = 10
device = torch.device("cuda" if use_cuda else "cpu")
generator = BGGenerator(
    in_shape=in_shape, out_channels=channels, start_filts=64,
    latent_dim=latent_dim, mode="trilinear", with_code=False).to(device)
discriminator = BGDiscriminator(
    in_shape=in_shape, in_channels=channels, out_channels=channels,
    start_filts=64, with_logit=True).to(device)
encoder = BGEncoder(
    in_shape=in_shape, in_channels=channels, start_filts=64,
    latent_dim=latent_dim).to(device)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0001)
e_optimizer = torch.optim.Adam(encoder.parameters(), lr = 0.0001)
real_y = Variable(torch.ones((batch_size, channels)).to(device))
fake_y = Variable(torch.zeros((batch_size, channels)).to(device))
board = Board(port=8097, host="http://localhost", env="vae")
outdir = "/tmp/vae-gan/checkpoint"
if not os.path.isdir(outdir):
    os.makedirs(outdir)

for epoch in range(n_epochs):
    loaders = manager.get_dataloader(train=True, validation=False,
                                     fold_index=0)
    for iteration, item in enumerate(loaders.train):
        real_images = item.inputs.to(device)
        batch_size = real_images.size(0)
        real_images = Variable(real_images,requires_grad=False).to(device)
        z_rand = Variable(torch.randn(
            (batch_size, latent_dim)), requires_grad=False).to(device)
        mean, logvar, code = encoder(real_images)
        x_rec = generator(code)
        x_rand = generator(z_rand)
        logger.debug("X_real: {0}".format(real_images.shape))
        logger.debug("X_rand: {0}".format(x_rand.shape))
        logger.debug("X_rec: {0}".format(x_rec.shape))

        # Train discriminator
        d_optimizer.zero_grad()
        d_real_loss = criterion_bce(
            discriminator(real_images), real_y[:batch_size])
        d_recon_loss = criterion_bce(discriminator(x_rec), fake_y[:batch_size])
        d_fake_loss = criterion_bce(discriminator(x_rand), fake_y[:batch_size])
        dis_loss = d_recon_loss + d_real_loss + d_fake_loss
        dis_loss.backward(retain_graph=True)
        d_optimizer.step()

        # Train generator
        g_optimizer.zero_grad()
        output = discriminator(real_images)
        d_real_loss = criterion_bce(output, real_y[:batch_size])
        output = discriminator(x_rec)
        d_recon_loss = criterion_bce(output, fake_y[:batch_size])
        output = discriminator(x_rand)
        d_fake_loss = criterion_bce(output, fake_y[:batch_size])
        d_img_loss = d_real_loss + d_recon_loss + d_fake_loss
        gen_img_loss = -d_img_loss
        rec_loss = ((x_rec - real_images)**2).mean()
        err_dec = gamma * rec_loss + gen_img_loss
        err_dec.backward(retain_graph=True)
        g_optimizer.step()

        # Train encoder
        prior_loss = 1 + logvar-mean.pow(2) - logvar.exp()
        prior_loss = (-0.5 * torch.sum(prior_loss)) / torch.numel(mean.data)
        err_enc = prior_loss + beta * rec_loss
        e_optimizer.zero_grad()
        err_enc.backward()
        e_optimizer.step()

        # Visualization
        if iteration % 4 == 0:
            print("[{0}/{1}]".format(epoch, n_epochs),
                  "D: {:<8.3}".format(dis_loss.item()),
                  "En: {:<8.3}".format(err_enc.item()),
                  "De: {:<8.3}".format(err_dec.item()))

            for name, data in [("X_real", real_images), ("X_dec", x_rec),
                               ("X_rand", x_rand)]:
                featmask = (0.5 * data[0] + 0.5).data.cpu().numpy()
                img = featmask[..., featmask.shape[-1] // 2]
                img = np.expand_dims(img, axis=1)
                img = (img / img.max()) * 255
                board.viewer.images(
                    img,
                    opts={
                        "title": name,
                        "caption": name},
                    win=name)

    # Save model
    for name, model in [("generator", generator),
                        ("discriminator", discriminator),
                        ("encoder", encoder)]:
        fname = os.path.join(
            outdir, name + "_epoch_" + str(epoch + 1) + ".pth")
        torch.save(model.state_dict(), fname)

Conclusion

Variational Auto-Encoder(VAE) GAN are free from mode collapse but outputs are characterized with blurriness. In order to effectively address the problems of both mode collapse of GANs and blurriness of VAEs, we will use α-GAN, a solution born by combining both models, in the next tutorial.

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

Follow us

© 2019, pynet developers .
Inspired by AZMIND template.