Most machine learning models are never visualized. Visualizing a model and its parameters often leads to immediate insights or bugfixes, but getting a good visual requires a lot of one-off work.

How do we get useful visualizations without requiring too much human overhead? I think code is underrated as a visualization. In this blog post I show a family of pragmatic visualizations that are each created by simply printing code expressions and rendering their parameters inline. Often, the printed code is not runnable, but is instead a visually optimized version of the model’s code.

I start with two real-life examples where these visualizations provided valuable insights, using these examples to demonstrate the visualizations. Then I turn to my main focus: eliminating the one-off work. I share rows2prose, a Python library that generates these visuals from a dataframe of model parameters and any styled text (not just code). This leaves only the problems of tracing model parameters to a dataframe and printing visually optimized model code. I show how scientific computing frameworks can help with this, in this case using a Lisp-macro-like approach which I demonstrate with Vexpr.

Example 1: The Model That Was Good, Then Bad, Then Good

Last year I ran a machine learning experiment and noticed that the baseline model from BoTorch had a weird curve.

Schematic

Chart 1: Scaling curve for a Gaussian Process performing hold-one-out cross-validation. Given a dataset, the model is trained on all but one point in the dataset, then it predicts the output for the held-out point. I repeat the experiment on 50 different random subsets of a larger dataset and plot the 10th percentile, 90th percentile, and geometric mean.

The model is a Gaussian Process, but that detail really doesn’t matter for this blog post. This is a strange scaling curve for any machine learning model. We would expect a model to continually get better at prediction as it receives larger datasets, with diminishing returns only at the end. Instead, this one has a big lull in the middle, giving it an almost staircase shape. What could this mean? (Consider pausing here and guessing the reason. I had a guess, and my guess turned out to be wrong.)

An obvious next step is to look at the model itself, not just the model’s results. How do the final trained model parameters change as we scale up the size of the dataset? I visualize each individual parameter in diagrams like this:

Schematic

The vertical axis separates different experiments—in this case, different dataset sizes—from top to bottom. The horizontal axis shows parameter values across different repetitions of that experiment. To visualize the entire model, I print a simplified version of the the model’s code, rendering each parameter in-place using one of these diagrams for each parameter.

Here’s what the model looked like for each point in Chart 1. Feel free to zoom in, but don’t get bogged down in the details, just observe that each parameter had interesting changes about halfway down, coinciding with the interesting changes from the experiment.

Model Visualization

Predict using a Gaussian Process with the following covariance and mean.

Covariance kernel: Use distance between points as follows:

* sum([ # Kernel: Factorized scalar vs choice parameters * sum([ # Scalar parameters * matern_25( norm_l2([ compare('log_epochs') / , compare('log_batch_size') / , compare('log_conv1_weight_decay') / , compare('log_conv2_weight_decay') / , compare('log_conv3_weight_decay') / , compare('log_dense1_weight_decay') / , compare('log_dense2_weight_decay') / , compare('log_1cycle_initial_lr_pct') / , compare('log_1cycle_final_lr_pct') / , compare('log_1cycle_pct_warmup') / , compare('log_1cycle_max_lr') / , compare('log_1cycle_momentum_max_damping_factor') / , compare('log_1cycle_momentum_min_damping_factor_pct') / , compare('log_1cycle_beta1_max_damping_factor') / , compare('log_1cycle_beta1_min_damping_factor_pct') / , compare('log_beta2_damping_factor') / , compare('log_conv1_channels') / , compare('log_conv2_channels') / , compare('log_conv3_channels') / , compare('log_dense1_units') / ])), # Choice parameters * exp( -norm_l1([ compare('choice_nhot0') / , compare('choice_nhot1') / , compare('choice_nhot2') / , compare('choice_nhot3') / ]))]), # Kernel: Joint scalar and choice parameters * prod([ matern_25( norm_l2([ compare('log_epochs') / , compare('log_batch_size') / , compare('log_conv1_weight_decay') / , compare('log_conv2_weight_decay') / , compare('log_conv3_weight_decay') / , compare('log_dense1_weight_decay') / , compare('log_dense2_weight_decay') / , compare('log_1cycle_initial_lr_pct') / , compare('log_1cycle_final_lr_pct') / , compare('log_1cycle_pct_warmup') / , compare('log_1cycle_max_lr') / , compare('log_1cycle_momentum_max_damping_factor') / , compare('log_1cycle_momentum_min_damping_factor_pct') / , compare('log_1cycle_beta1_max_damping_factor') / , compare('log_1cycle_beta1_min_damping_factor_pct') / , compare('log_beta2_damping_factor') / , compare('log_conv1_channels') / , compare('log_conv2_channels') / , compare('log_conv3_channels') / , compare('log_dense1_units') / ])), exp( -norm_l1([ compare('choice_nhot0') / , compare('choice_nhot1') / , compare('choice_nhot2') / , compare('choice_nhot3') / ]))])])

When comparing a point to itself, add noise value: (log scale)


Mean: constant


Visualization 1: All of the parameters of a Gaussian Process model rendered in context. The covariance kernel at the top contains a few parameter types: a multiplicative positive scale (top-left), four multiplicative mixing weights (the other four parameters along the left side), while the rest are "lengthscale" parameters that are used to scale distances. Additionally there are noise and mean parameters (bottom). The noise parameter always ends up being very low, due to the default prior pushing it toward zero, and because all the points in the dataset are spaced apart by pseudorandom Sobol generation so the model is never forced to incorporate variance at a single location.

The sudden jump in accuracy in Chart 1 corresponds to a number of jumps in the parameters. Perhaps the most dramatic change was in the top half of the covariance kernel, where we see that a number of parameters stay fixed at about 0.3, until they suddenly jump up to large values.

This strongly suggests a theory: the model’s priors are too strong. With small dataset sizes, the gradients from better predicting the dataset are not powerful enough to overpower the gradient from the priors. After the dataset size crosses some threshold, the parameters are able to break free.

Let’s loosen the prior on the parameters and see if the issue is solved.

Cross-validation results with looser priors

Chart 2: Hold-one-out cross-validation results, comparing the BoTorch baseline to a new configuration with a weaker lengthscale prior.

The model improved significantly. How do its parameters look?

Model Visualization

Predict using a Gaussian Process with the following covariance and mean.

Covariance kernel: Use distance between points as follows:

* sum([ # Kernel: Factorized scalar vs choice parameters * sum([ # Scalar parameters * matern_25( norm_l2([ compare('log_epochs') / , compare('log_batch_size') / , compare('log_conv1_weight_decay') / , compare('log_conv2_weight_decay') / , compare('log_conv3_weight_decay') / , compare('log_dense1_weight_decay') / , compare('log_dense2_weight_decay') / , compare('log_1cycle_initial_lr_pct') / , compare('log_1cycle_final_lr_pct') / , compare('log_1cycle_pct_warmup') / , compare('log_1cycle_max_lr') / , compare('log_1cycle_momentum_max_damping_factor') / , compare('log_1cycle_momentum_min_damping_factor_pct') / , compare('log_1cycle_beta1_max_damping_factor') / , compare('log_1cycle_beta1_min_damping_factor_pct') / , compare('log_beta2_damping_factor') / , compare('log_conv1_channels') / , compare('log_conv2_channels') / , compare('log_conv3_channels') / , compare('log_dense1_units') / ])), # Choice parameters * exp( -norm_l1([ compare('choice_nhot0') / , compare('choice_nhot1') / , compare('choice_nhot2') / , compare('choice_nhot3') / ]))]), # Kernel: Joint scalar and choice parameters * prod([ matern_25( norm_l2([ compare('log_epochs') / , compare('log_batch_size') / , compare('log_conv1_weight_decay') / , compare('log_conv2_weight_decay') / , compare('log_conv3_weight_decay') / , compare('log_dense1_weight_decay') / , compare('log_dense2_weight_decay') / , compare('log_1cycle_initial_lr_pct') / , compare('log_1cycle_final_lr_pct') / , compare('log_1cycle_pct_warmup') / , compare('log_1cycle_max_lr') / , compare('log_1cycle_momentum_max_damping_factor') / , compare('log_1cycle_momentum_min_damping_factor_pct') / , compare('log_1cycle_beta1_max_damping_factor') / , compare('log_1cycle_beta1_min_damping_factor_pct') / , compare('log_beta2_damping_factor') / , compare('log_conv1_channels') / , compare('log_conv2_channels') / , compare('log_conv3_channels') / , compare('log_dense1_units') / ])), exp( -norm_l1([ compare('choice_nhot0') / , compare('choice_nhot1') / , compare('choice_nhot2') / , compare('choice_nhot3') / ]))])])

When comparing a point to itself, add noise value: (log scale)


Mean: constant


Visualization 2: Model parameters now that the priors on the "lengthscales" have been loosened. Specifically, I changed the prior on the lengthscales—the parameters along the right side—from its default value Gamma(3.0, 6.0) to Gamma(1.125, 0.375), a distribution with the same mode but higher variance, so during training there is a weaker gradient pushing each parameter toward the mode. (The model tunes its mixing weights to favor the top half of the kernel, so the parameters in the top half are the ones that increase, while those in the bottom are still pulled toward 0.3.)

The discrete change in parameters is now mostly gone. I ran further experiments with extra weak priors to make the discrete change disappear more, and I found that it worked, further improving results for small datasets, however it began harming results for large datasets.

So I’ve learned that with this model, I’ll get best results if I use weaker priors, at least for small-to-medium datasets. BoTorch / Ax’s built-in priors did not serve me well. That doesn’t mean the default priors are wrong, rather it suggests that users need to be willing to look closely at their machine learning models if they want to get good results. If using a machine learning model always gave users visualizations like these, I think many more people would use them well.

Example 2: The Pitfalls of Parallel Cross-Validation

I think people ought to always see their model, including while it trains. I built this experience for myself, and it quickly provided an insight. Here is a visualization I watched in realtime as my model above trained. This is a batch cross-validation task with 60 datapoints, so I am training 60 models in parallel and visualizing their parameters (hence, multiple dots per parameter). Click the button below to watch the models train.

Model Visualization

Predict using a Gaussian Process with the following covariance and mean.

Covariance kernel: Use distance between points as follows:

* sum([ # Kernel: Factorized scalar vs choice parameters * sum([ # Scalar parameters * matern_25( norm_l2([ compare('log_epochs') / , compare('log_batch_size') / , compare('log_conv1_weight_decay') / , compare('log_conv2_weight_decay') / , compare('log_conv3_weight_decay') / , compare('log_dense1_weight_decay') / , compare('log_dense2_weight_decay') / , compare('log_1cycle_initial_lr_pct') / , compare('log_1cycle_final_lr_pct') / , compare('log_1cycle_pct_warmup') / , compare('log_1cycle_max_lr') / , compare('log_1cycle_momentum_max_damping_factor') / , compare('log_1cycle_momentum_min_damping_factor_pct') / , compare('log_1cycle_beta1_max_damping_factor') / , compare('log_1cycle_beta1_min_damping_factor_pct') / , compare('log_beta2_damping_factor') / , compare('log_conv1_channels') / , compare('log_conv2_channels') / , compare('log_conv3_channels') / , compare('log_dense1_units') / ])), # Choice parameters * exp( -norm_l1([ compare('choice_nhot0') / , compare('choice_nhot1') / , compare('choice_nhot2') / , compare('choice_nhot3') / ]))]), # Kernel: Joint scalar and choice parameters * prod([ matern_25( norm_l2([ compare('log_epochs') / , compare('log_batch_size') / , compare('log_conv1_weight_decay') / , compare('log_conv2_weight_decay') / , compare('log_conv3_weight_decay') / , compare('log_dense1_weight_decay') / , compare('log_dense2_weight_decay') / , compare('log_1cycle_initial_lr_pct') / , compare('log_1cycle_final_lr_pct') / , compare('log_1cycle_pct_warmup') / , compare('log_1cycle_max_lr') / , compare('log_1cycle_momentum_max_damping_factor') / , compare('log_1cycle_momentum_min_damping_factor_pct') / , compare('log_1cycle_beta1_max_damping_factor') / , compare('log_1cycle_beta1_min_damping_factor_pct') / , compare('log_beta2_damping_factor') / , compare('log_conv1_channels') / , compare('log_conv2_channels') / , compare('log_conv3_channels') / , compare('log_dense1_units') / ])), exp( -norm_l1([ compare('choice_nhot0') / , compare('choice_nhot1') / , compare('choice_nhot2') / , compare('choice_nhot3') / ]))])])

When comparing a point to itself, add noise value: (log scale)


Mean: constant


Visualization 3: Cross-validation with dataset size 60. I used the looser lengthscale prior from above, and I used a different noise prior that doesn't endlessly push noise toward 0.

As I watched this in my notebook, I got a sudden impression: many of these points are converging much faster than others. In fact, I think there are hundreds of steps where 59 of the 60 models have converged, and we’re just waiting for the last one. This is particularly evident if you watch the parameters in the lower half of the kernel, where one single faint blue dot slowly approaches the cluster of overlapping dots. This is concerning because all 60 models are being evaluated on every step, even though the last few hundred steps are unnecessary for most of the models.

I tested this theory by comparing the two training approaches.

Schematic
Chart 3: Training trajectory for 60 models trained in batch, compared to that of training each of them separately.

The problem is worse than I thought. Not only do some optimizations finish well before others, but every optimization takes many more steps when trained in batch. When I count model evaluations, 18,817 total evaluations happen when training models one at a time, while 92,820 happen when training in parallel, so we are doing approximately 5 times too many operations. I describe this in more depth in this post’s extended material. I am inclined to implement batch training differently, maybe by implementing a single training run and then using something like JAX’s vmap in conjunction with JAX’s while_loop.

I wouldn’t have noticed this problem if I hadn’t been able to see my model during training. Of course, other standard visualizations could have revealed this issue; Chart 3 is fairly standard, and it would have also done the job. But I didn’t have Chart 3, I didn’t know I should be building it, and bulding it is actually difficult and inefficient with BoTorch. I think a visualized expression is a useful jumping off point, and maybe we should try to always have it available to us.

A Recipe for Pragmatic Model Visualization

How do we give ourselves a playful environment where our models are always visualized by default?

schematic

I solve part of the problem with a new Python library called rows2prose.

You give rows2prose two things:

  • a dataframe containing scalars that should be visualized
  • a visually-useful string that describes the model (for example, the model’s code), including placeholder text for visualized values

schematic

You can use rows2prose in Jupyter notebooks, and it can output HTML visualization files from arbitrary Python scripts. It generated every visual in this blog post, and you can use it today.

The remaining gap has multiple solutions

How do you take your model and get a visually-useful summary of it? There are, of course, many ways to do this, including crazy new approaches like asking an LLM to generate one for you.

But here’s how I did it.

I created Vexpr, a Python library that takes inspiration from Lisp. In Vexpr, you build up expression data structures (“Vexprs”) similar to Lisp S-expressions. The expressions you see in these visualizations are simply printed Vexprs.

(JAX fans may be familiar with “Jaxprs”; Vexprs are similar, but they are more user-facing. A user of Vexpr is intentionally building up an elegant Vexpr, while a user of JAX doesn’t really care what their Jaxpr looks like. Vexprs and Jaxprs two solve different problems and I plan to use them together.)

Just like Lisp, Vexpr lets you use macros to modify these expressions. For my previous post, I used macros to vectorize expressions, and for this post, I used them to visually optimize the expressions. For example, the actual Vexpr program for my model uses an elementwise division between two arrays to divide N different distances by N different lengthscales, but I wanted to visualize this as N different divisions, with each parameter rendered next to its corresponding “compare” feature. I changed the code using a macro, thus I actually unvectorized the division operations to make them prettier. Here’s another example: my printed Vexpr was more verbose than I wanted to be, so I used macros to convert it into pseudocode! Each visualization above contains the function call compare('log_epochs'), which doesn’t actually exist. In the runnable expression, this compare term is replaced with a larger subexpression that extracts a "log_epochs" feature from two vectors (x1 and x2) and computes the distance. I wanted a succinct visual, so I “visually optimized” the expression, removing those details, and I never implemented compare. Code can be much more succinct when it doesn’t actually need to run.

In addition to macros, another useful idea that filled this gap was partial evaluation. I take a Vexpr, plug in its parameters, then evaluate all parts of the expression that are ready to be evaluated. This computes the “unvectorize” from the previous paragraph, taking arrays of divisors and indexing into them. This is also useful when machine learning models put constraints on parameters; often they implement constraints by storing “raw” versions of the parameters and passing them through an exp or sigmoid to move them into a constrained interval. Partial evaluation moves the values into the constrained interval so that they are ready to be visualized. One thing that made me laugh, putting the ideas of these two paragraphs together: even after converting my runnable Vexpr into pseudocode, I still ran partial evaluation on it, evaluating all expressions that could be evaluated. It feels funny when you tell your computer to “evaluate the parts of this code that are not pseudocode”.

Putting all of this together, here is the final architecture underlying these visualizations.

schematic

The “Macros + Partial Evaluation” functionality I used is all present in Vexpr, but the ideas are still baking. In a future post I might try to convince you to use them.

But what about Deep Learning?

This blog post featured human-comprehensible machine learning models like Gaussian Processes. In these models each parameter has a very clear meaning. Is this blog post applicable to Deep Learning?

First, let me appeal to you that comprehensible models are important, and I think people playing with Deep Learning ought to be among the most enthusiastic users of comprehensible models. Suppose I grant you the extreme position that a deep learning model is a black box that isn’t worth looking into. In that extreme, you have a great use case for comprehensible models: exploring the space of Deep Learning architectures and training regimes. You get to take the giant space of models and regimes, design your own hand-engineered features of that space like “learning rate” or “number of attention heads”, generate your own datasets of experiment results, and conduct symphonies of computers to explore the space. Deep Learning system design is what got me into these models in the first place.

Regardless, I think it’s possible to build useful, pragmatic visuals for Deep Networks. My main design goals would be: (1.) enable the user to detect when something in the network is broken / not being used, and (2.) put an expression in front of the user to encourage playful tweaking of the architecture.

Conclusion

These visualizations were immediately useful, and they are pragmatic because they are not specific to any model type. If you can extract a text description of your model, and if you can trace a set of useful-to-visualize scalars, then you can visualize your model.

I think that somehow we should give all users of machine learning models access to visuals like these. Using some combination of our shared frameworks, our example code, and our crazy LLM tools, we should take on the responsibility of not only performing our desired computation, but also rendering a useful expression of it.

(This post has an appendix. This project is supported by a GCP cloud compute grant from ML Collective, which has been super helpful. Thanks, also, to Rosanne Liu for useful feedback on drafts of this post.)