diff --git a/brainsteam/content/posts/2022/01/13-01-painless-explainability-for-text-models-with-eli5/images/explanation_example.png b/brainsteam/content/posts/2022/01/13-01-painless-explainability-for-text-models-with-eli5/images/explanation_example.png new file mode 100644 index 0000000..a955cc6 Binary files /dev/null and b/brainsteam/content/posts/2022/01/13-01-painless-explainability-for-text-models-with-eli5/images/explanation_example.png differ diff --git a/brainsteam/content/posts/2022/01/13-01-painless-explainability-for-text-models-with-eli5/index.md b/brainsteam/content/posts/2022/01/13-01-painless-explainability-for-text-models-with-eli5/index.md index 3e726d6..a3705d3 100644 --- a/brainsteam/content/posts/2022/01/13-01-painless-explainability-for-text-models-with-eli5/index.md +++ b/brainsteam/content/posts/2022/01/13-01-painless-explainability-for-text-models-with-eli5/index.md @@ -23,11 +23,17 @@ tags: - [Model-Agnostic](#model-agnostic) - [Explanation](#explanation) - [Usage Examples](#usage-examples) + - [Requirements and Setup](#requirements-and-setup) - [ELI5 and Sci-kit Learn](#eli5-and-sci-kit-learn) + - [Why SVM and LSA?](#why-svm-and-lsa) + - [Training the model](#training-the-model) + - [Getting some predictions](#getting-some-predictions) + - [Getting an explanation](#getting-an-explanation) - [ELI5 and Transformers/Huggingface](#eli5-and-transformershuggingface) + - [Why transformers?](#why-transformers) - [Loading The Model](#loading-the-model) - [Defining the Interface with ELI5](#defining-the-interface-with-eli5) - - [Getting an explanation](#getting-an-explanation) + - [Getting an explanation](#getting-an-explanation-1) - [ELI5 and a Remotely Hosted Model / API](#eli5-and-a-remotely-hosted-model--api) @@ -124,15 +130,136 @@ As we saw at the beginning of the post, the explanations that are produced by LI # Usage Examples +## Requirements and Setup + +In order to get any of the examples below running you will need a relatively recent version of Python 3 and the [eli5](https://eli5.readthedocs.io/en/latest/autodocs/lime.html#eli5.lime.lime.TextExplainer) library installed too. You will probably want to run the example code in a [Jupyter Notebook](https://jupyter.org/) so that you can see the pretty graphical explanations. + +If you're not sure about which version of Python to install, you might want to have a quick look at [my opinionated guide to Python environment setup](/2021/04/01/opinionated-guide-to-virtualenvs/). + +All of these examples will work fine on machines without GPUs although the [transformer model](#eli5-and-transformershuggingface) is a little slow running on CPU (it takes about 60 seconds to run on my 2020 Dell XPS w/ i7, 16GB RAM). + ## ELI5 and Sci-kit Learn +[Scikit-Learn](https://scikit-learn.org/stable/) is one of the most widely used machine learning libraries used by data scientists everywhere. In this first example we're going to train a model in sci-kit learn and then use ELI5 to get an explanation for it. Make sure you have your python environment set up and [scikit-learn](https://scikit-learn.org/stable/) installed. + +If you recognise the following example that's because it is also the example that [ELI5 use in their documentation](https://eli5.readthedocs.io/en/latest/tutorials/black-box-text-classifiers.html#example-problem-lsa-svm-for-20-newsgroups-dataset) but I've added some commentary to what's happening in the code snippets. + +We are going to train a [Support Vector Machine (SVM)](https://en.wikipedia.org/wiki/Support-vector_machine) model to predict which newsgroup an email came from thanks to the [20 newsgroup](https://scikit-learn.org/stable/datasets/real_world.html#newsgroups-dataset) dataset. SVMs with a linear kernel do have feature coefficients which could be used to provide global feature importance. However, to make it harder we will be using an [RBF](https://en.wikipedia.org/wiki/Radial_basis_function_kernel) kernel and we will use [Latent Semantic Analysis](https://en.wikipedia.org/wiki/Latent_semantic_analysis) because that's the setup used in the example and it's a combination that cannot be explained simply without LIME. + +### Why SVM and LSA? + +So why do they used RBF and LIME? Is it a contrived example just to show off LIME? + +Well LSA is often used as a way to get more performance from an underlying [BoW](https://en.wikipedia.org/wiki/Bag-of-words_model) model by reducing dimensionality and combining commonly co-occuring words into a single feature (rather than having one feature per word). With LSA we might be able to do a better job of capturing some of the general 'topics' and themes that occur across a whole document rather than just tracking words and key phrases (n-grams). This could help with scenarios like the spammer above where LSA could put co-occurences of 'widow', 'million' and 'usd' in one dimension and 'term sheet', 'million', 'usd' in another dimension, giving the SVM a bit of context for the words 'million' and 'usd'. + +RBF is a SVM kernel that can separate data that is not linearly seperable and there's a great explanation of this [here](https://www.kdnuggets.com/2016/06/select-support-vector-machine-kernels.html). RBF is often cited as a [reasonable first choice](https://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf) of kernel for SVMs. However, NLP practitioners will generally [recommend a linear kernel for text classification](https://www.svm-tutorial.com/2014/10/svm-linear-kernel-good-text-classification/) as in practice, and in my experience, text is usually linearly separable. However it will always depend on dataset so do some visualisation during exploratory analysis to see if an RBF kernel is appropriate. + + +### Training the model + +First we are going to use scikit-learn's built in [fetch_20newsgroups](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_20newsgroups.html#sklearn.datasets.fetch_20newsgroups) helper function to download some example emails from 4 newsgroups. There could reasonably be some serious overlap between the atheism and christian boards so this might be where LSA and our RBF kernel come in handy. + +```python +from sklearn.datasets import fetch_20newsgroups + +categories = ['alt.atheism', 'soc.religion.christian', + 'comp.graphics', 'sci.med'] +twenty_train = fetch_20newsgroups( + subset='train', + categories=categories, + shuffle=True, + random_state=42, + remove=('headers', 'footers'), +) +twenty_test = fetch_20newsgroups( + subset='test', + categories=categories, + shuffle=True, + random_state=42, + remove=('headers', 'footers'), +) + +``` + +In the next code snippet we train the code. The [TFIDFVectorizer](https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html) splits the texts into tokens, builds a bag-of-words representation of the text but with the addition of [TF-IDF](https://en.wikipedia.org/wiki/Tf%E2%80%93idf) information to help us filter out words that don't give us any information. + +The [TruncatedSVD](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.TruncatedSVD.html) object is applied to the TFIDF vectorizer to give us our latent signals/categories. + +Then the [SVC](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html) is fed the output of the SVD/LSA component. + +Each component is linked together into a [Pipeline](https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html) object that basically provides syntactic sugar for us later and avoids us having to manually define an interface for ELI5 to call in order to use our model. + +Finally we call `pipe.fit()` on the training data to actually feed the pipeline and train the model and `pipe.score()` on the test set to give us a top-line accuracy (if we were doing a thorough job we should probably also look at [other appropriate metrics](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html)). + +`random_state` is simply a number that is used to seed Python's pseudo-random number generator which +scikit-learn usesfor pseudo-random operations. Setting random state explicitly is a good habit to get +into in order to preserve the reproducibility of your models. + +Another key parameter set here is `probability=True` on the SVM. This will allow us to get the probability distributions +across the classes that LIME will need later. If you don't set this then `predict_proba()` will fail at the next step. + + +```python +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.svm import SVC +from sklearn.decomposition import TruncatedSVD +from sklearn.pipeline import Pipeline, make_pipeline + +vec = TfidfVectorizer(min_df=3, stop_words='english', + ngram_range=(1, 2)) +svd = TruncatedSVD(n_components=100, n_iter=7, random_state=42) +lsa = make_pipeline(vec, svd) + +clf = SVC(C=150, gamma=2e-2, probability=True) +pipe = make_pipeline(lsa, clf) +pipe.fit(twenty_train.data, twenty_train.target) +pipe.score(twenty_test.data, twenty_test.target) + +``` + +### Getting some predictions + +Now that the model is trained it is possible to run it on unseen data and get a prediction. In the tutorial +the ELI5 authors provide a pretty printing function that shows the probability distribution of the labels for +a given example. + +```python +def print_prediction(doc): + y_pred = pipe.predict_proba([doc])[0] + for target, prob in zip(twenty_train.target_names, y_pred): + print("{:.3f} {}".format(prob, target)) + +doc = twenty_test.data[0] +print_prediction(doc) +``` + +This is basically just predicting the classes for the given document, which is the first doc in the test set, and then combining the probabilities in the prediction (`y_pred`) with the class names (`twenty_train.target_names`). + +### Getting an explanation + +Getting an explanation of out this model is relatively simple at this point. We simply import the [TextExplainer](https://eli5.readthedocs.io/en/latest/autodocs/lime.html#eli5.lime.lime.TextExplainer) class from ELI5 and `fit()` it to the document (the first one in the test set as per the above snippet). The TextExplainer will use the SVC pipeline `pipe` to make predictions for a bunch of perturbed examples and train its own model. The `show_predictions` function will then give a visualisation of the explanation. The `target_names=` parameter is used to pass the class names from our dataset to the text explainer so that they can be displayed nicely. + +```python +import eli5 +from eli5.lime import TextExplainer + +te = TextExplainer(random_state=42) +te.fit(doc, pipe.predict_proba) +te.show_prediction(target_names=twenty_train.target_names) +``` + + ## ELI5 and Transformers/Huggingface -[Transformers](https://huggingface.co/docs/transformers/index) is an open source library provided by HuggingFace which provides an easy to use wrapper around PyTorch and Tensorflow specifically to make it easy to use transformer-based NLP models like BERT, RoBERTa etc. In order to use ELI5 with Transformers from huggingface, we need to have Python3, [transformers](https://huggingface.co/docs/transformers/index) and a recent version of [pytorch](https://pytorch.org/) installed. You will probably want to run this code in a [Jupyter Notebook](https://jupyter.org/) so that you can see the pretty graphical explanations. Of course you'll also need [eli5](https://eli5.readthedocs.io/en/latest/autodocs/lime.html#eli5.lime.lime.TextExplainer) library installed too. +[Transformers](https://huggingface.co/docs/transformers/index) is an open source library provided by HuggingFace which provides an easy to use wrapper around PyTorch and Tensorflow specifically to make it easy to use transformer-based NLP models like BERT, RoBERTa etc. In order to use ELI5 with Transformers from huggingface, we need to have Python3, [transformers](https://huggingface.co/docs/transformers/index) and a recent version of [pytorch](https://pytorch.org/) installed. This example will work on a machine without a GPU provided you aren't planning on training your transformer model from scratch. I am using [this sentiment model](https://huggingface.co/nlptown/bert-base-multilingual-uncased-sentiment) which evaluates the sentiment/rating of reviews from 1 to 5 in English, Dutch, German, French or Spanish. +### Why transformers? +Transformer-based models are, at the time of writing, **the in thing** for NLP models - they are a type of deep neural network that has contextual understanding of full sentences. If you're not familiar with them [this article](https://towardsdatascience.com/transformers-89034557de14) offers a fairly good introduction. + +There are good reasons for not using transformers - first and foremost is that they are very computationally expensive to train and somewhat computationally expensive during inference (as you will see if you run both the above SVM experiment and the below transformer experiment). If you find that a less powerful (both in terms of understanding and in terms of power consumption) model works for your use case then using that instead is probably a good move - it'll save you headaches later if you need to scale up your inference operation. ### Loading The Model @@ -197,13 +324,19 @@ def model_adapter(texts: List[str]): ### Getting an explanation -The last piece in the puzzle is to actually run the model and get our explanation. Firstly we initialize our explainer object +The last piece in the puzzle is to actually run the model and get our explanation. Firstly we initialize our explainer object. +`n_samples` gives the number of perturbed examples that LIME should generate in order to train the local model (more samples +should give a more faithful local explanation at the cost of more compute/taking longer). Note that as above, we manually set +`random_state` for reproducibility. +Next we pass the text that we'd like to get an explanation for and the model_adapter function into `fit()` - this will trigger +ELI5 to train a LIME model using our transformer model which could take a few seconds or minutes depending on what sort of machine +spec you have. -Here we pass in the text that we'd like to get an explanation for. `n_samples` gives the number of perturbed examples that LIME should generate in order -to train the local model (more samples should give a more faithful local explanation at the cost of more compute/taking longer). -Random state is simply a number that is used to seed Python's pseudo-random number generator which LIME uses to randomly decide what -samples to pick. Setting random state explicitly is a good habit to get into in order to preserve the reproducibility of your models. +Finally, we render the explanation using `te.explain_prediction()`. We pass `target_names=list(model.config.id2label.values())` which +tells the `TextExplainer` what the class names from the bert model are (class names are stored in `config.id2label` by convention in +[Huggingface transformer configurations](https://huggingface.co/docs/transformers/main_classes/configuration) but this function will accept +any list of strings that is the same length as the number of classes in the model). ```python from eli5.lime import TextExplainer @@ -214,5 +347,8 @@ food was exceptional. The waiters were so polite.""", model_adapter) te.explain_prediction(target_names=list(model.config.id2label.values())) ``` +Et voila! Hopefully you will get some output that looks like the below: + +{{
}} ## ELI5 and a Remotely Hosted Model / API \ No newline at end of file diff --git a/brainsteam/content/posts/2022/01/13-01-painless-explainability-for-text-models-with-eli5/test.ipynb b/brainsteam/content/posts/2022/01/13-01-painless-explainability-for-text-models-with-eli5/test.ipynb index b6d4f44..c9937de 100644 --- a/brainsteam/content/posts/2022/01/13-01-painless-explainability-for-text-models-with-eli5/test.ipynb +++ b/brainsteam/content/posts/2022/01/13-01-painless-explainability-for-text-models-with-eli5/test.ipynb @@ -59,6 +59,529 @@ "data = query(\"This is very nice\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Scikit Learn ELI5 Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install scikit-learn" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.datasets import fetch_20newsgroups\n", + "\n", + "categories = ['alt.atheism', 'soc.religion.christian',\n", + " 'comp.graphics', 'sci.med']\n", + "twenty_train = fetch_20newsgroups(\n", + " subset='train',\n", + " categories=categories,\n", + " shuffle=True,\n", + " random_state=42,\n", + " remove=('headers', 'footers'),\n", + ")\n", + "twenty_test = fetch_20newsgroups(\n", + " subset='test',\n", + " categories=categories,\n", + " shuffle=True,\n", + " random_state=42,\n", + " remove=('headers', 'footers'),\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.8901464713715047" + ] + }, + "execution_count": 89, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.feature_extraction.text import TfidfVectorizer\n", + "from sklearn.svm import SVC\n", + "from sklearn.decomposition import TruncatedSVD\n", + "from sklearn.pipeline import Pipeline, make_pipeline\n", + "\n", + "vec = TfidfVectorizer(min_df=3, stop_words='english',\n", + " ngram_range=(1, 2))\n", + "svd = TruncatedSVD(n_components=100, n_iter=7, random_state=42)\n", + "lsa = make_pipeline(vec, svd)\n", + "\n", + "clf = SVC(C=150, gamma=2e-2, probability=True)\n", + "pipe = make_pipeline(lsa, clf)\n", + "pipe.fit(twenty_train.data, twenty_train.target)\n", + "pipe.score(twenty_test.data, twenty_test.target)" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.001 alt.atheism\n", + "0.001 comp.graphics\n", + "0.995 sci.med\n", + "0.003 soc.religion.christian\n" + ] + } + ], + "source": [ + "def print_prediction(doc):\n", + " y_pred = pipe.predict_proba([doc])[0]\n", + " for target, prob in zip(twenty_train.target_names, y_pred):\n", + " print(\"{:.3f} {}\".format(prob, target))\n", + "\n", + "doc = twenty_test.data[0]\n", + "print_prediction(doc)" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "

\n", + " \n", + " \n", + " y=alt.atheism\n", + " \n", + "\n", + "\n", + " \n", + " (probability 0.000, score -8.648)\n", + "\n", + "top features\n", + "

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + " Contribution?\n", + " Feature
\n", + " -0.398\n", + " \n", + " <BIAS>\n", + "
\n", + " -8.249\n", + " \n", + " Highlighted in text (sum)\n", + "
\n", + "\n", + " \n", + "\n", + "\n", + "\n", + "

\n", + " as i recall from my bout with kidney stones, there isn't any\n", + "medication that can do anything about them except relieve the pain.\n", + "\n", + "either they pass, or they have to be broken up with sound, or they have\n", + "to be extracted surgically.\n", + "\n", + "when i was in, the x-ray tech happened to mention that she'd had kidney\n", + "stones and children, and the childbirth hurt less.\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "

\n", + " \n", + " \n", + " y=comp.graphics\n", + " \n", + "\n", + "\n", + " \n", + " (probability 0.000, score -8.687)\n", + "\n", + "top features\n", + "

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + " Contribution?\n", + " Feature
\n", + " -0.283\n", + " \n", + " <BIAS>\n", + "
\n", + " -8.404\n", + " \n", + " Highlighted in text (sum)\n", + "
\n", + "\n", + " \n", + "\n", + "\n", + "\n", + "

\n", + " as i recall from my bout with kidney stones, there isn't any\n", + "medication that can do anything about them except relieve the pain.\n", + "\n", + "either they pass, or they have to be broken up with sound, or they have\n", + "to be extracted surgically.\n", + "\n", + "when i was in, the x-ray tech happened to mention that she'd had kidney\n", + "stones and children, and the childbirth hurt less.\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "

\n", + " \n", + " \n", + " y=sci.med\n", + " \n", + "\n", + "\n", + " \n", + " (probability 0.996, score 6.821)\n", + "\n", + "top features\n", + "

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + " Contribution?\n", + " Feature
\n", + " +6.883\n", + " \n", + " Highlighted in text (sum)\n", + "
\n", + " -0.061\n", + " \n", + " <BIAS>\n", + "
\n", + "\n", + " \n", + "\n", + "\n", + "\n", + "

\n", + " as i recall from my bout with kidney stones, there isn't any\n", + "medication that can do anything about them except relieve the pain.\n", + "\n", + "either they pass, or they have to be broken up with sound, or they have\n", + "to be extracted surgically.\n", + "\n", + "when i was in, the x-ray tech happened to mention that she'd had kidney\n", + "stones and children, and the childbirth hurt less.\n", + "

\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "

\n", + " \n", + " \n", + " y=soc.religion.christian\n", + " \n", + "\n", + "\n", + " \n", + " (probability 0.004, score -5.612)\n", + "\n", + "top features\n", + "

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + " Contribution?\n", + " Feature
\n", + " -0.326\n", + " \n", + " <BIAS>\n", + "
\n", + " -5.286\n", + " \n", + " Highlighted in text (sum)\n", + "
\n", + "\n", + " \n", + "\n", + "\n", + "\n", + "

\n", + " as i recall from my bout with kidney stones, there isn't any\n", + "medication that can do anything about them except relieve the pain.\n", + "\n", + "either they pass, or they have to be broken up with sound, or they have\n", + "to be extracted surgically.\n", + "\n", + "when i was in, the x-ray tech happened to mention that she'd had kidney\n", + "stones and children, and the childbirth hurt less.\n", + "

\n", + "\n", + " \n", + "\n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 91, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import eli5\n", + "from eli5.lime import TextExplainer\n", + "\n", + "te = TextExplainer(random_state=42)\n", + "te.fit(doc, pipe.predict_proba)\n", + "te.show_prediction(target_names=twenty_train.target_names)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Transformers ELI5 example\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install transformers\n", + "!pip install torch==1.9.1+cpu -f https://download.pytorch.org/whl/torch_stable.html" + ] + }, { "cell_type": "code", "execution_count": 20, @@ -122,7 +645,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 84, "metadata": {}, "outputs": [ { @@ -138,9 +661,9 @@ "text/plain": [ "TextExplainer(char_based=False,\n", " clf=SGDClassifier(alpha=0.001, loss='log', penalty='elasticnet',\n", - " random_state=RandomState(MT19937) at 0x7FE441F8F050),\n", - " n_samples=1000, random_state=42,\n", - " sampler=MaskingTextSamplers(random_state=RandomState(MT19937) at 0x7FE441F8F050,\n", + " random_state=RandomState(MT19937) at 0x7FE441F8F490),\n", + " random_state=42,\n", + " sampler=MaskingTextSamplers(random_state=RandomState(MT19937) at 0x7FE441F8F490,\n", " sampler_params=None,\n", " token_pattern='(?u)\\\\b\\\\w+\\\\b',\n", " weights=array([0.7, 0.3])),\n", @@ -149,19 +672,19 @@ " token_pattern='(?u)\\\\b\\\\w+\\\\b'))" ] }, - "execution_count": 82, + "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "te = TextExplainer(n_samples=1000, random_state=42)\n", - "te.fit(\"I was sad and disappointed by this comment\", model_adapter)" + "te = TextExplainer(n_samples=5000, random_state=42)\n", + "te.fit(\"The restaurant was amazing, the quality of their food was exceptional. The waiters were so polite.\", model_adapter)" ] }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 85, "metadata": {}, "outputs": [ { @@ -214,7 +737,7 @@ "\n", "\n", " \n", - " (probability 0.322, score -0.640)\n", + " (probability 0.001, score -6.834)\n", "\n", "top features\n", "

\n", @@ -238,22 +761,22 @@ "\n", " \n", " \n", - " \n", + " \n", " \n", - " -0.055\n", + " -0.468\n", " \n", " \n", - " Highlighted in text (sum)\n", + " <BIAS>\n", " \n", " \n", "\n", " \n", - " \n", + " \n", " \n", - " -0.585\n", + " -6.366\n", " \n", " \n", - " <BIAS>\n", + " Highlighted in text (sum)\n", " \n", " \n", "\n", @@ -267,7 +790,7 @@ "\n", "\n", "

\n", - " i was sad and disappointed by this comment\n", + " the restaurant was amazing, the quality of their food was exceptional. the waiters were so polite.\n", "

\n", "\n", " \n", @@ -284,7 +807,7 @@ "\n", "\n", " \n", - " (probability 0.514, score 0.209)\n", + " (probability 0.002, score -6.109)\n", "\n", "top features\n", "

\n", @@ -304,23 +827,13 @@ " \n", " \n", " \n", - " \n", - " \n", - " +0.917\n", - " \n", - " \n", - " Highlighted in text (sum)\n", - " \n", - " \n", - "\n", - " \n", " \n", "\n", " \n", " \n", - " \n", + " \n", " \n", - " -0.708\n", + " -0.400\n", " \n", " \n", " <BIAS>\n", @@ -328,6 +841,16 @@ " \n", "\n", " \n", + " \n", + " \n", + " -5.709\n", + " \n", + " \n", + " Highlighted in text (sum)\n", + " \n", + " \n", + "\n", + " \n", "\n", " \n", " \n", @@ -337,7 +860,7 @@ "\n", "\n", "

\n", - " i was sad and disappointed by this comment\n", + " the restaurant was amazing, the quality of their food was exceptional. the waiters were so polite.\n", "

\n", "\n", " \n", @@ -354,7 +877,7 @@ "\n", "\n", " \n", - " (probability 0.161, score -1.565)\n", + " (probability 0.013, score -4.219)\n", "\n", "top features\n", "

\n", @@ -378,9 +901,9 @@ "\n", " \n", " \n", - " \n", + " \n", " \n", - " -0.603\n", + " -0.397\n", " \n", " \n", " <BIAS>\n", @@ -388,9 +911,9 @@ " \n", "\n", " \n", - " \n", + " \n", " \n", - " -0.962\n", + " -3.822\n", " \n", " \n", " Highlighted in text (sum)\n", @@ -407,7 +930,7 @@ "\n", "\n", "

\n", - " i was sad and disappointed by this comment\n", + " the restaurant was amazing, the quality of their food was exceptional. the waiters were so polite.\n", "

\n", "\n", " \n", @@ -424,7 +947,7 @@ "\n", "\n", " \n", - " (probability 0.002, score -6.296)\n", + " (probability 0.218, score -1.159)\n", "\n", "top features\n", "

\n", @@ -448,22 +971,22 @@ "\n", " \n", " \n", - " \n", + " \n", " \n", - " -0.369\n", + " -0.539\n", " \n", " \n", - " <BIAS>\n", + " Highlighted in text (sum)\n", " \n", " \n", "\n", " \n", - " \n", + " \n", " \n", - " -5.927\n", + " -0.621\n", " \n", " \n", - " Highlighted in text (sum)\n", + " <BIAS>\n", " \n", " \n", "\n", @@ -477,7 +1000,7 @@ "\n", "\n", "

\n", - " i was sad and disappointed by this comment\n", + " the restaurant was amazing, the quality of their food was exceptional. the waiters were so polite.\n", "

\n", "\n", " \n", @@ -494,7 +1017,7 @@ "\n", "\n", " \n", - " (probability 0.001, score -6.650)\n", + " (probability 0.766, score 1.648)\n", "\n", "top features\n", "

\n", @@ -514,26 +1037,26 @@ " \n", " \n", " \n", - " \n", - "\n", - " \n", - " \n", - " \n", + " \n", " \n", - " -0.360\n", + " +2.169\n", " \n", " \n", - " <BIAS>\n", + " Highlighted in text (sum)\n", " \n", " \n", "\n", " \n", - " \n", + " \n", + "\n", + " \n", + " \n", + " \n", " \n", - " -6.289\n", + " -0.521\n", " \n", " \n", - " Highlighted in text (sum)\n", + " <BIAS>\n", " \n", " \n", "\n", @@ -547,7 +1070,7 @@ "\n", "\n", "

\n", - " i was sad and disappointed by this comment\n", + " the restaurant was amazing, the quality of their food was exceptional. the waiters were so polite.\n", "

\n", "\n", " \n", @@ -589,10 +1112,10 @@ "\n" ], "text/plain": [ - "Explanation(estimator=\"SGDClassifier(alpha=0.001, loss='log', penalty='elasticnet',\\n random_state=RandomState(MT19937) at 0x7FE441F8FC00)\", description=None, error=None, method='linear model', is_regression=False, targets=[TargetExplanation(target='1 star', feature_weights=FeatureWeights(pos=[FeatureWeight(feature='sad and', weight=0.5346186851714406, std=None, value=1.0), FeatureWeight(feature='by this', weight=0.42658779616112424, std=None, value=1.0), FeatureWeight(feature='and disappointed', weight=0.26653522108846317, std=None, value=1.0), FeatureWeight(feature='i was', weight=0.19021551478926602, std=None, value=1.0), FeatureWeight(feature='this comment', weight=0.17398835561819753, std=None, value=1.0), FeatureWeight(feature='sad', weight=0.154955236219833, std=None, value=1.0), FeatureWeight(feature='disappointed', weight=0.11045707013459205, std=None, value=1.0)], neg=[FeatureWeight(feature='', weight=-0.5847697578742233, std=None, value=1.0), FeatureWeight(feature='i', weight=-0.4710052143644072, std=None, value=1.0), FeatureWeight(feature='by', weight=-0.3865256334829903, std=None, value=1.0), FeatureWeight(feature='and', weight=-0.3741197710655346, std=None, value=1.0), FeatureWeight(feature='was', weight=-0.21584403881881575, std=None, value=1.0), FeatureWeight(feature='this', weight=-0.1771012983233238, std=None, value=1.0), FeatureWeight(feature='disappointed by', weight=-0.13037261951049633, std=None, value=1.0), FeatureWeight(feature='comment', weight=-0.09160800881871993, std=None, value=1.0), FeatureWeight(feature='was sad', weight=-0.06625548348972331, std=None, value=1.0)], pos_remaining=0, neg_remaining=0), proba=0.3216432779568045, score=-0.6402439465653179, weighted_spans=WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document='i was sad and disappointed by this comment', spans=[('i', [(0, 1)], -0.4710052143644072), ('was', [(2, 5)], -0.21584403881881575), ('sad', [(6, 9)], 0.154955236219833), ('and', [(10, 13)], -0.3741197710655346), ('disappointed', [(14, 26)], 0.11045707013459205), ('by', [(27, 29)], -0.3865256334829903), ('this', [(30, 34)], -0.1771012983233238), ('comment', [(35, 42)], -0.09160800881871993), ('i was', [(0, 1), (2, 5)], 0.19021551478926602), ('was sad', [(2, 5), (6, 9)], -0.06625548348972331), ('sad and', [(6, 9), (10, 13)], 0.5346186851714406), ('and disappointed', [(10, 13), (14, 26)], 0.26653522108846317), ('disappointed by', [(14, 26), (27, 29)], -0.13037261951049633), ('by this', [(27, 29), (30, 34)], 0.42658779616112424), ('this comment', [(30, 34), (35, 42)], 0.17398835561819753)], preserve_density=False, vec_name=None)], other=FeatureWeights(pos=[], neg=[FeatureWeight(feature='', weight=-0.5847697578742233, std=None, value=1.0), FeatureWeight(feature=, weight=-0.05547418869109433, std=None, value=None)], pos_remaining=0, neg_remaining=0)), heatmap=None), TargetExplanation(target='2 stars', feature_weights=FeatureWeights(pos=[FeatureWeight(feature='disappointed', weight=0.7328982078589731, std=None, value=1.0), FeatureWeight(feature='and disappointed', weight=0.5320969197071604, std=None, value=1.0), FeatureWeight(feature='was sad', weight=0.2131501740286752, std=None, value=1.0), FeatureWeight(feature='sad', weight=0.204538558077756, std=None, value=1.0), FeatureWeight(feature='by this', weight=0.1664729577435313, std=None, value=1.0), FeatureWeight(feature='this comment', weight=0.1331403830562311, std=None, value=1.0), FeatureWeight(feature='i was', weight=0.0992793624540375, std=None, value=1.0), FeatureWeight(feature='disappointed by', weight=0.0757687718883966, std=None, value=1.0)], neg=[FeatureWeight(feature='', weight=-0.7081592777076793, std=None, value=1.0), FeatureWeight(feature='this', weight=-0.4236960149446214, std=None, value=1.0), FeatureWeight(feature='and', weight=-0.36806936626741465, std=None, value=1.0), FeatureWeight(feature='comment', weight=-0.2259694258591071, std=None, value=1.0), FeatureWeight(feature='was', weight=-0.2024479929487215, std=None, value=1.0), FeatureWeight(feature='by', weight=-0.014466480265356482, std=None, value=1.0), FeatureWeight(feature='i', weight=-0.005985952449988492, std=None, value=1.0)], pos_remaining=0, neg_remaining=0), proba=0.5142968951266649, score=0.20855082437187222, weighted_spans=WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document='i was sad and disappointed by this comment', spans=[('i', [(0, 1)], -0.005985952449988492), ('was', [(2, 5)], -0.2024479929487215), ('sad', [(6, 9)], 0.204538558077756), ('and', [(10, 13)], -0.36806936626741465), ('disappointed', [(14, 26)], 0.7328982078589731), ('by', [(27, 29)], -0.014466480265356482), ('this', [(30, 34)], -0.4236960149446214), ('comment', [(35, 42)], -0.2259694258591071), ('i was', [(0, 1), (2, 5)], 0.0992793624540375), ('was sad', [(2, 5), (6, 9)], 0.2131501740286752), ('and disappointed', [(10, 13), (14, 26)], 0.5320969197071604), ('disappointed by', [(14, 26), (27, 29)], 0.0757687718883966), ('by this', [(27, 29), (30, 34)], 0.1664729577435313), ('this comment', [(30, 34), (35, 42)], 0.1331403830562311)], preserve_density=False, vec_name=None)], other=FeatureWeights(pos=[FeatureWeight(feature=, weight=0.9167101020795517, std=None, value=None)], neg=[FeatureWeight(feature='', weight=-0.7081592777076793, std=None, value=1.0)], pos_remaining=0, neg_remaining=0)), heatmap=None), TargetExplanation(target='3 stars', feature_weights=FeatureWeights(pos=[FeatureWeight(feature='disappointed by', weight=0.49534031715083565, std=None, value=1.0), FeatureWeight(feature='was sad', weight=0.38185475125317725, std=None, value=1.0), FeatureWeight(feature='this comment', weight=0.3040274759378381, std=None, value=1.0), FeatureWeight(feature='i was', weight=0.08740115905472043, std=None, value=1.0), FeatureWeight(feature='sad and', weight=0.08027723553189965, std=None, value=1.0)], neg=[FeatureWeight(feature='', weight=-0.6027694124901809, std=None, value=1.0), FeatureWeight(feature='disappointed', weight=-0.5150103773886033, std=None, value=1.0), FeatureWeight(feature='sad', weight=-0.44967947400070096, std=None, value=1.0), FeatureWeight(feature='this', weight=-0.41934779999200184, std=None, value=1.0), FeatureWeight(feature='and', weight=-0.38879109152041874, std=None, value=1.0), FeatureWeight(feature='and disappointed', weight=-0.20839201881188554, std=None, value=1.0), FeatureWeight(feature='comment', weight=-0.17453889925444666, std=None, value=1.0), FeatureWeight(feature='by', weight=-0.08144982607176687, std=None, value=1.0), FeatureWeight(feature='was', weight=-0.07386605125736567, std=None, value=1.0)], pos_remaining=0, neg_remaining=0), proba=0.16114092285641266, score=-1.5649440118588995, weighted_spans=WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document='i was sad and disappointed by this comment', spans=[('was', [(2, 5)], -0.07386605125736567), ('sad', [(6, 9)], -0.44967947400070096), ('and', [(10, 13)], -0.38879109152041874), ('disappointed', [(14, 26)], -0.5150103773886033), ('by', [(27, 29)], -0.08144982607176687), ('this', [(30, 34)], -0.41934779999200184), ('comment', [(35, 42)], -0.17453889925444666), ('i was', [(0, 1), (2, 5)], 0.08740115905472043), ('was sad', [(2, 5), (6, 9)], 0.38185475125317725), ('sad and', [(6, 9), (10, 13)], 0.08027723553189965), ('and disappointed', [(10, 13), (14, 26)], -0.20839201881188554), ('disappointed by', [(14, 26), (27, 29)], 0.49534031715083565), ('this comment', [(30, 34), (35, 42)], 0.3040274759378381)], preserve_density=False, vec_name=None)], other=FeatureWeights(pos=[], neg=[FeatureWeight(feature=, weight=-0.9621745993687185, std=None, value=None), FeatureWeight(feature='', weight=-0.6027694124901809, std=None, value=1.0)], pos_remaining=0, neg_remaining=0)), heatmap=None), TargetExplanation(target='4 stars', feature_weights=FeatureWeights(pos=[FeatureWeight(feature='i was', weight=0.3280158349131319, std=None, value=1.0), FeatureWeight(feature='by this', weight=0.24485117988897515, std=None, value=1.0), FeatureWeight(feature='this comment', weight=0.22136455004546846, std=None, value=1.0)], neg=[FeatureWeight(feature='disappointed', weight=-2.2818466043358065, std=None, value=1.0), FeatureWeight(feature='sad', weight=-1.5584253852713434, std=None, value=1.0), FeatureWeight(feature='this', weight=-0.6638038328645915, std=None, value=1.0), FeatureWeight(feature='and disappointed', weight=-0.5323899161000051, std=None, value=1.0), FeatureWeight(feature='was', weight=-0.46923892676653156, std=None, value=1.0), FeatureWeight(feature='by', weight=-0.37154981102885365, std=None, value=1.0), FeatureWeight(feature='', weight=-0.3690717498285274, std=None, value=1.0), FeatureWeight(feature='comment', weight=-0.330310147529379, std=None, value=1.0), FeatureWeight(feature='i', weight=-0.27051122014568324, std=None, value=1.0), FeatureWeight(feature='was sad', weight=-0.18528806395587175, std=None, value=1.0), FeatureWeight(feature='and', weight=-0.05813554259772132, std=None, value=1.0)], pos_remaining=0, neg_remaining=0), proba=0.0017141517689746404, score=-6.2963396355767385, weighted_spans=WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document='i was sad and disappointed by this comment', spans=[('i', [(0, 1)], -0.27051122014568324), ('was', [(2, 5)], -0.46923892676653156), ('sad', [(6, 9)], -1.5584253852713434), ('and', [(10, 13)], -0.05813554259772132), ('disappointed', [(14, 26)], -2.2818466043358065), ('by', [(27, 29)], -0.37154981102885365), ('this', [(30, 34)], -0.6638038328645915), ('comment', [(35, 42)], -0.330310147529379), ('i was', [(0, 1), (2, 5)], 0.3280158349131319), ('was sad', [(2, 5), (6, 9)], -0.18528806395587175), ('and disappointed', [(10, 13), (14, 26)], -0.5323899161000051), ('by this', [(27, 29), (30, 34)], 0.24485117988897515), ('this comment', [(30, 34), (35, 42)], 0.22136455004546846)], preserve_density=False, vec_name=None)], other=FeatureWeights(pos=[], neg=[FeatureWeight(feature=, weight=-5.92726788574821, std=None, value=None), FeatureWeight(feature='', weight=-0.3690717498285274, std=None, value=1.0)], pos_remaining=0, neg_remaining=0)), heatmap=None), TargetExplanation(target='5 stars', feature_weights=FeatureWeights(pos=[FeatureWeight(feature='by this', weight=0.5199374968512099, std=None, value=1.0), FeatureWeight(feature='i was', weight=0.11021409204987892, std=None, value=1.0)], neg=[FeatureWeight(feature='disappointed', weight=-2.542609677519791, std=None, value=1.0), FeatureWeight(feature='sad', weight=-1.6205444349003744, std=None, value=1.0), FeatureWeight(feature='and disappointed', weight=-0.6823758319028586, std=None, value=1.0), FeatureWeight(feature='was', weight=-0.5884377438193445, std=None, value=1.0), FeatureWeight(feature='by', weight=-0.47836975878854626, std=None, value=1.0), FeatureWeight(feature='disappointed by', weight=-0.4713742071188074, std=None, value=1.0), FeatureWeight(feature='', weight=-0.36028536795303406, std=None, value=1.0), FeatureWeight(feature='comment', weight=-0.23648279521049556, std=None, value=1.0), FeatureWeight(feature='was sad', weight=-0.2026448613238884, std=None, value=1.0), FeatureWeight(feature='sad and', weight=-0.09655848179203014, std=None, value=1.0)], pos_remaining=0, neg_remaining=0), proba=0.0012047522911433645, score=-6.649531571428082, weighted_spans=WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document='i was sad and disappointed by this comment', spans=[('was', [(2, 5)], -0.5884377438193445), ('sad', [(6, 9)], -1.6205444349003744), ('disappointed', [(14, 26)], -2.542609677519791), ('by', [(27, 29)], -0.47836975878854626), ('comment', [(35, 42)], -0.23648279521049556), ('i was', [(0, 1), (2, 5)], 0.11021409204987892), ('was sad', [(2, 5), (6, 9)], -0.2026448613238884), ('sad and', [(6, 9), (10, 13)], -0.09655848179203014), ('and disappointed', [(10, 13), (14, 26)], -0.6823758319028586), ('disappointed by', [(14, 26), (27, 29)], -0.4713742071188074), ('by this', [(27, 29), (30, 34)], 0.5199374968512099)], preserve_density=False, vec_name=None)], other=FeatureWeights(pos=[], neg=[FeatureWeight(feature=, weight=-6.289246203475048, std=None, value=None), FeatureWeight(feature='', weight=-0.36028536795303406, std=None, value=1.0)], pos_remaining=0, neg_remaining=0)), heatmap=None)], feature_importances=None, decision_tree=None, highlight_spaces=None, transition_features=None, image=None)" + "Explanation(estimator=\"SGDClassifier(alpha=0.001, loss='log', penalty='elasticnet',\\n random_state=RandomState(MT19937) at 0x7FE441F8F5A0)\", description=None, error=None, method='linear model', is_regression=False, targets=[TargetExplanation(target='1 star', feature_weights=FeatureWeights(pos=[FeatureWeight(feature='were so', weight=0.4412142792788093, std=None, value=1.0), FeatureWeight(feature='food was', weight=0.26752107341372183, std=None, value=1.0), FeatureWeight(feature='restaurant was', weight=0.26656608799204584, std=None, value=1.0), FeatureWeight(feature='waiters were', weight=0.20671345225968554, std=None, value=1.0), FeatureWeight(feature='their food', weight=0.1922554265078954, std=None, value=1.0), FeatureWeight(feature='of their', weight=0.17089439768401465, std=None, value=1.0), FeatureWeight(feature='so polite', weight=0.16971856137145266, std=None, value=1.0), FeatureWeight(feature='the restaurant', weight=0.0939572854002312, std=None, value=1.0), FeatureWeight(feature='the quality', weight=0.062358167271686325, std=None, value=1.0), FeatureWeight(feature='was amazing', weight=0.04621754955053698, std=None, value=1.0), FeatureWeight(feature='quality of', weight=0.036982483291333176, std=None, value=1.0)], neg=[FeatureWeight(feature='exceptional', weight=-1.3083460073861424, std=None, value=1.0), FeatureWeight(feature='amazing', weight=-1.2489714603173, std=None, value=1.0), FeatureWeight(feature='quality', weight=-0.9791490505974166, std=None, value=1.0), FeatureWeight(feature='restaurant', weight=-0.7171043876100932, std=None, value=1.0), FeatureWeight(feature='food', weight=-0.6679038114428723, std=None, value=1.0), FeatureWeight(feature='their', weight=-0.5886911093430724, std=None, value=1.0), FeatureWeight(feature='waiters', weight=-0.562453762890923, std=None, value=1.0), FeatureWeight(feature='the', weight=-0.5612386460991848, std=None, value=3.0), FeatureWeight(feature='polite', weight=-0.4909223929633989, std=None, value=1.0), FeatureWeight(feature='so', weight=-0.4746325526619182, std=None, value=1.0), FeatureWeight(feature='', weight=-0.4681751172694132, std=None, value=1.0), FeatureWeight(feature='were', weight=-0.3105891970589915, std=None, value=1.0), FeatureWeight(feature='was', weight=-0.20789884677179024, std=None, value=2.0), FeatureWeight(feature='of', weight=-0.12298547156907864, std=None, value=1.0), FeatureWeight(feature='was exceptional', weight=-0.07927481966138608, std=None, value=1.0)], pos_remaining=0, neg_remaining=0), proba=0.0009819658649170206, score=-6.833937869621568, weighted_spans=WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document='the restaurant was amazing, the quality of their food was exceptional. the waiters were so polite.', spans=[('the', [(0, 3)], -0.5612386460991848), ('restaurant', [(4, 14)], -0.7171043876100932), ('was', [(15, 18)], -0.20789884677179024), ('amazing', [(19, 26)], -1.2489714603173), ('the', [(28, 31)], -0.5612386460991848), ('quality', [(32, 39)], -0.9791490505974166), ('of', [(40, 42)], -0.12298547156907864), ('their', [(43, 48)], -0.5886911093430724), ('food', [(49, 53)], -0.6679038114428723), ('was', [(54, 57)], -0.20789884677179024), ('exceptional', [(58, 69)], -1.3083460073861424), ('the', [(71, 74)], -0.5612386460991848), ('waiters', [(75, 82)], -0.562453762890923), ('were', [(83, 87)], -0.3105891970589915), ('so', [(88, 90)], -0.4746325526619182), ('polite', [(91, 97)], -0.4909223929633989), ('the restaurant', [(0, 3), (4, 14)], 0.0939572854002312), ('restaurant was', [(4, 14), (15, 18)], 0.26656608799204584), ('was amazing', [(15, 18), (19, 26)], 0.04621754955053698), ('the quality', [(28, 31), (32, 39)], 0.062358167271686325), ('quality of', [(32, 39), (40, 42)], 0.036982483291333176), ('of their', [(40, 42), (43, 48)], 0.17089439768401465), ('their food', [(43, 48), (49, 53)], 0.1922554265078954), ('food was', [(49, 53), (54, 57)], 0.26752107341372183), ('was exceptional', [(54, 57), (58, 69)], -0.07927481966138608), ('waiters were', [(75, 82), (83, 87)], 0.20671345225968554), ('were so', [(83, 87), (88, 90)], 0.4412142792788093), ('so polite', [(88, 90), (91, 97)], 0.16971856137145266)], preserve_density=False, vec_name=None)], other=FeatureWeights(pos=[], neg=[FeatureWeight(feature=, weight=-6.365762752352155, std=None, value=None), FeatureWeight(feature='', weight=-0.4681751172694132, std=None, value=1.0)], pos_remaining=0, neg_remaining=0)), heatmap=None), TargetExplanation(target='2 stars', feature_weights=FeatureWeights(pos=[FeatureWeight(feature='were so', weight=0.3687584351354406, std=None, value=1.0), FeatureWeight(feature='food was', weight=0.2818213047878279, std=None, value=1.0), FeatureWeight(feature='restaurant was', weight=0.23694730908198494, std=None, value=1.0), FeatureWeight(feature='waiters were', weight=0.21363431305973518, std=None, value=1.0), FeatureWeight(feature='quality of', weight=0.20779481726471224, std=None, value=1.0), FeatureWeight(feature='the waiters', weight=0.20231925231831172, std=None, value=1.0), FeatureWeight(feature='their food', weight=0.1811418380466741, std=None, value=1.0), FeatureWeight(feature='of their', weight=0.1753614820666619, std=None, value=1.0), FeatureWeight(feature='the quality', weight=0.15707226483127093, std=None, value=1.0)], neg=[FeatureWeight(feature='amazing', weight=-1.5783751972217899, std=None, value=1.0), FeatureWeight(feature='exceptional', weight=-1.4790309783432105, std=None, value=1.0), FeatureWeight(feature='quality', weight=-0.9347689615679214, std=None, value=1.0), FeatureWeight(feature='food', weight=-0.6327993132635227, std=None, value=1.0), FeatureWeight(feature='the', weight=-0.5788250220187662, std=None, value=3.0), FeatureWeight(feature='restaurant', weight=-0.5602856269944528, std=None, value=1.0), FeatureWeight(feature='waiters', weight=-0.4789001349410294, std=None, value=1.0), FeatureWeight(feature='their', weight=-0.46587342328228626, std=None, value=1.0), FeatureWeight(feature='polite', weight=-0.4606029092132801, std=None, value=1.0), FeatureWeight(feature='', weight=-0.4001754632372141, std=None, value=1.0), FeatureWeight(feature='were', weight=-0.2543025592541633, std=None, value=1.0), FeatureWeight(feature='so', weight=-0.09871439184564407, std=None, value=1.0), FeatureWeight(feature='exceptional the', weight=-0.09697973820214975, std=None, value=1.0), FeatureWeight(feature='of', weight=-0.08167691136359295, std=None, value=1.0), FeatureWeight(feature='was', weight=-0.0324062121189669, std=None, value=2.0)], pos_remaining=0, neg_remaining=0), proba=0.002025322078557236, score=-6.108865826275371, weighted_spans=WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document='the restaurant was amazing, the quality of their food was exceptional. the waiters were so polite.', spans=[('the', [(0, 3)], -0.5788250220187662), ('restaurant', [(4, 14)], -0.5602856269944528), ('was', [(15, 18)], -0.0324062121189669), ('amazing', [(19, 26)], -1.5783751972217899), ('the', [(28, 31)], -0.5788250220187662), ('quality', [(32, 39)], -0.9347689615679214), ('of', [(40, 42)], -0.08167691136359295), ('their', [(43, 48)], -0.46587342328228626), ('food', [(49, 53)], -0.6327993132635227), ('was', [(54, 57)], -0.0324062121189669), ('exceptional', [(58, 69)], -1.4790309783432105), ('the', [(71, 74)], -0.5788250220187662), ('waiters', [(75, 82)], -0.4789001349410294), ('were', [(83, 87)], -0.2543025592541633), ('so', [(88, 90)], -0.09871439184564407), ('polite', [(91, 97)], -0.4606029092132801), ('restaurant was', [(4, 14), (15, 18)], 0.23694730908198494), ('the quality', [(28, 31), (32, 39)], 0.15707226483127093), ('quality of', [(32, 39), (40, 42)], 0.20779481726471224), ('of their', [(40, 42), (43, 48)], 0.1753614820666619), ('their food', [(43, 48), (49, 53)], 0.1811418380466741), ('food was', [(49, 53), (54, 57)], 0.2818213047878279), ('exceptional the', [(58, 69), (71, 74)], -0.09697973820214975), ('the waiters', [(71, 74), (75, 82)], 0.20231925231831172), ('waiters were', [(75, 82), (83, 87)], 0.21363431305973518), ('were so', [(83, 87), (88, 90)], 0.3687584351354406)], preserve_density=False, vec_name=None)], other=FeatureWeights(pos=[], neg=[FeatureWeight(feature=, weight=-5.708690363038157, std=None, value=None), FeatureWeight(feature='', weight=-0.4001754632372141, std=None, value=1.0)], pos_remaining=0, neg_remaining=0)), heatmap=None), TargetExplanation(target='3 stars', feature_weights=FeatureWeights(pos=[FeatureWeight(feature='so', weight=0.3732887992270877, std=None, value=1.0), FeatureWeight(feature='was amazing', weight=0.2973068556130115, std=None, value=1.0), FeatureWeight(feature='food was', weight=0.2966172170449891, std=None, value=1.0), FeatureWeight(feature='their food', weight=0.2183481994654202, std=None, value=1.0), FeatureWeight(feature='the quality', weight=0.2157576281628976, std=None, value=1.0), FeatureWeight(feature='waiters were', weight=0.19123854139546625, std=None, value=1.0), FeatureWeight(feature='restaurant was', weight=0.15656173982130717, std=None, value=1.0), FeatureWeight(feature='the restaurant', weight=0.1479724362568333, std=None, value=1.0), FeatureWeight(feature='quality of', weight=0.14264912543429953, std=None, value=1.0), FeatureWeight(feature='of their', weight=0.12312764745359929, std=None, value=1.0), FeatureWeight(feature='the waiters', weight=0.09515209208809647, std=None, value=1.0), FeatureWeight(feature='was', weight=0.07169380833545562, std=None, value=2.0)], neg=[FeatureWeight(feature='amazing', weight=-1.5430523131452285, std=None, value=1.0), FeatureWeight(feature='exceptional', weight=-1.3418851039304842, std=None, value=1.0), FeatureWeight(feature='quality', weight=-0.6076663429775776, std=None, value=1.0), FeatureWeight(feature='so polite', weight=-0.5499925348438075, std=None, value=1.0), FeatureWeight(feature='the', weight=-0.46821278223214635, std=None, value=3.0), FeatureWeight(feature='', weight=-0.3973113230934471, std=None, value=1.0), FeatureWeight(feature='food', weight=-0.32245188389535806, std=None, value=1.0), FeatureWeight(feature='waiters', weight=-0.3220949693051254, std=None, value=1.0), FeatureWeight(feature='restaurant', weight=-0.2779002036079693, std=None, value=1.0), FeatureWeight(feature='were', weight=-0.22670218335080206, std=None, value=1.0), FeatureWeight(feature='their', weight=-0.22404977692950792, std=None, value=1.0), FeatureWeight(feature='polite', weight=-0.17712570659886498, std=None, value=1.0), FeatureWeight(feature='were so', weight=-0.07174581836680817, std=None, value=1.0), FeatureWeight(feature='of', weight=-0.018921718697862692, std=None, value=1.0)], pos_remaining=0, neg_remaining=0), proba=0.01323435654544038, score=-4.219398570676526, weighted_spans=WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document='the restaurant was amazing, the quality of their food was exceptional. the waiters were so polite.', spans=[('the', [(0, 3)], -0.46821278223214635), ('restaurant', [(4, 14)], -0.2779002036079693), ('was', [(15, 18)], 0.07169380833545562), ('amazing', [(19, 26)], -1.5430523131452285), ('the', [(28, 31)], -0.46821278223214635), ('quality', [(32, 39)], -0.6076663429775776), ('of', [(40, 42)], -0.018921718697862692), ('their', [(43, 48)], -0.22404977692950792), ('food', [(49, 53)], -0.32245188389535806), ('was', [(54, 57)], 0.07169380833545562), ('exceptional', [(58, 69)], -1.3418851039304842), ('the', [(71, 74)], -0.46821278223214635), ('waiters', [(75, 82)], -0.3220949693051254), ('were', [(83, 87)], -0.22670218335080206), ('so', [(88, 90)], 0.3732887992270877), ('polite', [(91, 97)], -0.17712570659886498), ('the restaurant', [(0, 3), (4, 14)], 0.1479724362568333), ('restaurant was', [(4, 14), (15, 18)], 0.15656173982130717), ('was amazing', [(15, 18), (19, 26)], 0.2973068556130115), ('the quality', [(28, 31), (32, 39)], 0.2157576281628976), ('quality of', [(32, 39), (40, 42)], 0.14264912543429953), ('of their', [(40, 42), (43, 48)], 0.12312764745359929), ('their food', [(43, 48), (49, 53)], 0.2183481994654202), ('food was', [(49, 53), (54, 57)], 0.2966172170449891), ('the waiters', [(71, 74), (75, 82)], 0.09515209208809647), ('waiters were', [(75, 82), (83, 87)], 0.19123854139546625), ('were so', [(83, 87), (88, 90)], -0.07174581836680817), ('so polite', [(88, 90), (91, 97)], -0.5499925348438075)], preserve_density=False, vec_name=None)], other=FeatureWeights(pos=[], neg=[FeatureWeight(feature=, weight=-3.8220872475830787, std=None, value=None), FeatureWeight(feature='', weight=-0.3973113230934471, std=None, value=1.0)], pos_remaining=0, neg_remaining=0)), heatmap=None), TargetExplanation(target='4 stars', feature_weights=FeatureWeights(pos=[FeatureWeight(feature='was amazing', weight=0.2406325077168544, std=None, value=1.0), FeatureWeight(feature='their food', weight=0.15157723853870006, std=None, value=1.0), FeatureWeight(feature='the quality', weight=0.1399818513805465, std=None, value=1.0), FeatureWeight(feature='the restaurant', weight=0.13787874014492787, std=None, value=1.0), FeatureWeight(feature='restaurant was', weight=0.12684331354544345, std=None, value=1.0), FeatureWeight(feature='the waiters', weight=0.12309277037766059, std=None, value=1.0), FeatureWeight(feature='waiters were', weight=0.09734080413820671, std=None, value=1.0), FeatureWeight(feature='food was', weight=0.08692520104048276, std=None, value=1.0), FeatureWeight(feature='was exceptional', weight=0.08062590525808609, std=None, value=1.0), FeatureWeight(feature='exceptional the', weight=0.009074556454012837, std=None, value=1.0), FeatureWeight(feature='amazing the', weight=0.006158831321800694, std=None, value=1.0)], neg=[FeatureWeight(feature='amazing', weight=-0.6215482013799933, std=None, value=1.0), FeatureWeight(feature='', weight=-0.620885173127225, std=None, value=1.0), FeatureWeight(feature='exceptional', weight=-0.2908827666417324, std=None, value=1.0), FeatureWeight(feature='the', weight=-0.2644996283089753, std=None, value=3.0), FeatureWeight(feature='so polite', weight=-0.15732513067254023, std=None, value=1.0), FeatureWeight(feature='so', weight=-0.15620445986313092, std=None, value=1.0), FeatureWeight(feature='were', weight=-0.10913166927840638, std=None, value=1.0), FeatureWeight(feature='quality', weight=-0.10333919110091677, std=None, value=1.0), FeatureWeight(feature='waiters', weight=-0.03571463874678352, std=None, value=1.0)], pos_remaining=0, neg_remaining=0), proba=0.21802021364159152, score=-1.1593991392029817, weighted_spans=WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document='the restaurant was amazing, the quality of their food was exceptional. the waiters were so polite.', spans=[('the', [(0, 3)], -0.2644996283089753), ('amazing', [(19, 26)], -0.6215482013799933), ('the', [(28, 31)], -0.2644996283089753), ('quality', [(32, 39)], -0.10333919110091677), ('exceptional', [(58, 69)], -0.2908827666417324), ('the', [(71, 74)], -0.2644996283089753), ('waiters', [(75, 82)], -0.03571463874678352), ('were', [(83, 87)], -0.10913166927840638), ('so', [(88, 90)], -0.15620445986313092), ('the restaurant', [(0, 3), (4, 14)], 0.13787874014492787), ('restaurant was', [(4, 14), (15, 18)], 0.12684331354544345), ('was amazing', [(15, 18), (19, 26)], 0.2406325077168544), ('amazing the', [(19, 26), (28, 31)], 0.006158831321800694), ('the quality', [(28, 31), (32, 39)], 0.1399818513805465), ('their food', [(43, 48), (49, 53)], 0.15157723853870006), ('food was', [(49, 53), (54, 57)], 0.08692520104048276), ('was exceptional', [(54, 57), (58, 69)], 0.08062590525808609), ('exceptional the', [(58, 69), (71, 74)], 0.009074556454012837), ('the waiters', [(71, 74), (75, 82)], 0.12309277037766059), ('waiters were', [(75, 82), (83, 87)], 0.09734080413820671), ('so polite', [(88, 90), (91, 97)], -0.15732513067254023)], preserve_density=False, vec_name=None)], other=FeatureWeights(pos=[], neg=[FeatureWeight(feature='', weight=-0.620885173127225, std=None, value=1.0), FeatureWeight(feature=, weight=-0.5385139660757567, std=None, value=None)], pos_remaining=0, neg_remaining=0)), heatmap=None), TargetExplanation(target='5 stars', feature_weights=FeatureWeights(pos=[FeatureWeight(feature='amazing', weight=1.4166616343102316, std=None, value=1.0), FeatureWeight(feature='exceptional', weight=1.1116493929667317, std=None, value=1.0), FeatureWeight(feature='so polite', weight=0.7183995264652492, std=None, value=1.0), FeatureWeight(feature='quality', weight=0.4603260408381861, std=None, value=1.0), FeatureWeight(feature='were so', weight=0.2490215751420647, std=None, value=1.0), FeatureWeight(feature='of their', weight=0.0823679786389739, std=None, value=1.0), FeatureWeight(feature='food', weight=0.055860706726513255, std=None, value=1.0), FeatureWeight(feature='waiters', weight=0.02617745978942477, std=None, value=1.0), FeatureWeight(feature='food was', weight=0.01654520820628123, std=None, value=1.0)], neg=[FeatureWeight(feature='', weight=-0.5209352118853326, std=None, value=1.0), FeatureWeight(feature='was', weight=-0.48869249339855975, std=None, value=2.0), FeatureWeight(feature='so', weight=-0.42322661890010355, std=None, value=1.0), FeatureWeight(feature='was amazing', weight=-0.2535461214477351, std=None, value=1.0), FeatureWeight(feature='of', weight=-0.21389664754772872, std=None, value=1.0), FeatureWeight(feature='their food', weight=-0.17140724010730823, std=None, value=1.0), FeatureWeight(feature='the quality', weight=-0.14962242819710417, std=None, value=1.0), FeatureWeight(feature='amazing the', weight=-0.07179925678814633, std=None, value=1.0), FeatureWeight(feature='polite', weight=-0.06184756173850257, std=None, value=1.0), FeatureWeight(feature='the', weight=-0.06032069433534206, std=None, value=3.0), FeatureWeight(feature='exceptional the', weight=-0.04214936922012958, std=None, value=1.0), FeatureWeight(feature='waiters were', weight=-0.031429441689239994, std=None, value=1.0)], pos_remaining=0, neg_remaining=0), proba=0.7657381418694937, score=1.6481364378284238, weighted_spans=WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document='the restaurant was amazing, the quality of their food was exceptional. the waiters were so polite.', spans=[('the', [(0, 3)], -0.06032069433534206), ('was', [(15, 18)], -0.48869249339855975), ('amazing', [(19, 26)], 1.4166616343102316), ('the', [(28, 31)], -0.06032069433534206), ('quality', [(32, 39)], 0.4603260408381861), ('of', [(40, 42)], -0.21389664754772872), ('food', [(49, 53)], 0.055860706726513255), ('was', [(54, 57)], -0.48869249339855975), ('exceptional', [(58, 69)], 1.1116493929667317), ('the', [(71, 74)], -0.06032069433534206), ('waiters', [(75, 82)], 0.02617745978942477), ('so', [(88, 90)], -0.42322661890010355), ('polite', [(91, 97)], -0.06184756173850257), ('was amazing', [(15, 18), (19, 26)], -0.2535461214477351), ('amazing the', [(19, 26), (28, 31)], -0.07179925678814633), ('the quality', [(28, 31), (32, 39)], -0.14962242819710417), ('of their', [(40, 42), (43, 48)], 0.0823679786389739), ('their food', [(43, 48), (49, 53)], -0.17140724010730823), ('food was', [(49, 53), (54, 57)], 0.01654520820628123), ('exceptional the', [(58, 69), (71, 74)], -0.04214936922012958), ('waiters were', [(75, 82), (83, 87)], -0.031429441689239994), ('were so', [(83, 87), (88, 90)], 0.2490215751420647), ('so polite', [(88, 90), (91, 97)], 0.7183995264652492)], preserve_density=False, vec_name=None)], other=FeatureWeights(pos=[FeatureWeight(feature=, weight=2.169071649713757, std=None, value=None)], neg=[FeatureWeight(feature='', weight=-0.5209352118853326, std=None, value=1.0)], pos_remaining=0, neg_remaining=0)), heatmap=None)], feature_importances=None, decision_tree=None, highlight_spaces=None, transition_features=None, image=None)" ] }, - "execution_count": 79, + "execution_count": 85, "metadata": {}, "output_type": "execute_result" } @@ -659,6 +1182,33 @@ "pd.DataFrame(data)" ] }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.datasets import fetch_20newsgroups\n", + "\n", + "categories = ['alt.atheism', 'soc.religion.christian',\n", + " 'comp.graphics', 'sci.med']\n", + " \n", + "twenty_train = fetch_20newsgroups(\n", + " subset='train',\n", + " categories=categories,\n", + " shuffle=True,\n", + " random_state=42,\n", + " remove=('headers', 'footers'),\n", + ")\n", + "twenty_test = fetch_20newsgroups(\n", + " subset='test',\n", + " categories=categories,\n", + " shuffle=True,\n", + " random_state=42,\n", + " remove=('headers', 'footers'),\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null,