Math Home
Machine Learning

Introduction

\(K\) Nearest Neighbors (or KNN for short) is a supervised learning algorithm in which data is being sorted into classes. Suppose we are sorting the data into \(n\) classes \((C_1, C_2,\dots, C_n).\) We have a set of labeled data points which we use as training data. Let \(\overrightarrow{x}_1, \overrightarrow{x}_2, \dots, \overrightarrow{x}_m\) be the data points and let \(y_i\) be the class of data point \(\overrightarrow{x}_i.\)

Given a new data point \(\overline{x},\) the problem we are trying to solve is determining which class \(\overrightarrow{x}\) should belong to. The \(K\) Nearest Neighbors algorithm will look at the classes of the \(K\) nearest neighbors to \(\overrightarrow{x}\) (according to some distance measure) and assign \(\overrightarrow{x}\) the class of the majority of its neighbors.

It is possible that there are ties. Ties can be handled in different ways and the best way depends on the classification problem the KNN is being used to solve.

Example

We are classifying data in a red class and a blue class. Suppose we have the following: \begin{align} & \text{Data } & \text{Class} \\ & (1,1) \mapsto & \text{red} \\ & (1,4) \mapsto & \text{blue} \\ & (5,2) \mapsto & \text{red} \\ & (8,5) \mapsto & \text{red} \\ & (6,3) \mapsto & \text{blue} \end{align} A new data point is observed at \((1,3)\) and we want to predict whether it will be a blue point or a red point.

We will consider the classification for two different values of \(K.\) Since there are \(2\) classes, if we use odd values of \(K\) there will be no ties.

\(K = 1:\) We only need to look at the class of the nearest point. The nearest point to \((1,3)\) that we have already classified is \((1,4).\) Since \((1,4)\) is blue, the KNN will classify \((1,3)\) as blue.

\(K = 3:\) We need to consider the classes of the three nearest points to \((1, 3).\) The three nearest points are \((1,4),\) \((1,1)\) and \((5,2).\) Two of the points are in the red class and one is in the blue class. Since more of the nearby points are red, we will classify \((1,3)\) as red.


As you can see from the example, the choice of \(K\) can determine which class the KNN will predict for a given data point. Small values of \(K\) make it more likely that you only consider nearby points but also gives a less stable prediction. Large values of \(K\) give more stable predictions but may cause the KNN to consider points that are not near a given data point.

You can draw points in the red class and blue class on the plane, then run the KNN algorithm to see the regions over which future points would be classified as red or blue. The value of \(K\) can be adjusted with the buttons below.


Select the class you want to add:
Select a value for \(K:\)


You can create examples where a different values of \(K\) give better results.



Consider the case where most of the top left points are red and most of the bottom right points are blue. There are one or two outliers. In this case, \(K = 3\) gives a nice result because it ignores the outliers:



If we use \(K = 1\) we get regions which are too sensitive to the outliers. There are regions near the outliers where unlikely predictions are made:




Now consider the case when blue is in the upper left corner, but they are rare. If we use \(K = 5\) then the blue class will not be recognized at all.



However, using \(K = 1\) the blue class shows up in the predictions.



Code Sample

Click the code or the description to see the connection.

Comments
JavaScript
>First check that there are at least K points. If not, the KNN classifier cannot be used.
>To make a prediction about an unclassified point, find the K points in the training data that are closest to the unclassified point.
>>First, set up an array with K points. The points have class 0 and a distance larger than any two points on the screen can have.
>>Loop through each point in the training data and compute its distance to the point we are classifying. We are classifying the point with (x,y)-coordinates (4i, 4j).
>>If the distance from the training data point to unclassified point is less than the farthest of the K closest points in the training data we have checked so far, replace the farthest of the K closest points with the current point.
>>Keep track of the training point that is farthest away among all of the K closest points. This is the point that will be removed if we find a closer point.
>Count how many of the K closest points are in the red class.
>If at least K/2 of the K closest points are in the red class, predict the unclassified point is red. Otherwise, predict the unclassified point is blue.
>Iterate the classifier over all points in a square grid.
// dataPoints is an array of data points.
// dataPoints[d] has 3 coordinates: x, y, color in hex

function classifyPoints() {
  if (K > dataPoints.length) {
    return
  }
  for (var i = 0; i < 250; i++) {
    for (var j = 0; j < 250; j++) {
      //Find K closest points
      currClass = []
      currDist = []
      currMaxIndex = 0
      for (var fillup = 0; fillup < K; fillup++) {
        currClass.push(0)
        currDist.push(1000000)
      }
      for (var d = 0; d < dataPoints.length; d++) {
        dist = (dataPoints[d][0] - 4*i)**2 + (dataPoints[d][1] - 4*j)**2
        if (dist < currDist[currMaxIndex]) {
          currDist[currMaxIndex] = dist
          currClass[currMaxIndex] = dataPoints[d][2]
          for (var loop = 0; loop < KNNnumber; loop++) {
            if (currDist[loop] > currDist[currMaxIndex]) {
              currMaxIndex = loop
            }
          }
        }
      }
      var countRed = 0
      for (var countLoop = 0; countLoop < K; countLoop++) {
        if (currClass[countLoop] == 0xff0000) {
          countRed++
        }
      }
      if (2*countRed > K) {
        square(ctx, 4*i, 4*j, 2, 2, 0xff0000)
      }
      else {
        square(ctx, 4*i, 4*j, 2, 2, 0x0000ff)
      }
} }
}