--- title: Painless Explainability for NLP/Text Models with LIME and ELI5 type: post description: An introduction to LIME ML model explainability in the context of NLP usage and how to use ELI5 library - a painless way to use LIME local explainability for almost any model. resources: - name: feature src: images/scrabble.jpg date: 2022-01-13T07:47:11+00:00 url: /2022/01/13/painless-explainability-for-text-models-with-eli5 tags: - machine-learning - work - explainability --- # Contents - [Contents](#contents) - [Introduction](#introduction) - [Understanding LIME](#understanding-lime) - [Local](#local) - [Interpretable](#interpretable) - [Model-Agnostic](#model-agnostic) - [Explanation](#explanation) - [Usage Examples](#usage-examples) - [Requirements and Setup](#requirements-and-setup) - [ELI5 and Sci-kit Learn](#eli5-and-sci-kit-learn) - [Why SVM and LSA?](#why-svm-and-lsa) - [Training the Model](#training-the-model) - [Getting Some Predictions](#getting-some-predictions) - [Getting an Explanation](#getting-an-explanation) - [ELI5 and Transformers/Huggingface](#eli5-and-transformershuggingface) - [Why Transformers?](#why-transformers) - [Loading The Model](#loading-the-model) - [Defining the Interface with ELI5](#defining-the-interface-with-eli5) - [Getting an Explanation](#getting-an-explanation-1) - [ELI5 and a Remotely Hosted Model / API](#eli5-and-a-remotely-hosted-model--api) # Introduction Explainability of machine learning models is a hot topic right now - particularly in deep learning where models are that bit harder to reason about and understand. These models are often called 'black boxes' because you put something in, you get something out and you don't really know how that outcome was achieved. The ability to explain machine learning model's decisions in terms of the features passed in is both useful from a debugging standpoint (identifying features with weird weights) and with legislation like [GDPR's Right to an Explanation](https://www.privacy-regulation.eu/en/r71.htm) it is becoming important in a commercial setting to be able to explain why models behave a certain way. In this post I will give a simplified overview of how LIME works (I may take some small technical liberties and manufacture some contrived examples to demonstrate some of these mechanisms and phenomena - apologies) and then I'll give a brief explanation of how LIME can be applied to a sci-kit learn SVM-based sentiment model and then a huggingface/torch sentiment model. {{
}} # Understanding LIME Lime stands for **L**ocal, **I**nterpretable **M**odel-agnostic **E**xplanations and is a technique proposed by [Ribeiro et al.](https://arxiv.org/abs/1602.04938) in 2016. The basic premise is that for a given input example (in an image classifier we're talking 1 image, in a text classifier we're talking 1 unit of text e.g. a paragraph or a sentence, in a numerical model trained on tabular data we're talking 1 row from that table), LIME can approximate how much of an effect each of the features extracted from the input have on the final output (i.e. How important are a cluster of pixels in an image?, How important are specific words/phrases in a sentence?, How important is each column in that row of numbers?). For a given example both contributing and negating features are highlighted (reasons for and against that decision). {{
}} ## Local The local aspect of LIME is described in [the paper](https://arxiv.org/abs/1602.04938): > ...Although it is often impossible for an explanation to be completely faithful unless it is the complete description of the model itself, for an explanation to be meaningful it must at least be locally faithful, i.e. it must correspond to how the model behaves inthe vicinity of the instance being predicted... > This is a really important constraint of LIME: it offers excellent example-specific explanations that work well for pockets of similar data points but these explanations can't necessarily be generalised for the whole of the model under examination. The authors of the paper also attempt to illustrate this limitation in a diagram: {{
}} This is especially important in tasks that are highly context dependent (like text classification). Here's a contrived example of a spam detection use case. Take the words "7 million usd" as in: >Sir, > >I am a wealthy widow and if you help me I will pay you 7 million usd > >Best Regards and also >Kevin, > >the new term sheet from the investors is in, they're offering 7 million usd for 5% equity, > > Brian Smith
> Head of Mergers & Acquisitions In the first example, the words "7 million usd" contribute to the suspicion that this is a scam in the presence of "wealthy widow" and "help me". In the second example the words "7 million usd" aren't as important, they're words that you'd probably expect in a legitimate email about an investment opportunity from your colleague in Mergers. The point I'm trying to make is that it's very difficult to come up with good general rules about which words are important without any context (and indeed if you can do that then you probably don't need machine learning, you can just build a rule-based system that checks for the presence or absense of words on a list). The overall decision function of "spam or not spam" is much more complicated than "these words are good and these words are bad" but for a certain set of "spammy" examples we can certainly say which words are more spammy and which words are less spammy. This is analogous to the concepts at play in LIME too. Therefore when we're using LIME, we should avoid saying things like "The model seems to consider the words 'million' and 'usd' spammy" and we should say things like "in cases similar to the widow email, it looks like the words 'million' and 'usd' contributed to the decision that this email was spam in the absense of any other redeeming words". ## Interpretable Some machine learning models like [linear models](https://scikit-learn.org/stable/modules/linear_model.html) and [Decision Trees](https://scikit-learn.org/stable/modules/tree.html) are inherently interpretable through being able to measure parameter coefficients (how big the weight of the feature is when calculating the decision boundary line) in the case of the former and how early on a feature appears in a decision tree (since decision trees use [information gain](https://en.wikipedia.org/wiki/Information_gain_in_decision_trees) to put features that tell us most about the final classification/decision near the top of the tree so that they impact more data points) in the case of the latter. LIME exploits these explainable models in order to explain the local context around a given input example. We perturb (slightly change) the input example and use the black-box model under analysis to make predictions. As words are added or removed from the input, the output from the black box model changes slightly (in the [contrived again] example below, removing the word 'love' from the movie review reduces the probability that the review is positive.) {{
}} These perturbed inputs and the outputs from the 'black box' model that we're analysing outputs are then used as a training set to train the local, interpretable model. For text models, LIME uses [Bag-of-Words](https://en.wikipedia.org/wiki/Bag-of-words_model) (BoW) representations of the perturbed input as the features for the local model. We can then use the interpretable information (parameter coefficients/feature position in decision tree) for the local model to approximately interpret the effect that the different words have on the bigger model since each word in the local BoW vocabulary will have an associated coefficient. ## Model-Agnostic LIME's model agnosticism is one of its most useful attributes. As long as you know how to encode the input data and your model has the ability to provide probabality distributions over its outputs, you can provide local explanations for any type of model. This is because the explanation comes from the local model and the BoW features therein rather than the black box model. In the section below I've provided some examples of how to use ELI5 with some different types of models. ## Explanation As we saw at the beginning of the post, the explanations that are produced by LIME for NLP models are usually # Usage Examples ## Requirements and Setup In order to get any of the examples below running you will need a relatively recent version of Python 3 and the [eli5](https://eli5.readthedocs.io/en/latest/autodocs/lime.html#eli5.lime.lime.TextExplainer) library installed too. You will probably want to run the example code in a [Jupyter Notebook](https://jupyter.org/) so that you can see the pretty graphical explanations. If you're not sure about which version of Python to install, you might want to have a quick look at [my opinionated guide to Python environment setup](/2021/04/01/opinionated-guide-to-virtualenvs/). All of these examples will work fine on machines without GPUs although the [transformer model](#eli5-and-transformershuggingface) is a little slow running on CPU (it takes about 60 seconds to run on my 2020 Dell XPS w/ i7, 16GB RAM). ## ELI5 and Sci-kit Learn [Scikit-Learn](https://scikit-learn.org/stable/) is one of the most widely used machine learning libraries used by data scientists everywhere. In this first example we're going to train a model in sci-kit learn and then use ELI5 to get an explanation for it. Make sure you have your python environment set up and [scikit-learn](https://scikit-learn.org/stable/) installed. If you recognise the following example that's because it is also the example that [ELI5 use in their documentation](https://eli5.readthedocs.io/en/latest/tutorials/black-box-text-classifiers.html#example-problem-lsa-svm-for-20-newsgroups-dataset) but I've added some commentary to what's happening in the code snippets. We are going to train a [Support Vector Machine (SVM)](https://en.wikipedia.org/wiki/Support-vector_machine) model to predict which newsgroup an email came from thanks to the [20 newsgroup](https://scikit-learn.org/stable/datasets/real_world.html#newsgroups-dataset) dataset. SVMs with a linear kernel do have feature coefficients which could be used to provide global feature importance. However, to make it harder we will be using an [RBF](https://en.wikipedia.org/wiki/Radial_basis_function_kernel) kernel and we will use [Latent Semantic Analysis](https://en.wikipedia.org/wiki/Latent_semantic_analysis) because that's the setup used in the example and it's a combination that cannot be explained simply without LIME. ### Why SVM and LSA? So why do they used RBF and LIME? Is it a contrived example just to show off LIME? Well LSA is often used as a way to get more performance from an underlying [BoW](https://en.wikipedia.org/wiki/Bag-of-words_model) model by reducing dimensionality and combining commonly co-occuring words into a single feature (rather than having one feature per word). With LSA we might be able to do a better job of capturing some of the general 'topics' and themes that occur across a whole document rather than just tracking words and key phrases (n-grams). This could help with scenarios like the spammer above where LSA could put co-occurences of 'widow', 'million' and 'usd' in one dimension and 'term sheet', 'million', 'usd' in another dimension, giving the SVM a bit of context for the words 'million' and 'usd'. RBF is a SVM kernel that can separate data that is not linearly seperable and there's a great explanation of this [here](https://www.kdnuggets.com/2016/06/select-support-vector-machine-kernels.html). RBF is often cited as a [reasonable first choice](https://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf) of kernel for SVMs. However, NLP practitioners will generally [recommend a linear kernel for text classification](https://www.svm-tutorial.com/2014/10/svm-linear-kernel-good-text-classification/) as in practice, and in my experience, text is usually linearly separable. However it will always depend on dataset so do some visualisation during exploratory analysis to see if an RBF kernel is appropriate. ### Training the Model First we are going to use scikit-learn's built in [fetch_20newsgroups](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_20newsgroups.html#sklearn.datasets.fetch_20newsgroups) helper function to download some example emails from 4 newsgroups. There could reasonably be some serious overlap between the atheism and christian boards so this might be where LSA and our RBF kernel come in handy. ```python from sklearn.datasets import fetch_20newsgroups categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med'] twenty_train = fetch_20newsgroups( subset='train', categories=categories, shuffle=True, random_state=42, remove=('headers', 'footers'), ) twenty_test = fetch_20newsgroups( subset='test', categories=categories, shuffle=True, random_state=42, remove=('headers', 'footers'), ) ``` In the next code snippet we train the code. The [TFIDFVectorizer](https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html) splits the texts into tokens, builds a bag-of-words representation of the text but with the addition of [TF-IDF](https://en.wikipedia.org/wiki/Tf%E2%80%93idf) information to help us filter out words that don't give us any information. The [TruncatedSVD](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.TruncatedSVD.html) object is applied to the TFIDF vectorizer to give us our latent signals/categories. Then the [SVC](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html) is fed the output of the SVD/LSA component. Each component is linked together into a [Pipeline](https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html) object that basically provides syntactic sugar for us later and avoids us having to manually define an interface for ELI5 to call in order to use our model. Finally we call `pipe.fit()` on the training data to actually feed the pipeline and train the model and `pipe.score()` on the test set to give us a top-line accuracy (if we were doing a thorough job we should probably also look at [other appropriate metrics](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html)). `random_state` is simply a number that is used to seed Python's pseudo-random number generator which scikit-learn usesfor pseudo-random operations. Setting random state explicitly is a good habit to get into in order to preserve the reproducibility of your models. Another key parameter set here is `probability=True` on the SVM. This will allow us to get the probability distributions across the classes that LIME will need later. If you don't set this then `predict_proba()` will fail at the next step. ```python from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.svm import SVC from sklearn.decomposition import TruncatedSVD from sklearn.pipeline import Pipeline, make_pipeline vec = TfidfVectorizer(min_df=3, stop_words='english', ngram_range=(1, 2)) svd = TruncatedSVD(n_components=100, n_iter=7, random_state=42) lsa = make_pipeline(vec, svd) clf = SVC(C=150, gamma=2e-2, probability=True) pipe = make_pipeline(lsa, clf) pipe.fit(twenty_train.data, twenty_train.target) pipe.score(twenty_test.data, twenty_test.target) ``` ### Getting Some Predictions Now that the model is trained it is possible to run it on unseen data and get a prediction. In the tutorial the ELI5 authors provide a pretty printing function that shows the probability distribution of the labels for a given example. ```python def print_prediction(doc): y_pred = pipe.predict_proba([doc])[0] for target, prob in zip(twenty_train.target_names, y_pred): print("{:.3f} {}".format(prob, target)) doc = twenty_test.data[0] print_prediction(doc) ``` This is basically just predicting the classes for the given document, which is the first doc in the test set, and then combining the probabilities in the prediction (`y_pred`) with the class names (`twenty_train.target_names`). ### Getting an Explanation Getting an explanation of out this model is relatively simple at this point. We simply import the [TextExplainer](https://eli5.readthedocs.io/en/latest/autodocs/lime.html#eli5.lime.lime.TextExplainer) class from ELI5 and `fit()` it to the document (the first one in the test set as per the above snippet). The TextExplainer will use the SVC pipeline `pipe` to make predictions for a bunch of perturbed examples and train its own model. The `show_predictions` function will then give a visualisation of the explanation. The `target_names=` parameter is used to pass the class names from our dataset to the text explainer so that they can be displayed nicely. ```python import eli5 from eli5.lime import TextExplainer te = TextExplainer(random_state=42) te.fit(doc, pipe.predict_proba) te.show_prediction(target_names=twenty_train.target_names) ``` Et voila! Hopefully you will get some output that looks like the below: {{
}} ## ELI5 and Transformers/Huggingface [Transformers](https://huggingface.co/docs/transformers/index) is an open source library provided by HuggingFace which provides an easy to use wrapper around PyTorch and Tensorflow specifically to make it easy to use transformer-based NLP models like BERT, RoBERTa etc. In order to use ELI5 with Transformers from huggingface, we need to have Python3, [transformers](https://huggingface.co/docs/transformers/index) and a recent version of [pytorch](https://pytorch.org/) installed. This example will work on a machine without a GPU provided you aren't planning on training your transformer model from scratch. I am using [this sentiment model](https://huggingface.co/nlptown/bert-base-multilingual-uncased-sentiment) which evaluates the sentiment/rating of reviews from 1 to 5 in English, Dutch, German, French or Spanish. ### Why Transformers? Transformer-based models are, at the time of writing, **the in thing** for NLP models - they are a type of deep neural network that has contextual understanding of full sentences. If you're not familiar with them [this article](https://towardsdatascience.com/transformers-89034557de14) offers a fairly good introduction. There are good reasons for not using transformers - first and foremost is that they are very computationally expensive to train and somewhat computationally expensive during inference (as you will see if you run both the above SVM experiment and the below transformer experiment). If you find that a less powerful (both in terms of understanding and in terms of power consumption) model works for your use case then using that instead is probably a good move - it'll save you headaches later if you need to scale up your inference operation. ### Loading The Model The following snippet of code simply loads the model into memory amd sets up the tokenizer ready for use with new text examples ```python from transformers import AutoModelForSequenceClassification from transformers import AutoTokenizer import numpy as np import pandas as pd from typing import List # this is the name of the model we want to evaluate on # huggingface.com/models or alternatively you could train your own MODEL="nlptown/bert-base-multilingual-uncased-sentiment" tokenizer = AutoTokenizer.from_pretrained(MODEL) model = AutoModelForSequenceClassification.from_pretrained(MODEL) ``` ### Defining the Interface with ELI5 This snippet of code defines the all important `model_adapter` function which we use to interface between PyTorch and ELI5. ELI5 expects to be able to pass in a list of perturbed texts and get back a set of probability distributions (a matrix in the shape [NUM_EXAMPLES, NUM_CLASSES]). In our function we have to encode the text into a BERT compatible input format using the [tokenizer](https://huggingface.co/transformers/main_classes/tokenizer.html). Then we pass the encoded input to the model and receive some predictions. Finally we use `softmax()` which will convert the raw *logits* generated by the model into nice smooth probability functions that LIME is expecting to see. You may be wondering about the for loop and the batches? ELI5 tries to get results for 5000 samples at a time (by default) and that might be fine in a smaller, less powerful model but with a transformer we can't fit all of those examples into memory. Therefore we split the samples into batches of 64 at a time so that we don't end up running out of RAM. ```python def model_adapter(texts: List[str]): all_scores = [] for i in range(0, len(texts), 64): batch = texts[i:i+64] # use bert encoder to tokenize text encoded_input = tokenizer(batch, return_tensors='pt', padding=True, truncation=True, max_length=model.config.max_position_embeddings-2) # run the model output = model(**encoded_input) # by default this model gives raw logits rather # than a nice smooth softmax so we apply it ourselves here scores = output[0].softmax(1).detach().numpy() all_scores.extend(scores) return np.array(all_scores) ``` ### Getting an Explanation The last piece in the puzzle is to actually run the model and get our explanation. Firstly we initialize our explainer object. `n_samples` gives the number of perturbed examples that LIME should generate in order to train the local model (more samples should give a more faithful local explanation at the cost of more compute/taking longer). Note that as above, we manually set `random_state` for reproducibility. Next we pass the text that we'd like to get an explanation for and the model_adapter function into `fit()` - this will trigger ELI5 to train a LIME model using our transformer model which could take a few seconds or minutes depending on what sort of machine spec you have. Finally, we render the explanation using `te.explain_prediction()`. We pass `target_names=list(model.config.id2label.values())` which tells the `TextExplainer` what the class names from the bert model are (class names are stored in `config.id2label` by convention in [Huggingface transformer configurations](https://huggingface.co/docs/transformers/main_classes/configuration) but this function will accept any list of strings that is the same length as the number of classes in the model). ```python from eli5.lime import TextExplainer te = TextExplainer(n_samples=5000, random_state=42) te.fit("""The restaurant was amazing, the quality of their food was exceptional. The waiters were so polite.""", model_adapter) te.explain_prediction(target_names=list(model.config.id2label.values())) ``` Et voila! Hopefully you will get some output that looks like the below: {{
}} ## ELI5 and a Remotely Hosted Model / API