Including arbitrary matplotlib plots in TensorFlow’s TensorBoard

If you have been doing any machine or deep learning lately it’s likely that you have stumbled upon TensorBoard . TensorBoard is great because it let’s you interactively monitor training curves, plot graphs, show histograms and distributions of variables, include images and audio among many other useful utilities that make experiments more interpretable. However, one thing missing is the out of the box ability to add popular matplotlib plots. Python users typically have a variety of figure templates lying around and it would be a shame to not fully exploit them. Unfortunately, in contrast to adding images stored as arrays, adding arbitrary figures isn’t an inherently available TensorBoard option. In this post, we will see how we can use a small trick to directly add any plot to TensorBoard directly.

Originally developed for TensorFlow there is also a more generic extension called TensorBoardX for other deep learning frameworks such as PyTorch , Chainer or MXNet . In principle the only requirement with this extension is to have an interface to Numpy arrays. To keep the rest of this text framework agnostic we will thus be operating directly on these Numpy arrays as any of the above mentioned frameworks are readily capable to convert their tensors.

The default: adding standard images

Let’s quickly refresh with an example of how to add a ready to go image to TensorBoard. Below code is an example in PyTorch using torchvision to generate a grid of validation set images from a dataset loader (here Fashion-MNIST). Adding the image stored as a Numpy array to TensorBoard can be directly done with the writer.add_image() function.

import torchvision
from tensorboardX import SummaryWriter

writer = SummaryWriter(save_path)

inputs, classes = next(iter(dataset.val_loader))
imgs = torchvision.utils.make_grid(inputs, nrow=8, padding=5, normalize=True)
writer.add_image('valdata_snapshot', imgs)

 

The resulting image in TensorBoard will be listed under the “IMAGES” tab:

Screen Shot 2018-08-16 at 16.51.35

Adding matplotlib figures: conversion from fig to numpy

The reason we cannot add our matplotlib figures to TensorBoard directly is primarily a format issue as figures aren’t stored as Numpy arrays. There also doesn’t seem to be a trivial to invoke function that does this conversion for us. So we are going to use a little trick. We start from a plotting function that we would otherwise use with savefig() or imshow() to get our rendering result. In machine learning this could be a function to render a confusion matrix, plotting the inter-class confusion in a classification scenario. We will keep this function without any modification but to add a single auxiliary function call that will do the conversion. Such a function to plot a confusion matrix could then look like the following:

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

def vis_confusion(writer, step, matrix, class_dict):
    """
    Visualization of confusion matrix

    Parameters:
        writer (tensorboard.SummaryWriter): TensorBoard SummaryWriter instance.
        step (int): Counter usually specifying steps/epochs/time.
        matrix (numpy.array): Square-shaped array of size class x class.
            Should specify cross-class accuracies/confusion in percent
            values (range 0-1).
        class_dict (dict): Dictionary specifying class names as keys and
            corresponding integer labels/targets as values.
    """

    all_categories = sorted(class_dict, key=class_dict.get)

    # Normalize by dividing every row by its sum
    matrix = matrix.astype(float)
    for i in range(len(class_dict)):
        matrix[i] = matrix[i] / matrix[i].sum()

    # Create the figure
    fig = plt.figure()
    ax = fig.add_subplot(111)

    # Show the matrix and define a discretized color bar
    cax = ax.matshow(matrix)
    fig.colorbar(cax, boundaries=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])

    # Set up axes. Rotate the x ticks by 90 degrees.
    ax.set_xticklabels([''] + all_categories, rotation=90)
    ax.set_yticklabels([''] + all_categories)

    # Force label at every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    # Turn off the grid for this plot. Enforce a tight layout to reduce white margins
    ax.grid(False)
    plt.tight_layout()

    # Call our auxiliary to TensorBoard function to render the figure 
    plot_to_tensorboard(writer, fig, step)

 

This function takes an integer to specify the time-stamp, an epoch, iteration etc., a matrix where both columns and rows contain percent values of accuracy/confusion, the TensorBoard writer instance and a dictionary mapping integer labels to string class names. The actual plot can be done with matshow(), but the figure won’t actually be rendered until we make a call to our auxiliary function “plot_to_tensorboard”:

def plot_to_tensorboard(writer, fig, step):
    """
    Takes a matplotlib figure handle and converts it using
    canvas and string-casts to a numpy array that can be
    visualized in TensorBoard using the add_image function

    Parameters:
        writer (tensorboard.SummaryWriter): TensorBoard SummaryWriter instance.
        fig (matplotlib.pyplot.fig): Matplotlib figure handle.
        step (int): counter usually specifying steps/epochs/time.
    """

    # Draw figure on canvas
    fig.canvas.draw()

    # Convert the figure to numpy array, read the pixel values and reshape the array
    img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    # Normalize into 0-1 range for TensorBoard(X). Swap axes for newer versions where API expects colors in first dim
    img = img / 255.0
    # img = np.swapaxes(img, 0, 2) # if your TensorFlow + TensorBoard version are >= 1.8

    # Add figure in numpy "image" to TensorBoard writer
    writer.add_image('confusion_matrix', img, step)
    plt.close(fig)

 

This function receives our previously defined integer step (for naming convention), the TensorBoard writer instance and the figure to be rendered. In addition, one could add a string argument with a name to distinguish different types of plots. In order to turn the matplotlib figure into a Numpy array we draw the figure on a canvas. This canvas is then read pixel by pixel and values are converted to uint8 (0-255 range) integer RGB values from the received string. We query the canvas width and height in order to reshape our received Numpy array into a three dimensional one. Before adding this received array to TensorBoard we should make sure to eventually normalize into 0-1 range by dividing by 255 or swap the axes dependent on the versions of TensorFlow, TensorBoard, PyTorch etc. The common pattern seems to be that older versions expect color channels as the last array dimension, whereas the API has changed to expect color channels as the first array dimension in more recent versions.

Finally, here is an example of an MLP trained for 10 epochs on the Fashion-MNIST dataset and confusion matrices for each epoch:

2018-08-16 17.02.44

Concluding remarks: During the making of this post I have stumbled upon multiple challenges such as the previously mentioned API changes in color channel ordering, changing expectations on 0-255 integer ranged vs. 0-1 float encoded images as well as matplotlib proving to be difficult when using the draw() function. If you happen to receive a blank “white” (255 valued) image with any of your plots try normalizing the array or make a call to imshow() as with specific matplotlib versions the draw() function seems to have a bug that prevents a call to the renderer.

Leave a Reply

Your email address will not be published. Required fields are marked *