how to use a computer to check your derivative calculations

calculus
Author

Alex Hayes

Published

October 18, 2017

## Motivation

Suppose you have some loss function $$\mathcal{L}(\beta) : \mathbb{R}^n \to \mathbb{R}$$ you want to minimize with respect to some model parameters $$\beta$$. You understand how gradient descent works and you have a correct implementation of $$\mathcal{L}$$ but aren’t sure if you took the gradient correctly or implemented it correctly in code.

## Solution

We can compare our implemention of the gradient of $$\mathcal{L}$$ to a finite difference approximation of the gradient. Recall that the gradient of $$\mathcal{L}$$, $$\nabla_\mathcal{L}$$, in a direction $$d \in \mathbb{R}^n$$ at a point $$x \in \mathbb{R}^n$$ is defined as

$d^T \nabla_\mathcal{L}(x) = \lim_{\epsilon \to 0} \frac{\mathcal{L}(x + \epsilon \cdot d) - \mathcal{L}(x - \epsilon \cdot d)}{2 \epsilon}$

If we take $$\epsilon$$ to be fixed and small, we can use this formula to approximate the gradient in any direction. By approximating the gradient in each unit direction, we construct an approximation of the gradient of $$\mathcal{L}$$ at a particular point $$x$$.

## Example: Checking the gradient of linear regression

Suppose that we have $$n = 20$$ data points in $$\mathbb{R}^2$$ with responses $$y \in \mathbb{R}$$. Linear regression assumes the responses $$y$$ are related linearly to the data matrix $$X$$ via the equation

$y = X \beta + \epsilon$

We want to find an estimate $$\hat \beta$$ that minimizes the sum of squared error of the predicted values $$\hat y = X \hat \beta$$

$\mathcal{L}(\beta) = \frac{1}{2n} \sum_i (y_i - \hat y_i)^2 = \frac{1}{2n} \sum_i (y_i - x_i \beta)^2 = \frac{1}{2n} (y - X \beta)^T (y - X \beta)$

In the final step above we recognize that the sum of squared residuals can be written as a dot product. Next we’d like to the gradient of this dot product. There’s a beautiful explanation of how to take the gradient of a quadratic form here. The gradient (in matrix notation) is

$\nabla_\mathcal{L}(\beta) = -\frac{1}{n} (y - X \beta)^T X$

We can now implement an analytical version of $$\nabla_\mathcal{L}(\beta)$$ and compare it to a finite difference approximation. First we simulate and visualize some data:

library(tidyverse)

n <- 20
X <- matrix(rnorm(n * 2), ncol = 2)
y <- rnorm(n)

ggplot(NULL, aes(x = X[, 1], y = y)) +
geom_point() +
geom_smooth(method = "lm", se = FALSE) +
labs(title = "One dimension of the simulated data", x = expression(X[1])) +
theme_classic()

Next we implement our loss and gradient functions. We assume the loss function is implemented correctly but want to check the analytical_grad implementation.

loss <- function(beta) {
resid <- y - X %*% beta
sum(resid^2) / (2 * n)
}

grad <- -t(y - X %*% beta) %*% X / n
}

To perform this check, we need get approximate the gradient in a direction $$d$$:

#' @param f function that takes a single vector argument x
#' @param x point at which to evaluate derivative of f (vector)
#' @param d direction in which to take derivative of f (vector)
#' @param eps epsilon to use in the gradient approximation
numerical_directional_grad <- function(f, x, d, eps = 1e-8) {
(f(x + eps * d) - f(x - eps * d)) / (2 * eps)
}

And then to approximate the entire gradient, we need to combine directional derivatives in each of the unit directions:

zeros_like <- function(x) {
rep(0, length(x))
}

numerical_grad <- function(f, x, eps = 1e-8) {
for (dim in seq_along(x)) {
unit <- zeros_like(x)
unit[dim] <- 1
}
}

relative_error <- function(want, got) {
(want - got) / want  # assumes want is not zero
}

Now we can check the relative error between our analytical implementation of the gradient and the numerical approximation.

b <- c(2, 3)  # point in parameter space to check gradient at

num_grad
[1] 2.112107 3.286946
ana_grad
[1] 2.112107 3.286946
relative_error(num_grad, ana_grad)
[1] -2.810374e-08  1.553777e-08
This post is based off of Tim Vieira’s fantastic post on how to use numerical gradient checks in practice, but with R code. See also the numDeriv package.