update post content
This commit is contained in:
parent
86caa7e10a
commit
4db50a370a
Binary file not shown.
After Width: | Height: | Size: 149 KiB |
|
@ -26,14 +26,14 @@ tags:
|
|||
- [Requirements and Setup](#requirements-and-setup)
|
||||
- [ELI5 and Sci-kit Learn](#eli5-and-sci-kit-learn)
|
||||
- [Why SVM and LSA?](#why-svm-and-lsa)
|
||||
- [Training the model](#training-the-model)
|
||||
- [Getting some predictions](#getting-some-predictions)
|
||||
- [Getting an explanation](#getting-an-explanation)
|
||||
- [Training the Model](#training-the-model)
|
||||
- [Getting Some Predictions](#getting-some-predictions)
|
||||
- [Getting an Explanation](#getting-an-explanation)
|
||||
- [ELI5 and Transformers/Huggingface](#eli5-and-transformershuggingface)
|
||||
- [Why transformers?](#why-transformers)
|
||||
- [Why Transformers?](#why-transformers)
|
||||
- [Loading The Model](#loading-the-model)
|
||||
- [Defining the Interface with ELI5](#defining-the-interface-with-eli5)
|
||||
- [Getting an explanation](#getting-an-explanation-1)
|
||||
- [Getting an Explanation](#getting-an-explanation-1)
|
||||
- [ELI5 and a Remotely Hosted Model / API](#eli5-and-a-remotely-hosted-model--api)
|
||||
|
||||
|
||||
|
@ -155,7 +155,7 @@ Well LSA is often used as a way to get more performance from an underlying [BoW]
|
|||
RBF is a SVM kernel that can separate data that is not linearly seperable and there's a great explanation of this [here](https://www.kdnuggets.com/2016/06/select-support-vector-machine-kernels.html). RBF is often cited as a [reasonable first choice](https://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf) of kernel for SVMs. However, NLP practitioners will generally [recommend a linear kernel for text classification](https://www.svm-tutorial.com/2014/10/svm-linear-kernel-good-text-classification/) as in practice, and in my experience, text is usually linearly separable. However it will always depend on dataset so do some visualisation during exploratory analysis to see if an RBF kernel is appropriate.
|
||||
|
||||
|
||||
### Training the model
|
||||
### Training the Model
|
||||
|
||||
First we are going to use scikit-learn's built in [fetch_20newsgroups](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_20newsgroups.html#sklearn.datasets.fetch_20newsgroups) helper function to download some example emails from 4 newsgroups. There could reasonably be some serious overlap between the atheism and christian boards so this might be where LSA and our RBF kernel come in handy.
|
||||
|
||||
|
@ -217,7 +217,7 @@ pipe.score(twenty_test.data, twenty_test.target)
|
|||
|
||||
```
|
||||
|
||||
### Getting some predictions
|
||||
### Getting Some Predictions
|
||||
|
||||
Now that the model is trained it is possible to run it on unseen data and get a prediction. In the tutorial
|
||||
the ELI5 authors provide a pretty printing function that shows the probability distribution of the labels for
|
||||
|
@ -235,7 +235,7 @@ print_prediction(doc)
|
|||
|
||||
This is basically just predicting the classes for the given document, which is the first doc in the test set, and then combining the probabilities in the prediction (`y_pred`) with the class names (`twenty_train.target_names`).
|
||||
|
||||
### Getting an explanation
|
||||
### Getting an Explanation
|
||||
|
||||
Getting an explanation of out this model is relatively simple at this point. We simply import the [TextExplainer](https://eli5.readthedocs.io/en/latest/autodocs/lime.html#eli5.lime.lime.TextExplainer) class from ELI5 and `fit()` it to the document (the first one in the test set as per the above snippet). The TextExplainer will use the SVC pipeline `pipe` to make predictions for a bunch of perturbed examples and train its own model. The `show_predictions` function will then give a visualisation of the explanation. The `target_names=` parameter is used to pass the class names from our dataset to the text explainer so that they can be displayed nicely.
|
||||
|
||||
|
@ -248,6 +248,10 @@ te.fit(doc, pipe.predict_proba)
|
|||
te.show_prediction(target_names=twenty_train.target_names)
|
||||
```
|
||||
|
||||
Et voila! Hopefully you will get some output that looks like the below:
|
||||
|
||||
{{<figure src="images/explanation_svm.png" caption="The output of the explain functon should look something like this">}}
|
||||
|
||||
|
||||
## ELI5 and Transformers/Huggingface
|
||||
|
||||
|
@ -255,7 +259,7 @@ te.show_prediction(target_names=twenty_train.target_names)
|
|||
|
||||
This example will work on a machine without a GPU provided you aren't planning on training your transformer model from scratch. I am using [this sentiment model](https://huggingface.co/nlptown/bert-base-multilingual-uncased-sentiment) which evaluates the sentiment/rating of reviews from 1 to 5 in English, Dutch, German, French or Spanish.
|
||||
|
||||
### Why transformers?
|
||||
### Why Transformers?
|
||||
|
||||
Transformer-based models are, at the time of writing, **the in thing** for NLP models - they are a type of deep neural network that has contextual understanding of full sentences. If you're not familiar with them [this article](https://towardsdatascience.com/transformers-89034557de14) offers a fairly good introduction.
|
||||
|
||||
|
@ -322,7 +326,7 @@ def model_adapter(texts: List[str]):
|
|||
|
||||
```
|
||||
|
||||
### Getting an explanation
|
||||
### Getting an Explanation
|
||||
|
||||
The last piece in the puzzle is to actually run the model and get our explanation. Firstly we initialize our explainer object.
|
||||
`n_samples` gives the number of perturbed examples that LIME should generate in order to train the local model (more samples
|
||||
|
|
|
@ -1,64 +1,5 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import requests"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<module 'requests' from '/home/james/miniconda3/envs/pgesg/lib/python3.7/site-packages/requests/__init__.py'>"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"requests"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"API_TOKEN=\"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\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"API_URL = \"https://api-inference.huggingface.co/models/cardiffnlp/twitter-roberta-base-sentiment\"\n",
|
||||
"headers = {\"Authorization\": f\"Bearer {API_TOKEN}\"}\n",
|
||||
"\n",
|
||||
"def query(payload):\n",
|
||||
" data = json.dumps(payload)\n",
|
||||
" response = requests.request(\"POST\", API_URL, headers=headers, data=data)\n",
|
||||
" return json.loads(response.content.decode(\"utf-8\"))\n",
|
||||
"\n",
|
||||
"data = query(\"This is very nice\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -1209,12 +1150,189 @@
|
|||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Remote API Explanation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"import requests"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"API_TOKEN=\"YOUR API KEY HERE\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"MODEL=\"nlptown/bert-base-multilingual-uncased-sentiment\"\n",
|
||||
"\n",
|
||||
"API_URL = f\"https://api-inference.huggingface.co/models/{MODEL}\"\n",
|
||||
"headers = {\"Authorization\": f\"Bearer {API_TOKEN}\"}\n",
|
||||
"\n",
|
||||
"def query(payload):\n",
|
||||
" data = json.dumps(payload)\n",
|
||||
" response = requests.request(\"POST\", API_URL, headers=headers, data=data)\n",
|
||||
" return json.loads(response.content.decode(\"utf-8\"))\n",
|
||||
"\n",
|
||||
"def result_to_df(result):\n",
|
||||
" rows = []\n",
|
||||
" \n",
|
||||
" for result_row in result:\n",
|
||||
" row = {}\n",
|
||||
" for lbl_score in result_row:\n",
|
||||
" row[lbl_score['label']] = lbl_score['score']\n",
|
||||
"\n",
|
||||
" rows.append(row)\n",
|
||||
" \n",
|
||||
" return pd.DataFrame(rows)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = query(\"This is very nice\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def remote_model_adapter(texts: List[str]):\n",
|
||||
" \n",
|
||||
" all_scores = []\n",
|
||||
"\n",
|
||||
" for text in texts:\n",
|
||||
" \n",
|
||||
" data = query(text)\n",
|
||||
" all_scores.extend(result_to_df(data).values)\n",
|
||||
"\n",
|
||||
" return softmax(np.array(all_scores), axis=1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>1 star</th>\n",
|
||||
" <th>2 stars</th>\n",
|
||||
" <th>3 stars</th>\n",
|
||||
" <th>4 stars</th>\n",
|
||||
" <th>5 stars</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>0.003129</td>\n",
|
||||
" <td>0.003055</td>\n",
|
||||
" <td>0.017689</td>\n",
|
||||
" <td>0.194169</td>\n",
|
||||
" <td>0.781958</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" 1 star 2 stars 3 stars 4 stars 5 stars\n",
|
||||
"0 0.003129 0.003055 0.017689 0.194169 0.781958"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data = result_to_df(query('this is so much fun'))\n",
|
||||
"data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/james/miniconda3/envs/pgesg/lib/python3.7/site-packages/sklearn/base.py:213: FutureWarning: From version 0.24, get_params will raise an AttributeError if a parameter cannot be retrieved as an instance attribute. Previously it would return None.\n",
|
||||
" FutureWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TextExplainer(char_based=False,\n",
|
||||
" clf=SGDClassifier(alpha=0.001, loss='log', penalty='elasticnet',\n",
|
||||
" random_state=RandomState(MT19937) at 0x7FE4409B88D0),\n",
|
||||
" n_samples=20, random_state=42,\n",
|
||||
" sampler=MaskingTextSamplers(random_state=RandomState(MT19937) at 0x7FE4409B88D0,\n",
|
||||
" sampler_params=None,\n",
|
||||
" token_pattern='(?u)\\\\b\\\\w+\\\\b',\n",
|
||||
" weights=array([0.7, 0.3])),\n",
|
||||
" token_pattern='(?u)\\\\b\\\\w+\\\\b',\n",
|
||||
" vec=CountVectorizer(ngram_range=(1, 2),\n",
|
||||
" token_pattern='(?u)\\\\b\\\\w+\\\\b'))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"te = TextExplainer(n_samples=20, random_state=42)\n",
|
||||
"te.fit(\"The restaurant was amazing, the quality of their food was exceptional. The waiters were so polite.\", remote_model_adapter)\n",
|
||||
"te.show_prediction(target_names=list(data.columns))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
Loading…
Reference in New Issue