Helper Module for Deep Learning.
Beta VAE disentanglingΒΆ
Credit: A Grigis
# Imports
import os
import sys
if "CI_MODE" in os.environ:
sys.exit()
import glob
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import torch
from torch.distributions import Normal, kl_divergence
from pynet import NetParameters
from pynet.datasets import DataManager
from pynet.datasets.dsprites import DSprites
from pynet.interfaces import VAENetEncoder
from pynet.plotting import Board, update_board
from pynet.losses import get_vae_loss
from pynet.models.vae.utils import (
reconstruct_traverse, make_mosaic_img, add_labels)
# Global parameters
WDIR = "/tmp/beta_vae_disentangling"
BATCH_SIZE = 64
N_EPOCHS = 30
ADAM_LR = 5e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DISPLAY = False
# Load the data
dataset = DSprites(WDIR)
manager = DataManager.from_dataset(
train_dataset=dataset, batch_size=BATCH_SIZE, sampler="random")
# Test different losses
loss_params = {
"betah": {"beta": 4, "steps_anneal": 0, "use_mse": True},
"betab": {"C_init": 0.5, "C_fin": 25, "gamma": 100,
"steps_anneal": 100000, "use_mse": True},
"btcvae": {"dataset_size": len(dataset), "alpha": 1, "beta": 1, "gamma": 6,
"is_mss": True, "steps_anneal": 0, "use_mse": True}
}
def plot_losses(cache, filename):
if "kl" not in cache or "ll" not in cache:
return
ll = np.asarray(cache["ll"]).squeeze()
kl = np.asarray(cache["kl"]).squeeze()
fig, axs = plt.subplots(nrows=1, ncols=2)
colors = list(mcolors.TABLEAU_COLORS.keys())
for idx, dim_kl in enumerate(kl.T):
axs[0].plot(
dim_kl, color=colors[idx], label="dim{0}".format(idx + 1))
axs[0].set_xlabel("Training iterations")
axs[0].set_ylabel("KL")
axs[1].plot(
ll, dim_kl, color=colors[idx], label="dim{0}".format(idx + 1))
axs[1].set_xlabel("Log Likelihood")
axs[1].set_ylabel("KL")
plt.legend(loc="upper left")
plt.tight_layout()
plt.savefig(filename)
def plot_reconstructions(model, data, checkpointdir, filename=None):
weights_files = glob.glob(os.path.join(checkpointdir, "*.pth"))
n_plots = len(weights_files)
original = data.cpu().numpy()
original = np.expand_dims(original, axis=0)
stages = [original]
labels = ["orig"]
for idx, path in enumerate(sorted(weights_files)):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model"])
reconstruction = model.reconstruct(data, sample=False)
reconstruction = np.expand_dims(reconstruction, axis=0)
stages.append(reconstruction)
labels.append("rec stage {0}".format(idx + 1))
concatenated = np.concatenate(stages, axis=0)
mosaic = make_mosaic_img(concatenated)
concatenated = Image.fromarray(mosaic)
concatenated = add_labels(concatenated, labels)
if filename is not None:
concatenated.save(filename)
return concatenated
for loss_name in ("betah", "betab", "btcvae"):
# Train the model
checkpointdir = os.path.join(WDIR, "checkpoints", loss_name)
if not os.path.isdir(checkpointdir):
os.makedirs(checkpointdir)
weights_filename = os.path.join(
checkpointdir, "model_0_epoch_{0}.pth".format(N_EPOCHS))
params = NetParameters(
input_channels=1,
input_dim=DSprites.img_size,
conv_flts=[32, 32, 32, 32],
dense_hidden_dims=[256, 256],
latent_dim=10,
noise_out_logvar=-3,
noise_fixed=False,
act_func=None,
dropout=0,
sparse=False)
loss = get_vae_loss(loss_name=loss_name, **loss_params[loss_name])
if os.path.isfile(weights_filename):
vae = VAENetEncoder(
params,
optimizer_name="Adam",
learning_rate=ADAM_LR,
loss=loss,
use_cuda=(DEVICE.type == "cuda"),
pretrained=weights_filename)
else:
vae = VAENetEncoder(
params,
optimizer_name="Adam",
learning_rate=ADAM_LR,
loss=loss,
use_cuda=(DEVICE.type == "cuda"))
vae.board = Board(
port=8097, host="http://localhost", env="beta-vae")
vae.add_observer("after_epoch", update_board)
train_history, valid_history = vae.training(
manager=manager,
nb_epochs=(N_EPOCHS + 1),
checkpointdir=checkpointdir,
fold_index=0,
with_validation=False,
save_after_epochs=5)
plot_losses(vae.loss.cache,
os.path.join(WDIR, "loss_{0}.png".format(loss_name)))
print(vae.model)
# Display results
index = np.arange(len(dataset))
np.random.shuffle(index)
data = torch.unsqueeze(torch.from_numpy(
dataset.imgs[index][:100].astype(np.float32)), dim=1).to(DEVICE)
vae.model.eval()
name = "traverse_posteriror_{0}".format(loss_name)
filename = os.path.join(WDIR, "{0}.png".format(name))
mosaic_traverse = reconstruct_traverse(
vae.model, data, n_per_latent=8, n_latents=None, is_posterior=True,
filename=filename)
filename = os.path.join(
WDIR, "reconstruction_stages_{0}.png".format(loss_name))
plot_reconstructions(vae.model, data[:8], checkpointdir, filename=filename)
if DISPLAY:
plt.figure()
plt.imshow(np.asarray(mosaic_traverse))
plt.title(name)
plt.axis("off")
if DISPLAY:
plt.show()
Total running time of the script: ( 0 minutes 0.000 seconds)
Gallery generated by Sphinx-Gallery
Follow us
© 2019, pynet developers .
Inspired by AZMIND template.
Inspired by AZMIND template.