If you have been doing any machine or deep learning lately it’s likely that you have stumbled upon TensorBoard . TensorBoard is great because it let’s you interactively monitor training curves, plot graphs, show histograms and distributions of variables, include images and audio among many other useful utilities that make experiments more interpretable. However, one thing missing is the out of the box ability to add popular matplotlib plots. Python users typically have a variety of figure templates lying around and it would be a shame to not fully exploit them. Unfortunately, in contrast to adding images stored as arrays, adding arbitrary figures isn’t an inherently available TensorBoard option. In this post, we will see how we can use a small trick to directly add any plot to TensorBoard directly.
Originally developed for TensorFlow there is also a more generic extension called TensorBoardX for other deep learning frameworks such as PyTorch , Chainer or MXNet . In principle the only requirement with this extension is to have an interface to Numpy arrays. To keep the rest of this text framework agnostic we will thus be operating directly on these Numpy arrays as any of the above mentioned frameworks are readily capable to convert their tensors.
The default: adding standard images
Let’s quickly refresh with an example of how to add a ready to go image to TensorBoard. Below code is an example in PyTorch using torchvision to generate a grid of validation set images from a dataset loader (here Fashion-MNIST). Adding the image stored as a Numpy array to TensorBoard can be directly done with the writer.add_image() function.
The resulting image in TensorBoard will be listed under the “IMAGES” tab:
Adding matplotlib figures: conversion from fig to numpy
The reason we cannot add our matplotlib figures to TensorBoard directly is primarily a format issue as figures aren’t stored as Numpy arrays. There also doesn’t seem to be a trivial to invoke function that does this conversion for us. So we are going to use a little trick. We start from a plotting function that we would otherwise use with savefig() or imshow() to get our rendering result. In machine learning this could be a function to render a confusion matrix, plotting the inter-class confusion in a classification scenario. We will keep this function without any modification but to add a single auxiliary function call that will do the conversion. Such a function to plot a confusion matrix could then look like the following:
This function takes an integer to specify the time-stamp, an epoch, iteration etc., a matrix where both columns and rows contain percent values of accuracy/confusion, the TensorBoard writer instance and a dictionary mapping integer labels to string class names. The actual plot can be done with matshow(), but the figure won’t actually be rendered until we make a call to our auxiliary function “plot_to_tensorboard”:
This function receives our previously defined integer step (for naming convention), the TensorBoard writer instance and the figure to be rendered. In addition, one could add a string argument with a name to distinguish different types of plots. In order to turn the matplotlib figure into a Numpy array we draw the figure on a canvas. This canvas is then read pixel by pixel and values are converted to uint8 (0-255 range) integer RGB values from the received string. We query the canvas width and height in order to reshape our received Numpy array into a three dimensional one. Before adding this received array to TensorBoard we should make sure to eventually normalize into 0-1 range by dividing by 255 or swap the axes dependent on the versions of TensorFlow, TensorBoard, PyTorch etc. The common pattern seems to be that older versions expect color channels as the last array dimension, whereas the API has changed to expect color channels as the first array dimension in more recent versions.
Finally, here is an example of an MLP trained for 10 epochs on the Fashion-MNIST dataset and confusion matrices for each epoch:
Concluding remarks: During the making of this post I have stumbled upon multiple challenges such as the previously mentioned API changes in color channel ordering, changing expectations on 0-255 integer ranged vs. 0-1 float encoded images as well as matplotlib proving to be difficult when using the draw() function. If you happen to receive a blank “white” (255 valued) image with any of your plots try normalizing the array or make a call to imshow() as with specific matplotlib versions the draw() function seems to have a bug that prevents a call to the renderer.