You must know least squares

Let’s get away from the neural network hype and go back to the basics a bit, to the times when things actually made sense why they work.

Marin Vlastelica Pogančić

Machine learning is a really hot topic, everybody wants to do it or use it somehow in a product, or to reduce business operating costs. Machine learning seems to be the answer to the prayers of the business world. Yet, it can be pretty astonishing how people jump into machine learning. Yes, I know, neural networks are the “thing”, but jumping into neural networks is misguided in my opinion.

All of machine learning does not revolve about neural networks

Yes, they can do cool stuff, with a lot, really a lot of compute. But there was a time when machine learning solutions were elegant, efficient and quick. In this article, we’ll get down to the basics, and talk about least squares regression. Why should we talk about that? Because, believe it or not, this simple method is still used in statistical analysis, where there is a lot of data to analyze. In Kaggle competitions it is constantly used to observe trends in data, or variance and so on. Least squares is a thing, and not enough people know the math behind it although it is dead easy. So, let’s get cracking.

If you have ever taken a simple class in linear algebra, you know what is a matrix and you have most probably seen this kind of equation:

If we want to get the solution to this equation, we need the inverse of matrix A (we are just going to consider the case where A is invertible, and the equation is not overdetermined). So the solution looks as simple as this:

So now, we know how we get a solution to this simple equation, or better to say set of equations, since we can look at each row of matrix A as a linear equation by itself. The matrix formulation is a simple compact form to describe a set of linear equations.

But this is not what we actually want to do in the end. We want to learn something, we want to do machine learning! So where is the learning in this… It is just a bunch of linear equations, so boring. Well, what do we do in machine learning exactly? We have a certain function that we want to fit, we call it function f which gives us a value y for some data x. Written down simply:

As simple as that. So we want to learn f, in our case we suppose that f is a linear function (not going into problems of generalization and so on, just to keep it as simple as possible). I admit the notation is a bit tedious at this point. Previously used x as a point that has to satisfy a set of equations. In this case, we want to find the coefficients of a linear equation that describes the data. We can write this in our matrix form:

This expression comes naturally when you think about it since each row in X is a data point and then we take basically the dot-product of the data point with w. Now, we need to define our learning algorithm, that is actually going to take only one step, closed-form solution. For that, we are going to need a bit of calculus. Firstly, we need to look at the following never-aging objective function in machine learning, the mean squared error:

The || brackets are there because when we talk about vectors, we talk about their norms rather, in this case, it is shorthand for the sum of errors on each data point. If you want to know how important are norms and dot-products in ML, take a look at this article. The above error term has some nice properties, one of which is that it is quadratic. In optimization, we like quadratic functions. They are relatively easy to optimize because they are convex, they have one minimum that is global. And how do we find the solution to this, i.e. how do we minimize this function? Quite standard practice in calculus, we take the derivative and search for a solution where it is equal to 0.

Which is basically saying that the gradient is equal to 0:

And the gradient of MSE by taking the chain rule is equal to:

Now, with a bit of matrix multiplication and set the gradient to 0, we can write out the solution in good old closed form:

If you are doubting that it is that simple… Yes it is, you are actually doing machine learning with this small piece of math, or put another way rather, optimization.

Note that this solution is the solution to the case where y is one-dimensional, extending to a multi-dimensional case is trivial and can be left as an exercise! This small piece of math can be implemented with a few lines of code, actually, the least squares part takes only 2 lines of code!

This leads to the following figure, notice how it finds the best fitting line to the data (minimizes the square error).

We can see that the plane goes through the middle of the data cluster where it minimizes the squared error.

Obviously, we cannot fit all of the data with a straight line, or a plane, but with powerful feature extractors, we may be able to reduce our problem to a much simpler one. To put it into perspective, this is what neural networks do effectively, the only difference being that we use some nonlinearity as the activation function in the last layer. If we would remove this, we could look at the last layer of the neural network as a least squares problem, i.e. fitting a plane on the data (activations from previous layers).

So, that was a short, painless introduction to least squares, if you are interested in some other machine learning articles:

  1. Causal vs. Statistical Inference
  2. Kernel Secrets in Machine Learning
  3. The Central Limit Theorem and its Implications

Till next time!