Monday, December 12, 2011

K-Means Clustering Algorithm

This post is dedicated to K-Means Clustering Algorithm, used for unsupervised learning. Here is a short introduction into the unsupervised learning subject. This video shows the algorithm at work.

So, how does it work?

1. First of all we have the number K (the number of clusters) and the data set X1,...,Xm as inputs, K < m and Xj ∈ ℜn for ∀j=1..m. Considering that clustering is used as a task in data mining (e.g. including text), it is a nice exercise indeed to formalise a problem in a way where data is represented by vectors :)

2. At this step we randomly initialise K cluster centroids μ1,...,μK, μj ∈ ℜn for ∀j=1..K. The recommended way is to randomly initialise the centroids from the data set X1,...,Xm and make sure that μi≠μj for ∀ i≠j.

3. At this step we execute the loop:

Repeat {
  //cluster assignment step
  For i=1..m {
    Find the closest centroid for Xi, i.e.
    min{||Xij||}, ∀j=1..K, e.g. it is μt;
    Assign ci=t;
    //in this case ci is the index of 
    //the closest centroid for Xi

  //move centroids step, if a particular cluster 
  //centroid has no points assigned to it, 
  //then eliminate it
  For j=1..K {
    old_μj = μj;
    μj = average (mean) of all Xi where ci=j, ∀i=1..m;
This loop ends when all the centroids stop "moving", i.e. ||old_μj - μj|| < ε, ∀j=1..K, where ε is an error we are happy with (e.g. 0.01 or 0.001).

This is pretty much the algorithm. However, in this form, there is a risk to get stuck in a local minima. By local minima I mean the local minima of the cost function:
J(X1,...,Xm,c1,...cm1,...,μK) = (1 ⁄ m)⋅∑||Xi – μci||2, sum is for i=1..m

In order to address this, the algorithm (steps 2, 3) are executed several times (e.g. 100). Each time the cost function (J(...)) is computed and the clustering with the lowest J(...) is picked up. E.g.

For p=1..100 {
  Perform step 2;
  Perform step 3;
  Calculate cost function J(...);
Pick up clustering with the lowest cost J(...);

In order to make it more interactive, I and my son spent couple of hours implementing a JavaScript version of this algorithm using <canvas> tag from HTML5 (my son studies HTML at school, so it was an interesting exercise). Here is the link to the page (excuse me and my son for our non OOP approach to JavaScript :)). If you want to have a look at the sources, feel free to download that page (we didn't provide comments, but hopefully this article explains everything). We tested it with Google Chrome and Internet Explorer 9 (if you have problems with IE9, please consult the following link).

Note: While I was writing this article, a friend of mine suggested this nice JavaScript library for plotting