The paper’s basic premise is simple; scale-up GAN training to benefit from larger models and larger batches. Although training a baseline SA-GAN architecture on a larger scale did lead to significant performance improvement, it also made the models unstable. To deal with this issue, the authors introduced two architectural changes and modified the regularization scheme. Not only did these changes lead to better scalability and performance, but they also had a useful side effect. The new modified architecture, BigGAN, became susceptible to the “truncation trick”, a sampling technique that allows explicit, fine-grained control of the trade-off between sample fidelity and variety.
Architecture & Approach
BigGAN generator network
The BigGAN model uses the ResNet GAN architecture but with the channel pattern in the discriminator network (D) modified so that the number of filters in the first convolutional layer of each block is equal to the number of output filters. A single shared class embedding and skip connections for the latent vector z (skip-z) are used in the generator (G). Hierarchical latent spaces are employed to split the latent vector z along its channel dimension into chunks of equal size. Each chunk is concatenated to the shared class embedding and passed to a corresponding residual block as a conditioning vector. Each block’s conditioning is linearly projected to produce per-sample gains and biases for the block’s BatchNorm layers. The bias projections are zero-centered, while the gain projections are centered at 1.
Truncation Trick
Truncating a z vector by resampling the values with magnitude above a chosen threshold leads to improvement in individual sample quality at the cost of a reduction in overall sample variety. As IS does not penalize lack of variety in class-conditional models, reducing the truncation threshold leads to a direct increase in IS (analogous to precision). FID(Fréchet inception distance) penalizes lack of variety (analogous to recall) and rewards precision, so initially, a moderate improvement in FID is seen, but as truncation approaches zero and variety diminishes, the FID sharply drops.
Sampling with different latents than those seen in training causes a problematic distribution shift for many models. Some of the larger BigGAN models are not amenable to truncation and produce saturation artifacts when fed truncated noise. To counteract this, amenability to truncation needs to be enforced by conditioning G to be smooth so that the full space of z will map to good output samples. Orthogonal Regularization is used for this conditioning of the generator network:
Here W is a weight matrix and β a hyperparameter. This is often too limiting, so the authors explored several variants that relaxed the constraint while still imparting the desired smoothness. The best version removes the diagonal terms from the regularization, and minimizes the pairwise cosine similarity between filters but does not constrain their norm:
Here 1 denotes a matrix with all elements set to 1.
Synthesizing images using the BigGAN generator
Following code is a reference to the TensorFlow implementation of BigGAN available on TensorFlow hub.
- Import necessary library and classes
# set all global behaviors to TensorFlow 1.x import tensorflow.compat.v1 as tf tf.disable_v2_behavior() import os import io import IPython.display import numpy as np import PIL.Image from scipy.stats import truncnorm import tensorflow_hub as hub
- Load the 256×256 generator module from TensorFlow hub
module_path = 'https://tfhub.dev/deepmind/biggan-deep-256/1' tf.reset_default_graph() print('Loading BigGAN module from:', module_path) module = hub.Module(module_path) inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k) for k, v in module.get_input_info_dict().items()} output = module(inputs)
- Create helper functions for one-hot encoding labels, sampling, and displaying images.
input_z = inputs['z'] input_y = inputs['y'] input_trunc = inputs['truncation'] dim_z = input_z.shape.as_list()[1] vocab_size = input_y.shape.as_list()[1] def truncated_z_sample(batch_size, truncation=1., seed=None): state = None if seed is None else np.random.RandomState(seed) values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state) return truncation * values def one_hot(index, vocab_size=vocab_size): index = np.asarray(index) if len(index.shape) == 0: index = np.asarray([index]) assert len(index.shape) == 1 num = index.shape[0] output = np.zeros((num, vocab_size), dtype=np.float32) output[np.arange(num), index] = 1 return output def one_hot_if_needed(label, vocab_size=vocab_size): label = np.asarray(label) if len(label.shape) <= 1: label = one_hot(label, vocab_size) assert len(label.shape) == 2 return label def sample(sess, noise, label, truncation=1., batch_size=8, vocab_size=vocab_size): noise = np.asarray(noise) label = np.asarray(label) num = noise.shape[0] if len(label.shape) == 0: label = np.asarray([label] * num) label = one_hot_if_needed(label, vocab_size) ims = [] for batch_start in range(0, num, batch_size): s = slice(batch_start, min(num, batch_start + batch_size)) feed_dict = {input_z: noise[s], input_y: label[s], input_trunc: truncation} ims.append(sess.run(output, feed_dict=feed_dict)) ims = np.concatenate(ims, axis=0) assert ims.shape[0] == num ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255) ims = np.uint8(ims) return ims def imgrid(imarray, cols=5, pad=1): pad = int(pad) assert pad >= 0 cols = int(cols) assert cols >= 1 N, H, W, C = imarray.shape rows = N // cols + int(N % cols != 0) batch_pad = rows * cols - N assert batch_pad >= 0 post_pad = [batch_pad, pad, pad, 0] pad_arg = [[0, p] for p in post_pad] imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255) H += pad W += pad grid = (imarray .reshape(rows, cols, H, W, C) .transpose(0, 2, 1, 3, 4) .reshape(rows*H, cols*W, C)) if pad: grid = grid[:-pad, :-pad] return grid def imshow(a, format='png', jpeg_fallback=True): a = np.asarray(a, dtype=np.uint8) data = io.BytesIO() PIL.Image.fromarray(a).save(data, format) im_data = data.getvalue() try: disp = IPython.display.display(IPython.display.Image(im_data)) except IOError: if jpeg_fallback and format != 'jpeg': print(('Warning: image was too large to display in format "{}"; ' 'trying jpeg instead.').format(format)) return imshow(a, format='jpeg') else: raise return disp
- Create a TensorFlow session, initialize variables and generate some images using the BigGAN.
# create TensorFlow session and initialize variables initializer = tf.global_variables_initializer() sess = tf.Session() sess.run(initializer) # set noise seed, num of images, truncation and category to be sampled num_samples = 10 truncation = 0.5 noise_seed = 0 category = "971) bubble" z = truncated_z_sample(num_samples, truncation, noise_seed) y = int(category.split(')')[0]) ims = sample(sess, z, y, truncation=truncation) imshow(imgrid(ims, cols=min(num_samples, 5)))
Colab Notebook of the above implementation.
Synthesizing images from text prompts using CLIP and BigGAN generator
The following code has been taken from the simplified BigSleep notebook created by Ryan Murdock by combining OpenAI’s CLIP and the generator from a BigGAN.
- Install BigSleep
pip install big-sleep --upgrade
- Generate images from text prompts
from tqdm.notebook import trange from IPython.display import Image, display from big_sleep import Imagine TEXT = 'upside down tree' SAVE_EVERY = 100 SAVE_PROGRESS = True LEARNING_RATE = 5e-2 ITERATIONS = 1000 SEED = 0 model = Imagine( text = TEXT, save_every = SAVE_EVERY, lr = LEARNING_RATE, iterations = ITERATIONS, save_progress = SAVE_PROGRESS, seed = SEED ) for epoch in trange(20, desc = 'epochs'): for i in trange(1000, desc = 'iteration'): model.train_step(epoch, i) if i == 0 or i % model.save_every != 0: continue filename = TEXT.replace(' ', '_') image = Image(f'./{filename}.png') display(image)