Love your k Nearest Neighbors : The Basics

The k-Nearest Neighbors (KNN) is a machine learning algorithm for predicting discrete or continuous values based on the characteristics of close examples. It is one of the most basic algorithms for supervised learning – or learning from previously known data in order to classify or approximate a feature of a new piece of data.

The KNN algorithm performs the following steps:

  1. Measure the distance (note that “distance” can be measured in different ways, discussed in the later part) of the target variable to each of the known data samples.
  2. Sort the distances in ascending order
  3. Specify an integer value k.
  4. Select the top k data samples that are closest to your target variable.
  5. If you are performing regression (i.e., computing for an estimated value), then get the mean of the values of the k closest data samples. The mean is the output of the algorithm.
  6. If you are performing classification, then get the mode of the values of the k closest data samples. The mode is the output of the algorithm.

By applying the wisdom of the saying “birds of the same feather flock together”, KNN can predict the type of any bird based on the types of birds that are nearest to it.

As the definition suggests, KNN can be used in both classification and regression problems. Let’s take a look at a classification scenario where we can apply the most basic form of the algorithm. First, I’ll go through an intuitive approach to the algorithm, and then explain it in terms of how it is implemented in machine learning. We’ll then take a look at some problems of the algorithm based on the example that was presented.

Doing KNN, the human way

Let’s say you’re in a park and saw a bunch of birds on one side. You know there are parrots, eagles, doves, and swans, but there is one bird in the area that you could not classify. You call it Bird X. For simplicity, let’s assume that all the birds are stationary. Their positions are illustrated below:

Bird X is in the middle of a flock of birds of different types. Can you guess what bird it is?

To know what type of bird Bird X is, you decided to perform the KNN algorithm. The first step is to calculate the distance of Bird X to each of the birds in the area. You decided to get the straight-line distance in meters. Let’s also assume that you have some way of measuring the distances without disturbing the birds!

Straight-line distance of Bird X to each of the birds in the area

You specify your k to be 5, for some reason. So, the top 5 birds that are closest to Bird X are:

1parrot0.5 meters away
2eagle0.8 meters away
3eagle1 meter away
4dove1.2 meters away
5eagle1.8 meters away
Top k birds closest to Bird X, where k=5

To get the mode, just get the most frequently occurring bird in the top 5 birds closest to Bird X. Without much difficulty, you determine that the mode is 3, which are eagles. Thus, Bird X is an eagle.

If you tried to implement KNN in a computer using the scenario above, you’d end up with a few complications. Firstly, machines do not have the same notion of physical location as humans do (unless you implement computer vision techniques, but then it gets way too complicated than intended). For instance, when humans say “one eagle is over there“, a computer might say it as “one eagle object is in point (x,y) of some coordinate system“. Secondly, since a computer uses values in an n-dimensional plane to determine location, then it has a different way of computing for distance. For instance, it cannot just say, “Bird X is 2 meters away from the swan”. Rather, it performs methods of calculating the distance between two points in a coordinate system.

The next section describes how KNN is performed in machine learning using the same bird classification scenario, but replacing the human element of describing location and distance.

Doing KNN, the machine learning way

We plot the position of each bird as X and Y coordinates on a graph. The X and Y coordinates become features of each bird. A feature can be any information that tells something about the object in question, such color of feathers, size of beak, etc. In this example, we will only consider the X and Y coordinates to predict the type of Bird X.

Position of birds in the park, represented as points on a 2D Cartesian plane

Each point in the graph is assigned a color to represent the type of bird that is there. For example, point (3,6) is a parrot, and next to it at point (4,6) is an eagle. The point represented by X (5,4) is our unknown Bird X.

We will use a straight line to calculate the distance between each point and point X – or in other words, the Euclidean distance formula:

where A1 is the first coordinate of point A, B1 is the first coordinate of point B; A2 is the second coordinate of point A, and B2 is the second coordinate of point B

Recall that I mentioned that “distance” can be calculated in different ways, and the Euclidean distance is just one of them. The other distance metrics in machine learning can be a topic for another post next time.

So, applying the distance formula for each point and point X, we come up with the table below:

PointFormulaDistance from Bird X: (5,4)
parrot: (2,4)sqrt((2-5)2 + (4-4)2)3
parrot: (3,6)sqrt((3-5)2 + (6-4)2)2.83
parrot: (2,9)sqrt((2-5)2 + (9-4)2)5.83
eagle: (7,6)sqrt((7-5)2 + (6-4)2)2.83
eagle: (4,6)sqrt((4-5)2 + (6-4)2)2.24
eagle: (5,7)sqrt((5-5)2 + (7-4)2)3
dove: (8,2)sqrt((8-5)2 + (2-4)2)3.61
dove: (9,1)sqrt((9-5)2 + (1-4)2)5
swan: (1,1)sqrt((1-5)2 + (1-4)2)5

If k=5, then the top 5 birds that are closest in distance to Bird X are:

  1. eagle: (4,6) : 2.24 m away
  2. parrot: (3,6) : 2.83 m
  3. eagle: (7,6) : 2.83 m
  4. parrot: (2,4) : 3 m
  5. eagle: (5,7): 3 m

Since the majority of the birds in the top 5 are eagles, then Bird X is an eagle.

KNN in Python

The bird classification problem above can be simulated in Python code, using K-Nearest Neighbors to predict the class of Bird X.

First, write the bird data into a variable. I used a 2-dimensional list for ease of access to the elements later.

# Each row contains the bird type, X-coordinate, and Y-coordinate, in order.
data = [
    ['parrot', 2, 4],
    ['parrot', 3, 6],
    ['parrot', 2, 9],
    ['eagle', 7, 6],
    ['eagle', 4, 6],
    ['eagle', 5, 7],
    ['dove', 8, 2],
    ['dove', 9, 1],
    ['swan', 1, 1]
]

Then, define the distance function. In our example, we’re using the Euclidean distance formula to solve for the distance between two points on the graph.

import math

""" xA: x-coordinate of the first point
    yA: y-coordinate of the first point
    xB: x-coordinate of the second point
    yB: y-coordinate of the second point
"""
def get_distance(xA, yA, xB, yB):    
    distance = math.sqrt((xA-xB)**2 + (yA-yB)**2)
    
    return round(distance, 2)

Define the current location of Bird X and calculate the distance between that location and the location of each of the birds in the data list using the distance formula defined above. Each bird-to-bird distance will be stored in a list.

""" Step 1: Measure the distance (note that "distance" can be measured in different ways,  discussed in the later part) of the target variable to each of the known data samples."""

unknown_bird_loc = [5, 4]   # Bird X's location
distances = []

for i, bird in enumerate(data):
    distance = get_distance(unknown_bird_loc[0], unknown_bird_loc[1], bird[1], bird[2])
    
    # Save i since we'll be sorting the distances list later,
    # and we want a reference to each element's index in data.
    distances.append((i, distance))

At this point the distances are arranged in the same way the birds are arranged in data. The next step is to sort the distances in ascending order, so that we can get the top k birds that are closest to the location of Bird X.

""" Step 2: Sort the distances in ascending order. """

# Make sure that distances are sorted by the 2nd element in the tuple
def get_sort_key(distance_info):
    return distance_info[1]
    
sorted_distances = sorted(distances, key=get_sort_key)

sorted_distances should now contain a list of tuples (i, d), where i is a reference to the index in data, and d is the computed distance from the location of Bird X.

If we print out the contents of sorted_distances, the first element should be (4, 2.24), which means the bird at index 4 in data is 2.24 meters away from Bird X.

We can now define the value of k, then get the top k elements in the list.

""" Step 3: Specify an integer value k. """
k = 5
k_nearest_types = []

""" Step 4: Select the top k data samples that are closest to your target variable. """
for i in range(k):
    index_in_data = sorted_distances[i][0]

    # Using the original index, get the type of this bird
    bird_type = data[index_in_data][0]
    
    k_nearest_types.append(bird_type)

All the k-nearest birds are stored in the k_nearest_types list. Python has a statistics module that can find the mode of a list for us.

""" Step 5: If you are performing classification, then get the mode of the values of the k closest data samples. The mode is the output of the algorithm. """

from statistics import mode

mode = mode(k_nearest_types)
print('The mode is', mode, ', therefore, Bird X is', mode)

After running the script, we get the following output:

That’s a simple implementation of KNN on the bird classification problem using Python. You can try modifying the X and Y coordinates in data or the location of Bird X to see how the prediction changes as the location of each bird changes.

When KNN is not so smart

KNN being labeled as a “basic” algorithm means that it comes with huge limitations to how much a computer can learn. There are three main problems with KNN:

  1. KNN’s prediction is limited to previously known data. In the bird example, the algorithm can only predict if a bird is one of the 4 birds in its knowledge. What if it was actually an entirely new bird, like a pigeon? A pigeon that lands between a flock of eagles surely is not an eagle! This is why KNN is often referred to as “lazy learning”. It does not really learn from the data. It just merely answers with the closest thing in its knowledge.
  2. You need to be careful with the value of k. Deciding k is a challenge in itself. If the value of k is too low, the algorithm might be too quick to generalize without looking at the bigger picture. If the value of k is too high, it becomes more susceptible to noise in the data. Either way often causes poor prediction accuracy.

Consider a scenario where Bird X is actually an eagle, but its closest neighbor was a parrot. If k was set too low, like k=1, then the algorithm will mistakenly predict that it’s a parrot because it failed look further to see the eagles.

In contrast, if k was set too high, like k=8, it might get too distracted from data that are further away, which effectively stretches the mean or mode. This causes the algorithm to predict that Bird X is a parrot, as it found more parrots within the wider area.

3. KNN is computationally expensive as it looks at more data. Since the algorithm performs the distance formula on all of the data points, it becomes very inefficient if the amount of previous data it has to look at is huge. Imagine if we have to consider the distance of Bird X from all other birds in town!

This has been a soft introduction to the KNN algorithm.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: