Supervised Learning — When is a model “good enough”?

Supervised learning is about learning a mapping between input and output data. A common application is classification tasks, where the class for a given data point is to be predicted.

In the training phase, a supervised learning algorithm learns a pattern by which the correct labels (e.g. classes) can be assigned to the data based on known examples. This is stored in a model.

To assess the quality of the model, additional data is used (test data). The predicted classes are compared with the actual classes. To evaluate how good the model performs, we use metrics, whereby different use cases require different metrics.

In this blogpost we will look at three popular metrics (accuracy, recall and sensitivity) and three examples. For simplicity, we limit the models to predict one of two classes (binary classification problems).

The results of the model on the test data are represented as a confusion matrix. In a confusion matrix, the classes predicted by the model (Predicted, cf. figure) are placed column-wise and the actual classes (Actual, cf. figure) are placed row-wise. To learn more about the different metrics, we focus on the following three examples:

  • Example 1: A model has been trained to recognize whether a picture shows a cat or a dog. The model correctly predicts 41 cats out of 48 cat images, but misclassifies 7 images as showing dogs. For dog images, 49 are correctly predicted as dogs. 3 of the images are classified as showing cats while they actually show dogs.
confusion matrix for cat detection
  • Example 2: In the second example, a model was trained to predict whether a given person is sick or healthy. Out of a total of 10 sick persons, 2 were correctly predicted as sick, 8 wrongly as healthy. All 990 healthy persons were correctly classified as healthy.
Confusion matrix for disease detection
  • Example 3: The third example is a spam filter that classifies mails as “spam” or as “no spam”. 8 mails that are spam are actually classified as being spam. On the other hand, 5 mails are wrongly classified as spam. 2 spam mails even slip through our spam filter and are not predicted as spam, although they are spam. 21 regular mails, on the other hand, are correctly classified as “no spam”.
Confusion matrix for spam filter

Additional note: Metrics and confusion matrices often contain terms like “true positive” or “false negative”. These terms refer to the four fields that are colored in the confusion matrices above: In binary classification problems, it is often easier to speak of yes (“spam”) and no (“no spam”) instead of two classes like “spam” and “no spam”. This then directly results in:

  • True positive (TP): Cases where the model predicted yes and the actual class was also yes.
  • True negative (TN): Cases where the model predicted no and the actual class was also no.
  • False positive (FP): Cases where the model predicted yes and the actual class was no.
  • False negative (FN): Cases where the model predicted no and the actual class was yes.

Accuracy

The accuracy indicates how many of all cases were correctly classified, i.e.:

Accuracy

This allows us to calculate the accuracy for the examples above:

  • Example 1: The number of correctly predicted cats (41) and correctly predicted dogs (49) is set in relation to the number of total images seen (100):
  • Example 2: The number of correctly predicted sick persons (2) and correctly predicted healthy persons(990) is set in relation to the number of total examined persons (1000):
  • Example 3: The number of mails correctly classified as spam (8) and mails correctly classified as regular (21) is put in relation to the total number of mails examined (36):

Accuracy is the most general metric. When recognizing cats, 90% accuracy is quite good. For our spam classifier, it is about 81% accurate. Our model for recognizing sick persons, however, delivers an astaunding accuracy of 99.2%. Sounds quite promising at first … but wait! Let’s take another look at the exact numbers. It is noticeable that only 10 of all 1000 persons were sick. But our model did not correctly predict 8 of them. Nevertheless, the model has a higher accuracy than our cat detector: so accuracy does not tell us the whole truth, especially when the classes (as in the case of sick and healthy) are unequally distributed. Therefore, let’s check two other metrics that are more likely to cope with this problem.

Recall (or hit rate)

Recall (or hit rate) measures how many of the positive cases were correctly predicted:

Recall (or hit rate)

We now apply this formula to determine the recall (hit rate) for our examples:

  • Example 1: The number of correctly predicted cats (41) is related to the number of cats in the entire data set (48).
  • Example 2: The number of correctly predicted sick persons (2) is divided by the number of sick persons (10).
  • Example 3: The number of correctly classified spam mails (8) is set in relation to the number of all spam mails (10). With our spam filter, 80% of the spam mails are “hit”, 20% remain undetected.

If we look at recall as a metric, we can see how poorly our disease detection model works. Only 20% of all sick people were correctly predicted. For a dangerous and highly infectious disease, this is unacceptable! Our spam filter, on the other hand, seems to work satisfactorily with recall being 80%. But wait! Five mails were wrongly classified as spam and never landed in our inbox. What if one of these 5 mails is the important confirmation for the new job or an important message from an old school friend? Let’s look at another metric for such cases.

Precision

Precision indicates how well (or how precisely) the model detects positive cases, for example, how many of the mails classified as spam are really spam.

Precision

Thus, unlike recall (hit rate), precision does not measure which data points that actually belong to class 1 are “hit” correctly, but how accurate the predictions are.

  • Example 1: The number of cats correctly classified (41) is related to the total number of images predicted by the model to be cats (44).
  • Example 2: The number of persons correctly classified as sick (2) is related to the total number of persons classified as sick (2).
  • Example 3: The number of mails correctly classified as spam (8) is set in relation to the number of mails (13) for which spam was predicted. If our model classifies an e-mail as spam, it is correct in about 62% of the cases. However, it is not very “precise”, as 38% of the predictions are wrong.

With precision we have found a meaningful metric for our spam filter: It is not very “precise” and 38% of all spam classifications are wrong. So we should improve the model. For our disease detection, we got a precision of 100%. However, we already know that our model classified only 2 sick persons correctly, but did not detect 8 sick persons. Therefore, precision is not an appropriate metric in the case of our disease detection.

Overall, metrics are a very helpful tool for assessing the quality of a model in supervised learning. Depending on the task, one metric might be more suitable than another. Accuracy is useful when the class distribution is about equal. Recall is useful when there are high costs associated with wrong classifications (e.g., sick people). Precision is useful when the cost of incorrectly classifying data points as positive is very high (e.g. when a mail that is important to us is classified as spam). In addition, there are a lot of other metrics, such as the F1-score, which is based on both precision and recall.

However, metrics are always only a proxy for what we actually want to achieve with our model. If we rely solely on metrics, many other important things might be neglected (such as excluding or reducing negative consequences for certain groups of people) (see also Thomas & Uminsky, 2020).

Researcher | Developer | Speaker | Educator