MITB Banner

Hands-On Guide To BigGAN With Python Code

The basic premise of BigGAN is simple; scale-up GAN training to benefit from larger models and larger batches.
BIGGAN

The paper’s basic premise is simple; scale-up GAN training to benefit from larger models and larger batches. Although training a baseline SA-GAN architecture on a larger scale did lead to significant performance improvement, it also made the models unstable. To deal with this issue, the authors introduced two architectural changes and modified the regularization scheme. Not only did these changes lead to better scalability and performance, but they also had a useful side effect. The new modified architecture, BigGAN, became susceptible to the “truncation trick”, a sampling technique that allows explicit, fine-grained control of the trade-off between sample fidelity and variety. 

Architecture & Approach

BigGAN architecture
BigGAN generator network

The BigGAN model uses the ResNet GAN architecture but with the channel pattern in the discriminator network (D) modified so that the number of filters in the first convolutional layer of each block is equal to the number of output filters. A single shared class embedding and skip connections for the latent vector z (skip-z) are used in the generator (G). Hierarchical latent spaces are employed to split the latent vector z along its channel dimension into chunks of equal size. Each chunk is concatenated to the shared class embedding and passed to a corresponding residual block as a conditioning vector. Each block’s conditioning is linearly projected to produce per-sample gains and biases for the block’s BatchNorm layers. The bias projections are zero-centered, while the gain projections are centered at 1. 

ResBlock up used in the BigGAN generator and ResBlock down used in the discriminator
ResBlock up used in the BigGAN generator and ResBlock down used in the discriminator
Truncation Trick
Effects of truncation
The effects of increasing truncation. From left to right, the threshold is set to 2, 1, 0.5, 0.04.

Truncating a z vector by resampling the values with magnitude above a chosen threshold leads to improvement in individual sample quality at the cost of a reduction in overall sample variety. As IS does not penalize lack of variety in class-conditional models, reducing the truncation threshold leads to a direct increase in IS (analogous to precision). FID(Fréchet inception distance)  penalizes lack of variety (analogous to recall) and rewards precision, so initially, a moderate improvement in FID is seen, but as truncation approaches zero and variety diminishes, the FID sharply drops.

saturation artifacts
Saturation artifacts produced by poorly conditioned models

Sampling with different latents than those seen in training causes a problematic distribution shift for many models. Some of the larger BigGAN models are not amenable to truncation and produce saturation artifacts when fed truncated noise. To counteract this, amenability to truncation needs to be enforced by conditioning G to be smooth so that the full space of z will map to good output samples. Orthogonal Regularization is used for this conditioning of the generator network:

Here W is a weight matrix and β a hyperparameter. This is often too limiting, so the authors explored several variants that relaxed the constraint while still imparting the desired smoothness. The best version removes the diagonal terms from the regularization, and minimizes the pairwise cosine similarity between filters but does not constrain their norm: 

Here 1 denotes a matrix with all elements set to 1.

Synthesizing images using the BigGAN generator

Following code is a reference to the TensorFlow implementation of BigGAN available on TensorFlow hub.

  1. Import necessary library and classes
 # set all global behaviors to TensorFlow 1.x 
 import tensorflow.compat.v1 as tf
 tf.disable_v2_behavior()
 import os
 import io
 import IPython.display
 import numpy as np
 import PIL.Image
 from scipy.stats import truncnorm
 import tensorflow_hub as hub 
  1. Load the 256×256 generator module from TensorFlow hub
 module_path = 'https://tfhub.dev/deepmind/biggan-deep-256/1'
 tf.reset_default_graph()
 print('Loading BigGAN module from:', module_path)
 module = hub.Module(module_path)
 inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
           for k, v in module.get_input_info_dict().items()}
 output = module(inputs) 
  1. Create helper functions for one-hot encoding labels, sampling, and displaying images.
 input_z = inputs['z']
 input_y = inputs['y']
 input_trunc = inputs['truncation']
 dim_z = input_z.shape.as_list()[1]
 vocab_size = input_y.shape.as_list()[1]

 def truncated_z_sample(batch_size, truncation=1., seed=None):
   state = None if seed is None else np.random.RandomState(seed)
   values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state)
   return truncation * values

 def one_hot(index, vocab_size=vocab_size):
   index = np.asarray(index)
   if len(index.shape) == 0:
     index = np.asarray([index])
   assert len(index.shape) == 1
   num = index.shape[0]
   output = np.zeros((num, vocab_size), dtype=np.float32)
   output[np.arange(num), index] = 1
   return output

 def one_hot_if_needed(label, vocab_size=vocab_size):
   label = np.asarray(label)
   if len(label.shape) <= 1:
     label = one_hot(label, vocab_size)
   assert len(label.shape) == 2
   return label

 def sample(sess, noise, label, truncation=1., batch_size=8,
            vocab_size=vocab_size):
   noise = np.asarray(noise)
   label = np.asarray(label)
   num = noise.shape[0]
   if len(label.shape) == 0:
     label = np.asarray([label] * num)
   label = one_hot_if_needed(label, vocab_size)
   ims = []
   for batch_start in range(0, num, batch_size):
     s = slice(batch_start, min(num, batch_start + batch_size))
     feed_dict = {input_z: noise[s], input_y: label[s], input_trunc: truncation}
     ims.append(sess.run(output, feed_dict=feed_dict))
   ims = np.concatenate(ims, axis=0)
   assert ims.shape[0] == num
   ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255)
   ims = np.uint8(ims)
   return ims

 def imgrid(imarray, cols=5, pad=1):
   pad = int(pad)
   assert pad >= 0
   cols = int(cols)
   assert cols >= 1
   N, H, W, C = imarray.shape
   rows = N // cols + int(N % cols != 0)
   batch_pad = rows * cols - N
   assert batch_pad >= 0
   post_pad = [batch_pad, pad, pad, 0]
   pad_arg = [[0, p] for p in post_pad]
   imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)
   H += pad
   W += pad
   grid = (imarray
           .reshape(rows, cols, H, W, C)
           .transpose(0, 2, 1, 3, 4)
           .reshape(rows*H, cols*W, C))
   if pad:
     grid = grid[:-pad, :-pad]
   return grid

 def imshow(a, format='png', jpeg_fallback=True):
   a = np.asarray(a, dtype=np.uint8)
   data = io.BytesIO()
   PIL.Image.fromarray(a).save(data, format)
   im_data = data.getvalue()
   try:
     disp = IPython.display.display(IPython.display.Image(im_data))
   except IOError:
     if jpeg_fallback and format != 'jpeg':
       print(('Warning: image was too large to display in format "{}"; '
              'trying jpeg instead.').format(format))
       return imshow(a, format='jpeg')
     else:
       raise
   return disp 
  1. Create a TensorFlow session, initialize variables and generate some images using the BigGAN.
 # create TensorFlow session and initialize variables
 initializer = tf.global_variables_initializer()
 sess = tf.Session()
 sess.run(initializer)

 # set noise seed, num of images, truncation and category to be sampled
 num_samples = 10 
 truncation = 0.5 
 noise_seed = 0 
 category = "971) bubble" 

 z = truncated_z_sample(num_samples, truncation, noise_seed)
 y = int(category.split(')')[0])
 ims = sample(sess, z, y, truncation=truncation)
 imshow(imgrid(ims, cols=min(num_samples, 5))) 
Images sampled by BigGAN

Colab Notebook of the above implementation.

Synthesizing images from text prompts using CLIP and BigGAN generator

The following code has been taken from the simplified BigSleep notebook created by Ryan Murdock by combining OpenAI’s CLIP and the generator from a BigGAN.

  1. Install BigSleep

pip install big-sleep --upgrade

  1. Generate images from text prompts 
 from tqdm.notebook import trange
 from IPython.display import Image, display
 from big_sleep import Imagine
 TEXT = 'upside down tree' 
 SAVE_EVERY = 100 
 SAVE_PROGRESS = True 
 LEARNING_RATE = 5e-2 
 ITERATIONS = 1000 
 SEED = 0 

 model = Imagine(
     text = TEXT,
     save_every = SAVE_EVERY,
     lr = LEARNING_RATE,
     iterations = ITERATIONS,
     save_progress = SAVE_PROGRESS,
     seed = SEED
 )

 for epoch in trange(20, desc = 'epochs'):
     for i in trange(1000, desc = 'iteration'):
         model.train_step(epoch, i)
         if i == 0 or i % model.save_every != 0:
             continue
         filename = TEXT.replace(' ', '_')
         image = Image(f'./{filename}.png')
         display(image) 
Images generated using BigSleep (CLIP + BigGAN)
Note: The “Flying Car” and “Upside down tree” images are not the final images, my internet went down during generation and these were the last images saved by Collab.

References:

Access all our open Survey & Awards Nomination forms in one place >>

Picture of Aditya Singh

Aditya Singh

A machine learning enthusiast with a knack for finding patterns. In my free time, I like to delve into the world of non-fiction books and video essays.

Download our Mobile App

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox
Recent Stories