An Intro To Linear Regression Using A Single Neuron

How one neuron can learn and adapt to data

It is exciting to know that thousands or millions of neurons can work in concert to perform complicated tasks such as image classification, or speech recognition, but let’s start with a more simple example.

When learning something new, it is important to start with the fundamentals, keeping it simple, while keeping the big picture in mind. Let's zoom in on how a single neuron, with a single input and single output, can make a decision.

Above is a model of a single neuron, inspired by the previous post. This neuron is a small part of a larger network, that can make more complicated decisions. What we are going to explore is how neural networks can learn a single one of it's weights, based on some input data. We are going strip down the diagram below and only look at the first input, the first weight, and it's output.

The way neurons work is pretty straight forward. They multiply their input (x) by some value (w) to get some output (y), leaving the equation y = w*x. We call the middle value w for weight, because it determines how much signal gets passed through to the other side. Even with this simple model of a single neuron, we can start to model and learn some simple information.

Let's anthropomorphize our neuron, and call her Linear Lisa.

Lisa loves to learn about the world, so we give her a notebook, and send her out into the world. She decides she wants to learn how peoples heights relate to their weights, so she starts writing down measurements as she meets people. She then plots this data on a graph, and takes a look.

This data is very linear, meaning she should be able to draw a simple line over top of the graph, and follow it to get a reasonable estimate of weight given height. For example say our line is the light blue one estimated below.

Given any value on the x-axis, we can follow it straight up to this line to get a prediction.

How do we figure out this where exactly this line should be? This is where Lisa's internal weight (w) comes into play. By adjusting the value of w, she can decide where this line falls on the graph.

Let’s say she starts by setting her weight w = 2.0. She sees that the data goes up and to the right, and this line does too, so she thinks this might be a reasonable first guess.

After plotting this line, it is pretty obvious that the slope of the line is a little too steep, and it would not do a good job making predictions on the far right, so she decides to try w=1.5, which will be a little less steep.

This looks better, but she decides to keep going, what about w=1.0?

This is starting to look pretty reasonable, but with her simple model of y=w*x she is limited to lines that go through the origin (0,0). It seems like if she was able to move the line up the y-axis a little, she could better represent the data when x is closer to 0.

Lisa thinks back to her algebra class and remembers that the equation for a line also has another term, that controls the y-intercept of the line.

So far we have been using "w" instead of "m" for the slope, but most people will probably be familiar with the equation above where "m" is the slope of the line and "b" is the intercept in the y-axis.

The term "b" is called the bias term, and is useful to get our machine learning models to learn a better model of the data. One way to think of how the bias term works in neural networks is that if the bias is negative, the weighted sum of the input must be greater than the bias in order for the output to be positive. A network might learn a negative value for a bias in order to dampen some weights.

Lisa updates her internal model to have another weight, and another input for the bias. This is starting to look like a more flushed out neural network like the one we started with, but it is still rather simple.

Notice that there are 2 new terms in the equation, “b" and "w_1". In practice we always pass b=1.0, so that w1 acts as the bias term, and simplifies the equation to be the same as y = wx + b. The only reason it is organized like this is so that we understand that the "weights" are learnable, and the inputs are not. This will make it easier and more consistent to write the math and code for later.

Lisa uses her new knowledge about linear functions to estimate new weight values, flattening out the slope to 0.56, and moving up the y-axis with a bias or 2.5.

This looks better than than her previous guesses, but how can we be sure? Is w_0=0.57 better than w_0=0.56? Is w_1=3.0 better than w_1=2.5?

Figuring out these weights is a well defined problem in the field of machine learning and statistics. It is called the Linear Regression problem. We are missing a key piece though. Figuring out the optimal line requires us to know how well our line fits the data in the first place.

Luckily Lisa is not alone in her 2D world of data, and she calls up her friend "Edgar the error function" in order to find the optimal line.

Edgar has one job, and that is to look an actual value from the dataset and compare it to a predicted value from our line. He computes a score that we will call the “error". Edgar will relay this score back to Lisa and tell her how she is doing.

For example, if Lisa thinks someones weight should be 300lbs, given a height of 6ft, Edgar will look at a real value from the dataset of someone who is 6ft, and tell Lisa the difference between her guess, and the actual weight.

Edgar did a little research on Linear Regression and found an error function that he thinks will help solve Lisa's problem - The mean squared error function. This error function is commonly used for comparing sets of data points to see how close they are. It is sometimes referred to as the L2 loss because it computes the L2 norm between the sets of data.

There are a few terms in this equation, and it is not really as complicated as it looks. When approached with equations like this, I find it easiest to start with the inner most parenthesis, find variables that we know, and move outward. If you look at the inner most parenthesis here we have "wx + b”.

This is just the prediction from our line. The subscript “i” on the “x” term indicates we are looking at an individual "x” value from the dataset.

If we move out to the next set of parenthesis we see we are subtracting this value from "y" with subscript “i”.

“yi” represents the actual target value in our dataset, given "xi". If you looked on a graph it would be the y value for the point (x,y). The subtraction is telling us how far away our prediction is from the target datapoint. When looking at each datapoint in the dataset we also square the difference, so that everything is positive.

Instead of squaring the terms, we could take the absolute value, and this would be called the L1 norm, but squaring the difference makes the "error" between our model and the data more dramatic.

Finally we find the average error of our model by summing all the values and dividing by the total number of them.

You can think of this as a single number describing on average, how far away is the line from each datapoint. If we plug in the values for "w" and "b" from our line into this mean squared error equation, and considered each datapoint in our dataset, we will know exactly how well we are fitting our data.

Let's take a look at a small dataset and a linear model to see the mean squared error function in action. In this dataset we will only have two points - (2,3) and (4,5)

Let's say Lisa chooses a random line as a starting point. This line has a slope of -0.5 and a bias of 2.5.

This line clearly does not model the data well, let's ask Edgar exactly how bad the estimate is.

For the first point in the dataset (2,3), Edgar subtracts the value on the line at x=2 (which is y=1.5) from the target value at x=2 (which is y=3) and squares the difference.

This gives us an error of 2.25 for the first point. We now need to calculate the error for the next point (4,5), so we can get the average error over the dataset.

This point is much further away from the line, giving us the error of 20.25. If we average these two errors we get 11.25.

This is pretty bad. In this example, and the example above, our neuron was just guessing (kind of randomly) which values of "w" and "b" would get us the best fit. While guessing and checking is one way of learning the best value, we can do better. In the next post I will go over an optimization technique called gradient descent which is commonly used in neural networks to learn weights and minimize error.


No comments yet.

Add your feedback below

Login to comment