Using GPyTorch — A Researcher’s Experience
I began my research on efficient Gaussian process regression (GPR) over a year ago. During the time I’ve been working in this area, I gained a much better understanding of Gaussian processes (GPs) both theoretically and engineering-wise, which I learned from some of the open source frameworks developed in the past few years for this line of research. GPyTorch is the one that has helped me the most. I’m going to share my experience of exploring and using GPyTorch, and how open source projects can be beneficial for computational and statistical research.
What Is GPR?
First, let me briefly introduce GPR. Gaussian process is a probabilistic supervised machine learning model that makes predictions based on prior knowledge (kernels) and can measure the uncertainty at test points. It is a collection of random variables such that the joint distribution of every finite subset is multivariate Gaussian. Formally, GPR predicts the distribution of test points conditioned on the observed training points:
Then, for one or more test points and their corresponding function values:
We can express the joint distribution as:
The conditional distribution of the test points is then calculated as:
Therefore, we have:
How I Came to Know GPyTorch
To help me get started with GPR, my advisor recommended one of the most important books for beginners in this area, “Gaussian Processes for Machine Learning.“ It presented the fundamentals of GPR from the perspectives of both statistics and machine learning. The authors also developed the GPML toolbox, a Matlab and GNU Octave implementation of GPR inference. This was the first open source tool that I came across in this field, and it helped me dive deeper into the engineering side. However, the functionality it provided was a bit limited for us, so I kept on searching for other implementations of GPR. From there, I found GPyTorch, the open source framework that was a huge help later on.
How GPyTorch Helped Me
While I was searching for related works, one paper, “Fast Matrix Square Roots with Applications to Gaussian Processes and Bayesian Optimization,” stood out. Coincidentally, I found out after some digging that this was the same team behind the highly optimized framework GPyTorch. So I immediately jumped into studying how they built this framework.
OK, back up a little. Until I found that paper, I was getting a little frustrated with the research progress. Most of the papers I read were extremely theoretical and contained background knowledge that I was not familiar with, making it hard for me to find a breakthrough. However, thanks to the GPyTorch team open sourcing their code, I was offered a way in.
At first, I followed their documentation and tweaked the “getting started” examples a little bit here and there to observe the results. This helped me quickly ramp up on the set of important components necessary in the regression. Then I explored a little deeper into directly setting some training parameters using Python3 style setters, which can be found by learning how they constructed their class objects. I studied the code for each necessary component and how they all fit together.
By comparing the algorithms mentioned in the paper with actual code implementations, a lot of things that were hard to understand previously became much clearer. For example, the technique “stochastic trace estimation” was used to estimate one of the terms in the derivative.
The paper itself presented the idea with some formal yet complicated math derivations based on “Lanczos quadrature,” but because the codebase was highly modularized, I know the logic had to show up in a specific module mentioning the same keywords, and it did.
From there, I found out more about how it was actually used in the forward and backward propagation. Some parameters not directly mentioned in the paper, such as using 10 random variables for this procedure by default, came as a nice addition to what I took away. By reviewing the implementation of this statistical method, the question of why and how it was used became much more approachable for me, and would probably also be for a lot of other newcomers to this area.
Besides helping me grasp the statistical formulations from the original paper, going through the implementation of GPyTorch also allowed me to gain more insights into the mechanisms of PyTorch — yet another famous open source framework — which GPyTorch was built upon. Before this, my understanding of PyTorch was limited to some basic multivariate calculus, where the chain rule is employed to calculate the gradient for each parameter, plus some useful library functions.
But to figure out how the components in GPyTorch pieced together for the GPR based on iterative methods, I had to study the process of
torch.autograd. By reading the code of GPyTorch combined with PyTorch’s documentation, I realized that instead of one giant forward step, the framework broke down the computation of the loss function into:
- Sequentially solving the covariance matrix, based on which it computes the two linear solver terms
- Adding the noise term to the covariance
- And, at the same time, saving some tensors required for backpropagation to avoid any redundant computations later in the backward step
All of this gives me a fresh perspective on the way PyTorch works and why it is powerful. If not for this code example, I probably would need a lot more time to figure this out and even more time to come up with my own implementation — an essential next step for my research. Now, with more knowledge of the entire pipeline, I can build my own version of GPR — or pretty much any PyTorch model — much faster, and can make better design choices when working with large systems with numerous coordinating code components.
Lastly, aside from any technical proof-of-concept that GPyTorch provided for our research, it also served as an important confidence booster to help me realize the research idea that my advisor and I had was a reality not so far away. From the outside, GPR could seem super complicated and mind-boggling, but by studying through the GPyTorch project, I found it was still all small pieces grouped together, just as any other challenges we meet in our daily lives.
Other Frameworks Like This
Besides GPyTorch, there are other open source frameworks for GPR that target different training purposes, including Scikit-Learn, GPy, GPflow and Pyro. I would definitely recommend checking out each of these projects, and any of them could be just as useful and insightful as our beloved friend GPyTorch.
Acknowledgement: First , I would like to thank my advisor, Dr. Vivek Sarin, for shaping my thoughts and giving me generous feedback on this blog post. I would also like to thank my mentor at the Linux Foundation, Shuah Khan, for reviewing and improving this article. Lastly, I’m also grateful for my dear husband, Tong Qiu, for helping me proofread and polish the contents.