All Articles

Fast AI - Notes on Stochastic Gradient Descent

In the last third of the lesson 2 video, Jeremy goes into detail about linear regression and stochastic gradient descent. I found this part of the lesson really helpful, so I took some in depth notes! If you want to follow along interactively, all the code that I refer to in this post can be found in the Fast.ai repo.

Building a dataset

First, we need to create some fake data points that we will later try to apply the techniques of SGD to.

A tensor is a fancy name for an array in which all of the rows are the same length. A rank how many axes/dimensions there are in the tensor.

x = torch.ones(n,2)
x[:,0].uniform_(-1.,1)
x[:5]

Breaking this down:

x = torch.ones(n, 2); Creates a n (rows) x 2 (columns) tensor, fill it with ones, and assign it to x.
x[:,0].uniform_(-1.,1) In Pytorch, anytime you use a function with an underscore_ at the end, you apply the function in place - i.e. the function doesn’t create and return its own copy of the data. In this case, we’re generating random numbers in between -1 and 1.
x[:,0] means apply a transform of to only the first column.
x[:5] gives us the first 5 rows from the tensor.

plt.scatter(x[:,0], y);

plot

If we can find a line that sufficiently describes the data, we can use it to predict something from the data.

The equation for a line: y = a1x1 + ax

We want to find the parameters or weights for a1 and a2 such that the line that they create minimizes the error between the lines and the points that we plotted above.

Mean Squared Error

Mean Squared Error is the mean of the differences between the predicted values (value at line) and the actual data points.

((y_hat-y)**2).mean()

Going from inside to outside, y_hat - y is taking the each member of y and subtracting it from each member of y_hat. **2 is getting the square of each element, and .mean() is getting the mean of each element. Now let’s get the MSE, or mean squared error.

a = tensor(-1.,1)
y_hat = x@a
mse(y_hat, y)

First, we create a tensor of [-1., 1.] (both floating point), and then we do matrix multiplication on the original data. Then we get the mean squared error, which results in an error of 6.7554.

mse-plot

As we can see, this particular line has a large error, so it doesn’t fit the data very well. How do we minimize error? In other words, how can we get our line to better fit and predict the data? This concept is the core of gradient descent.

Gradient Descent

The two numbers we care about are the intercept of our line, and the gradient/slope of that line. With gradient descent, we experiment with making the intercept higher + lower, and the gradient more positive + more negative, and see which gives us smaller loss.

A derivative tells us which way to move the line to minimize loss, or how changing one thing changes the function.

The update() function describes how iteratively calculating the mean squared error and subtracting the derivative can give us a better line.

def update():
    y_hat = x@a
    loss = mse(y, y_hat)
    if t % 10 == 0: print(loss)
    loss.backward()
    with torch.no_grad():
        a.sub_(lr * a.grad)
        a.grad.zero_()

y_hat is our prediction
loss is our mean squared error
loss.backward() calculates the deerivative, and sticks it in an attribute called .grad
a.sub_(lr * a.grad) subtracts the derivative in place

Now, let’s run this 100 timesand plot the line that we get:

lr = 1e-1
for t in range(100): update()

sg-plot

This is clearly a lot better!

A final note: in practice, it’s extremely time consuming to calculate the loss on millions of images. Instead, we grab a smaller subset of images called mini-batches and update the weights in a piecemeal fashion.