Understand Recall, Precision & ROC of your classifier, intuitively

If you work on Machine Learning projects, you might encounter these terminologies quite frequently.

In this post, I will explain them intuitively, and draw the ROC curve step by step in Python.

Accuracy

Accuracy = (TP + TN)/(TP + TN + FP + FN)

Accuracy alone is a bad measure for classification tasks. A predictive model may have high accuracy, but be useless.

For example, when the classifier is doing fraud detection, and there are 100 samples, 99 of the 100 are benign transactions. 1 of the 100 is fraudulent.

If the classifier is a one-liner code, which simply returns False for any input.

def isFraud(self, transaction):
    return False

You will still get 99% accuracy. This is called Accuracy Paradox.

Recall and Precision

Recall, formally defined as,

Recall = True positive / (True positive + False negative)

Or, to remember their meaning without thinking about true positive/false positive/false negative jargon, I conceptualize them as follows:

Say, if you are asked,

“Can you list the most recent 10 countries you have visited?”

So you recall.

And the Recall rate is a number of countries you can correctly recall to a number of all correct events.

If you can recall all 10 countries, you have a 1.0 recall rate (100%). If you can recall 7 countries correctly, you have a 0.7 recall rate.

However, you might be wrong in some answers.

For example, you answered 15 countries, 10 countries are correct and 5 are wrong. This means you have a 1.0 recall rate, but not so precise.

This is where Precision came into the picture.

precision = True positive / (True positive + False positive)

Precision is the ratio of a number of countries you correctly recalled to a number all countries you recalled(mix of right and wrong answers).

So it is not hard to imagine that when you increased the recall rate of your model, your precision rate will drop. Vice versa.

F1-score

We now know that a model could be good at recall, but it does not it is also good at precision.

So F1-score is introduced to evaluate the model. It is the harmonic mean of the precision and recall.

The drawback of F1-score is that if you have a model with high recall and low precision. Another model with low recall and high precision. The two F1-score will end up very close to each other.

ROC curve & AUC

Now you have some idea about the Accuracy, the Recall and the Precision of a model.

By using these rates, the ROC curve evaluates the performance of models.

Let’s take a look at a concrete example where I draw ROC curves from scratch.

For example, you have already trained a classifier. By feeding your classifier with 20 labelled samples, you got a list of decisions from the classifier, like below,

# 0.9 means the classifier is 90% sure that this sample is True. 
y_score = [0.9, 0.8, 0.7, 0.6, 0.55, 0.54, 0.53, 0.52, 0.51, 0.505, 0.4, 0.39, 0.38, 0.37, 0.36, 0.35, 0.34, 0.33, 0.3, 0.1]

And their labels are below, where 1 represents the False label, 2 represents True label.

y_true =  [2, 2, 1, 2, 2, 2, 1, 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 1, 2, 1]

Now you have all ingredients ready to get the roc_curve. Behind the scene, the roc_curve() function will sort the score list in descending order. Like below, the source code used by Python’s sklearn

# sort scores and corresponding truth values
desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]

Then for each score, the function will treat the specific score as the threshold, and apply it to all samples.

In our example, if you take the score 0.6 as the threshold, then the first 4 samples (namely the samples scored 0.9, 0.8, 0.7, and 0.6) will be classified as Positive, and the rest 16 samples (which are less than 0.6) are classified as Negative. Thus, we can compute the overall TPR and FPR for the model and draw one point in the coordinates.

By repeating this for each score, we can get 20 points on the coordinates. By connecting them with (0, 0), (1, 1), we get the ROC curve.

Compute the ROC and AUC by sklearn in Python

# 1 represents False, 2 represents True
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

y_true =  [2, 2, 1, 2, 2, 2, 1, 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 1, 2, 1]
y_score = [0.9, 0.8, 0.7, 0.6, 0.55, 0.54, 0.53, 0.52, 0.51, 0.505, 0.4, 0.39, 0.38, 0.37, 0.36, 0.35, 0.34, 0.33, 0.3, 0.1]

fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=2)
>>> fpr
array([ 0. ,  0. ,  0.1,  0.1,  0.3,  0.3,  0.4,  0.4,  0.5,  0.5,  0.8, 0.8,  0.9,  0.9,  1. ])
>>> tpr
array([ 0.1,  0.2,  0.2,  0.5,  0.5,  0.6,  0.6,  0.7,  0.7,  0.8,  0.8, 0.9,  0.9,  1. ,  1. ])
>>> roc_auc = auc(fpr, tpr)
>>> roc_auc
0.68000000000000005

Plot the ROC

plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()

And now you have got the ROC chart of this specific classifier. The AUC (Area Under the Curve) of this classifier is 0.68. imgROC curve for classifier #1

Is the above classifier algorithm good? It is hard to draw a conclusion this far, but in general the greater your AUC is, the better your classifier performs overall.

Let’s say that you have modified your classifier, for example, by adding or pruning some of your features.

Your new y_socre is like below,

y_score = [0.9, 0.7, 0.7, 0.5, 0.55, 0.44, 0.53, 0.42, 0.51, 0.405, 0.4, 0.29, 0.38, 0.27, 0.36, 0.25, 0.34, 0.23, 0.3, 0.001] 

By repeating the ROC algorithm described above, we will get a new ROC curve with an AUC improved to 0.74.

In this case, we can say in general, your new algorithm out-performed the previous one. imgROC curve for classifier #2

Which threshold to choose to put into production?

Usually, the Kolmogorov–Smirnov test is used to get the threshold point where the value of (TPR - FPR) is maximized.

The KS test is un-biased. But most of the time, your classifier will have preference.

For example, an email spam classifier would prefer low False Positive Rate because higher FPR means your classifier would be more likely to mistakenly filter out non-spam emails (i.e. important emails from your boss). In that case, you need to tailor your own decision function, for example, by giving FPR more weight.

A malicious URL classifier would prefer higher recall(TPR), simply because you don’t want to let bad URL pass by – even that higher recall usually comes with higher FPR. It is a matter of trade-off.

Bloom Filter -- a probabilistic data structure as alternative to HashSet

As a software engineer, I trust the HashSet. If the HashSet tells me that “I definitely do not contains this key”, I take its words, unless, well my RAM gets bit alterations from “Atmospheric Neutrons” (aka: Cosmic Rays).

As an alternative to hashset, the Bloom Filter is a similar data structure that does the same job, but in a more space-efficient manner.

Motivation – Malicious Website Blocking Problem

Let’s say we use crawlers to scan websites’ content to ensure it is not compromised by malware. When a certain url is believed to be comprised, we add the url to a black-list. As the Google Chrome browser does.

imgChrome browser is using bloom filter to block malicious websites

We can implement this blacklist using a hashtable. This table could be located either on the client side(i.e. the Chrome browser) or on the server side(i.e. online shorten url service provider). This approach is conceptually easy to understand, simple to implement with lines of code, and with O(1) time complexity for url look-ups.

But with the following drawbacks,

  • Not memory efficient – for hashset containing top 1 million alexa domains, the hashset size is roughly 300M in memory. For shorten url provider, like bitly, the blacklist could take about 2.9GB in memory (according to its blog post).
  • We don’t want to expose the raw data to client side – the blacklist itself is a treasure trove asset, you don’t want to expose it to hacker.

With the Bloom Filter, the 300MB hashtable could be reduced to 1.71MB with 0.001 false positive rates, calculated by bloom filter calculator.

And because, ubder the hood, the bloom filter is a bitarray with ‘1’s and ‘0’s, it encrypted the actual content of the hashtable.

Introducing Bloom Filter

A Bloom filter is a space-efficient probabilistic data structure, conceived by Burton Howard Bloom in 1970, that is used to test whether an element is a member of a set. False positive matches are possible, but false negatives are not – in other words, a query returns either “possibly in the set” or “definitely not in the set”. Elements can be added to the set, but not removed (though this can be addressed with a “counting” filter); the more elements that are added to the set, the larger the probability of false positives.

Quote from Bloom Filter, Wikipedia

Bloom Filter is a bit array of N bits, where N is the size of the bit array. It has another parameter which is the number of hash functions, k. These hash functions are used to set bits in the bit array. When inserting an element x into the filter, the bits in the k indices h1(x), h2(x), …, hk(x) are set, where the bit positions are determined by the hash functions. Note that as we increase the number of hash functions, the false positive rate of this probability goes to zero. However, it takes more time to insert and lookup as well as the bloom filter fills up more quickly.

In order to check the membership existence in the Bloom Filter, we need to check if all of the bits are set; very similar to how we insert item into a bloom filter. If all of the bits are set, then it means that that item is probably in the bloom filter, where if not all of the bits are set, then it means that the item is not in the Bloom Filter.

Check this link for detailed tutorial for bloom filter

An example bloom filter containing 1-million urls

For a bloom filter containing 1 million urls, with 0.001 false positive rates.

  • If the bloom filter tells you “it is definitely NOT in the bloom filter”, it is 100% ensured that it is not in the bloom filter (because of the 0 false negative rate guarantee).
  • If the bloom filter tells you “It might be in the bloom filter”, there is a false positive rate of 0.001.

What are the false positives?

A false positive error, or in short a false positive, commonly called a “false alarm”, is a result that indicates a given condition exists, when it does not.

For example, if the bloom filter says the “www.facebook.com” is NOT in the blacklist (obvious, it should not), you can give this url a greenlight to let it pass. If the bloom filter says the “www.youtube.com” might be in the blacklist, you take this with a grain of salt – bearing in mind that there is 0.1% of false positive. So you need send the url “www.youtube.com” to further examinations, for example, a DB check, or even some more expensive realtime malware scan.

If your further check confirms that the “www.youtube.com” is NOT a malicious website, it is a “false positive” or “false alarm”. If, somehow, “www.youtube.com” is marked as a malicious website in your DB. It is a “true positive”.

Does false positive matter?

In this case, the false positive is acceptable, as long as we control the rate to be low by allocating more space to the bloom filter.

Indeed, to handle the occasional false positives, we did spend more time and resource to do the checking in the backend.

It is a trade-off decision to make. More importantly, not a single malicious website gets to slip out of our hands, because of the “no false negative” guarantee.

Implementing the bloom filter above

Without concerning the thread-safe, entry-removing needs, or scalability needs, the bloom filter could be achieved in a few lines of code.

Below is the python 2.7 implementation of a bloom filter containing 1 million url, with the false positive rate set to 0.001. It uses the mmh3 – MurmurHash3 hash function, and by feeding the mmh3 with different seeds, they can be treated as different hash functions.

import mmh3
import sys
import time
from bitarray import bitarray
import csv

class BloomFilter(set):

    def __init__(self, size, hash_count):
        super(BloomFilter, self).__init__()
        self.bit_array = bitarray(size)
        self.bit_array.setall(0)
        self.size = size
        self.hash_count = hash_count

    def __len__(self):
        return self.size

    def __iter__(self):
        return iter(self.bit_array)

    def add(self, item):
        for ii in range(self.hash_count):
            index = mmh3.hash(item, ii) % self.size
            self.bit_array[index] = 1

        return self

    def __contains__(self, item):
        out = True
        for ii in range(self.hash_count):
            index = mmh3.hash(item, ii) % self.size
            if self.bit_array[index] == 0:
                out = False

        return out

class HashSetLookup(object):
    def __init__(self):
        self._hs = set()

    def add(self, word):
        self._hs.add(word)

    def check(self, word):
        if word in self._hs:
            return True
        else:
            return False

bf = BloomFilter(14377588, 10) #parameters generated from the bloom filter calculater
hs = HashSetLookup()

count = 0
for row in csv.reader(open("top1m.csv")):
    bf.add(row[1])
    hs.add(row[1])
    count += 1
print "{} items added!".format(count)

print "Size of the Bloom Filter in MB: ", 14377588*0.6/100000
print "Lookup using Bloom Filter..."
start = time.time()
not_seen_count = 0
for row in csv.reader(open("2m_incoming_traffic.csv")):
    #print "look up:", row[1]                                                                                                     
    if row[1] not in bf:
        not_seen_count += 1
    #    if ".dummy" not in row[1]:                                                                                               
    #        print row[1]                     

print "False Positive count: ", not_seen_count

end = time.time()
print "time elapsed for 2m incoming url traffic: ", (end - start), "s"
print "not seen url count: ", not_seen_count
print "false positive count: ", 999200 - not_seen_count

print "Size of the hashset in MB: ", sys.getsizeof(hs._hs)/100000
print "Lookup using hashSet..."
start = time.time()
not_seen_count = 0
for row in csv.reader(open("2m_incoming_traffic.csv")):
    #print "look up:", row[1]                                                                                                     
    if not hs.check(row[1]):
        not_seen_count += 1
end = time.time()
print "time elapsed for 2m incoming url traffic: ", (end - start), "s"
print "not seem url count: ", not_seen_count

Output,

mbp:python-bloomfilter weihan$ python bloom_filter.py
999200 items added!

##########lookup 2-million url via bloom filter#########
Size of the Bloom Filter in MB:  86.265528
Lookup using Bloom Filter...
time elapsed for 2m incoming url traffic:  12.2198560238 s
not seen url count:  998209
false positive count:  991

##########lookup 2-million url via hashset#########
Size of the hashset in MB:  335
Lookup using hashSet...
time elapsed for 2m incoming url traffic:  1.70107579231 s
not seem url count:  999200

From the output above we can observe the 3 facts,

  • There are 991 false positive cases reported.
  • The bloom filter is of sze 86MB, while the hashset with the same contents cost 335MB in RAM.
  • The bloom filter cost more time to do the lookup – because hashset uses only one hash function; while bloom filter utilizes multiple hash functions.

full-ledged implementations in different languages,

In production, for different reasons, you might want to write your own bloom filter implementation to fully fit your needs, for example, speed requirement, distributed deployment, or removal requirements.

There are open source implementations for every language, but the following implementations are pretty good in my experience:

Note: you can re-write the python code using pypy, if speed is crucial in your application.

The two flashlights paradox and taking advantage of Python's GC when doing data stream caching

When switched from the language with manual memory management, such as C or C++, to a garbage-collected language, your job as a programmer is made much easier by the fact that your objects are automatically reclaimed when you’re through with them. It seems almost like magic when you first experience it, it can easily lead to the impression that you don’t have to think about memory management, but this isn’t quite true.

– quote from Effective Java by Joshua Bloch.

Ignoring memory management in Python will probably not that often lead you to situations where thing went very wrong – However it still could. For example, continously passing observer (callback) methods to observable objects, without explicitly destroying it(by nulling it out), when callback became no longer needed. Overtime, the memory leak can happen.

Keeping memory management in mind will help you gain edges when dealing with “bigger data” tasks.

In this post, I am going to show you how to take the advantage of Garbage-Collection (GC) in Python.

Stream Data Caching Problem

Sometimes, when data stream comes, we want to store them and preserve their natural order, for example, we keep them in order of the timestamp.

Putting them as data nodes in a LinkedList is the common and natual implementation.

But once we put them into the list, everytime when we need to access or retrieve one of the stored data nodes, it took O(logn) time to get to it.

Usually people build auxiliary HashMap to record the reference to all data nodes in the LinkedList. Thus we can gain the O(1) time complexity to search and access any node in the LinkedList.

This benefit comes with costs. When the hashmap grows, the collision rate increases, and the hashmap’s performance will eventually deganerate to LinkedList or BST(if in Java8) as shown below, a quote from Ralf Sternberg’s blog, the hashCode collision reaches 50% when there are 100,000 distinct objects in the hashmap. img

So we sometimes have to reduce the size of the cache in memory by, say, shrinking the linkedlist to only keep the recent one hour data.

For linkedlist, shrinking size could be done in O(1) time complexity by resetting the HEAD to a new data node. The discarded data nodes will be garbage-collected automatically.

But since we have a hashmap caching of these nodes, the data nodes discarded by the linkedlist are still referenced by the cache, which prevents GC from collecting them. So we will need to loop over them and remove each one of them from the hashmap, which takes O(n) time complexity where n is the numbers of discarded data nodes. It suddenly became the bottleneck of performance, not only because of the O(n) extra time cost, but more importantly, it will block the handling of newly arrived data nodes when we are shrinking the hashmap.

Can we do better?

The Two Flashlights Paradox

Okay, I confess that I coined the phase Two Flashlights Paradox. It was just one scene from Stephen Chow’s movie From Beijing with Love(1994).

In the movie, the rocket scientist Tat Man-sai invented a special flashlight that never get lit by itself until it senses light beam from another flashlight. imgIt won’t light imgOnly until another flashlight lights

It was a joke and a completely useless invention in the movie. But Hey isn’t this what we want to solve our streaming data caching problem?

– We want to keep data in hashmap(the joking flashlight to light) only when the data is referenced in the outside linkedlist(another flashlight).

Introducing the Weakref.WeakValueDictionary class in Python.

Very well self-describing, the WeakValueDictionary links its key, value pair using a “weak ref” instead of a strong ref(the normal ref) as the normal dictionary uses. Like the joking flashlight in the movie, in WeakValueDictionary will keep its “key, value” pairs as long as the value is referenced by another strong ref outside of the WeakValueDictionary.

Once the outside strong ref is gone, both the value and the key in WeakValueDictionary will be removed immediately.

Let’s do the following experiment to see how it works,

We will call the demo code below twice,

In the first call, we feed it a normal Python dict() to do the caching.

In the second call, we feed it with a Weakref.WeakValueDictionary() to do the caching.

import gc
from pprint import pprint
import weakref

gc.set_debug(gc.DEBUG_LEAK)

class ExpensiveObject(object):
    def __init__(self, name):
        self.name = name
        self.nxt = None
    def __repr__(self):
        return "ExpensiveObject(%s)" % self.name
    def __del__(self):
        print "(Deleting %s)" % self
        
def demo(cache_factory):

    dummyHead = ExpensiveObject("HEAD")
    cur = dummyHead
    print "CACHE TYPE: ", cache_factory
    cache = cache_factory() # the cache is initiated using cache_factory given
    for name in [ "one", "two", "three", "four" ]:
        o = ExpensiveObject(name)
        cache[name] = o
        cur.nxt = o
        cur = o
        del o # de-ref the o

    print "Before Shrinking, LinkedList contains:"
    p = dummyHead
    stringBuilder = []
    stringBuilder.append(p.name)
    p = p.nxt
    while (p):
        stringBuilder.append("->" + p.name)
        p = p.nxt
    print "".join(stringBuilder)

    print "Before Shrinking, cache contains:", cache.keys()
    for name, value in cache.items():
        print "  %s = %s" % (name, value)
        del value # decref

    # Shrink the LL
    print "Shrinking LL by reseting the head:"
    cur = dummyHead
    cur.nxt = cache["four"]
    # gc.collect() automatically

    print "After Shrinking, LinkedList contains:"
    p = dummyHead
    stringBuilder = []
    stringBuilder.append(p.name)
    p = p.nxt
    while (p):
        stringBuilder.append("->" + p.name)
        p = p.nxt
    print "".join(stringBuilder)


    print "After Shrinking, cache contains:", cache.keys()
    for name, value in cache.items():
        print "  %s = %s" % (name, value)

    print "Demo done. \nCleaning the following objects from the method stack."

print "\n##########Cachine Data Using Dict()############"
demo(dict)
print "\n##########Cachine Data Using WeakValueDictionary()############"
demo(weakref.WeakValueDictionary)

Let’s check the output,

First for caching using Python’s Dict().

##########Cachine Data Using Dict()############
CACHE TYPE:  <type 'dict'>
Before Shrinking, LinkedList contains:
HEAD->one->two->three->four
Before Shrinking, cache contains: ['four', 'three', 'two', 'one']

Shrinking LL by reseting the head...Done

After Shrinking, LinkedList contains:
HEAD->four
After Shrinking, cache contains: ['four', 'three', 'two', 'one']

Demo done. 
Cleaning the following objects from the method stack.
(Deleting ExpensiveObject(HEAD))
(Deleting ExpensiveObject(one))
(Deleting ExpensiveObject(two))
(Deleting ExpensiveObject(three))
(Deleting ExpensiveObject(four))

From above, we can see even the linkedlist is cut short(the “one”, “two”, “three” nodes are removed from the LL), but discarded data nodes are still alive in memory, not being garbage-collected by Python’s GC. That’s because they are still being referenced by the keys in the cache. In this case, we will have to manually remove them from the cache.

Now let’s check the output for caching using Python’s WeakValueDictionary()

##########Cachine Data Using WeakValueDictionary()############
CACHE TYPE:  weakref.WeakValueDictionary
Before Shrinking, LinkedList contains:
HEAD->one->two->three->four
Before Shrinking, cache contains: ['four', 'three', 'two', 'one']

Shrinking LL by reseting the head:
(Deleting ExpensiveObject(one))
(Deleting ExpensiveObject(two))
(Deleting ExpensiveObject(three))

After Shrinking, LinkedList contains:
HEAD->four
After Shrinking, cache contains: ['four']

Demo done. 
Cleaning the following objects from the method stack.
(Deleting ExpensiveObject(HEAD))
(Deleting ExpensiveObject(four))

From above, we can see the linkedlist is cut short, and discarded data nodes are immediately garbage-collected by Python’s GC because the weakref.WeakValueDictionary() does not keep them once the strong reference is gone.

This way we do not need to do the cache shrinking manually. And the time complexity for cache shrinking is optimized to O(1).

A concrete toy project example

In the link below, I solved an interesting logger system design problem using both ways above. Check it out here.