Math Home
Machine Learning

Suppose you have a collection of songs that you want to classify by genre, or you have data about a set of documents and you want to classify them by topic.

\(K\)-means is a method of grouping data points of a similar type. It puts each data point in a class, and data points in the same class are supposed to be similar. There are \(K\) classes, determined by \(K\) centeroids. See the chalkboard below for more details.

The whiteboard is an example of \(K\)-means with \(K = 3\).

Click in the box to add points, then click the Cluster! button below to create 3 clusters.






You will need at least 3 points before the code will run.

The squares are the centroids. Each point will inherit the color of the closest centroid.

The \(K\)-means Algorithm

The \(K\)-means algorithm will run on a set of data points. The algorithm can be summarized as follows:

  1. Choose \(K\) random starting points for the centroids. Do not worry yet about what a centroid is. What is more important is how we use them. For now, just think of them as \(K\) random locations. The centroids are numbered 1 through \(K\).
  2. Assign each data point a class value in \(\{1, 2, \dots, K\}\). A data point gets class value \(c\) if the closest centroid is the centroid numbered \(c\).
  3. For each \(c \in \{1, 2, \dots, K\}\), let \(m_c\) be the center of all the points labeled \(c\).
  4. Move the centroid numbered \(c\) to the location \(m_c\).
  5. Repeat steps 2 through 4 until none of the centeroids change their location.


Let's break the steps into further details.

In step 1, the \(K\) centroids (numbered 1 through \(K\)) are generated at random locations. On the whiteboard, I used the colors blue, green, and red for the classes because it is much easier to see the colors than scattered numbers.
In step 2, every point is assigned a value based on the nearest centroid. Nearest, of course, implies there is a distance. On the whiteboard I use the standard 2D Euclidean distance: \(\sqrt{x^2 + y^2}\). Other applications will call for other distance formulas.
In step 3, the center locations of each class are computed. On the whiteboard, I compute the average of all the \(x\)-values, \(\overline{x}\) and the average of all the \(y\)-values, \(\overline{y}\) for each class. The center for the class is \((\overline{x}, \overline{y})\)
In step 4, the centroids are updated to the center locations found in step 3. Hence the name. Think of this as the best location for being closest to all the data points in that class. However, the centroid may move further away from some point that should not be in that class, and closer to points that should be in that class.
Step 5 says to repeat the process, so that the data points are in the appropriate classes. Repeat until the centroids stop moving. Once the centroids stop moving nothing more is happening, and the algorithm can stop. Every point will be classified in one of \(K\) classes.

You can press the Iterate button above to see how the algorithm works step-by-step. However, I did not follow this algorithm exactly. Below I talk about some liberties I took.

Algorithm on the whiteboard

I think what I described above is the more "classical" \(K\)-means algorithm. However, I did not follow the algorithm exactly. The following is the algorithm I used:



  1. Use \(K = 3\).
  2. Do not allow the algorithm to run unless there are at least 3 data points.
  3. When the algorithm starts, instead of choosing random starting points for the centroids, assign each point to a class. The points will be assigned red, green, blue, red, green, blue, etc. in the order that you made them.
  4. Put the centroids in the middle of each cluster of points.
  5. Change the color of each data point to the color of the nearest centroid.
  6. Find the center of the points in each class by taking the average of the \(x\) and \(y\) coordinates for each class.
  7. Repeat steps 4 through 6 until the window closes or the "Reset" button is pressed.

Using an equal assignment of classes to the data points at the beginning was an arbitrary choice that I made on the fly when I was programming.
The last step says that the process never stops until the user makes it stop. This is done because that data can be updated in real time. If you run the clustering algorithm, then add more points, the clusters will adjust.

The first time you press the iterate button, you will see the alternating assigned classes on the data points. The centroids (the squares) will be at the center of each class. If you press the button again, each data point will change its color to the color of the centroid that was closest, and the centroids will once again be moved to the centers of their respective classes.

How many clusters?

In the program, I used 3 clusters. In the general algorithm described above, \(K\) is fixed. When you run the \(K\)-means algorithm, you always get \(K\) clusters. However, it is possible to choose a bad number for \(K\). For example, what if we are classifying 2 genres of music, but we use \(K = 5\) to make classes? Then there will be points divided into classes that don't mean anything. On the other hand, what if there should really be 8 classes and I tell the algorithm to make 3? You can try above if you make 10 small clusters of points. You will notice that several clusters end up being lumped into the same class.

The \(K\)-means algorithm does not give any information about what \(K\) should be. However, there are ways to determine if one values of \(K\) should work better than another. See the post on silhouette scores for more information!

To find a way to determine a good number of clusters, see this lesson: