LIME(Local Interpretable Model-agnostic Explanation) basic example

1 minute read

Open In Colab

# !pip install -q lime
import lime
import sklearn
import numpy as np
import sklearn
import sklearn.ensemble
import sklearn.metrics
from __future__ import print_function

Fetching data, training a classifier

  • 20 newsgroups dataset 사용
  • 테스트를 위해 2-class (atheism, christian) 만 사용
from sklearn.datasets import fetch_20newsgroups
categories = ['alt.atheism', 'soc.religion.christian']
newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
class_names = ['atheism', 'christian']
print(type(newsgroups_train))
<class 'sklearn.utils.Bunch'>
  • sklearn.utils.Bunch class의 에 대한 사용법은?
print(type(newsgroups_train.data))
<class 'list'>
vectorizer = sklearn.feature_extraction.text.TfidfVectorizer(lowercase=False)
train_vectors = vectorizer.fit_transform(newsgroups_train.data)
test_vectors = vectorizer.transform(newsgroups_test.data)
print(type(train_vectors))
<class 'scipy.sparse.csr.csr_matrix'>
print(train_vectors.shape)
print(test_vectors.shape)
(1079, 23035)
(717, 23035)
rf = sklearn.ensemble.RandomForestClassifier(n_estimators=500)
rf.fit(train_vectors, newsgroups_train.target)
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=500, n_jobs=None,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False)
pred = rf.predict(test_vectors)
sklearn.metrics.f1_score(newsgroups_test.target, pred, average='binary')
0.9230769230769231
  • f1_score의 값은 무엇인가?

  • 높은 점수가 나왔네…???

Explaining predictions using lime

from lime import lime_text
from sklearn.pipeline import make_pipeline
c = make_pipeline(vectorizer, rf)
print(c.predict_proba([newsgroups_test.data[0]]))
[[0.286 0.714]]
from lime.lime_text import LimeTextExplainer
explainer = LimeTextExplainer(class_names=class_names)
idx = 77
exp = explainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6)
/home/dkyun77/anaconda3/envs/py36/lib/python3.6/re.py:212: FutureWarning: split() requires a non-empty pattern match.
  return _compile(pattern, flags).split(string, maxsplit)
print('Document id: %d' % idx)
print('Probability(christian) =', c.predict_proba([newsgroups_test.data[idx]])[0,1])
print('True class: %s' % class_names[newsgroups_test.target[idx]])
Document id: 77
Probability(christian) = 0.138
True class: atheism
  • The classifier got this example right (it predicted atheism).
  • The explanation is presented below as a list of weighted features.
exp.as_list()
[('Posting', -0.1179373397649475),
 ('Host', -0.09636158172180675),
 ('NNTP', -0.08182990032857361),
 ('Re', -0.07071158188318627),
 ('Keith', -0.06514820655831027),
 ('In', -0.058947446546993056)]
print('Original prediction:', rf.predict_proba(test_vectors[idx])[0,1])
tmp = test_vectors[idx].copy()
tmp[0,vectorizer.vocabulary_['Posting']] = 0
tmp[0,vectorizer.vocabulary_['Host']] = 0
print('Prediction removing some features:', rf.predict_proba(tmp)[0,1])
print('Difference:', rf.predict_proba(tmp)[0,1] - rf.predict_proba(test_vectors[idx])[0,1])
Original prediction: 0.138
Prediction removing some features: 0.26
Difference: 0.122

Visualizing explanations¶

import matplotlib
matplotlib.rcParams.update({'text.color': 'white'})
matplotlib.style.use('ggplot')
font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 25}                          
matplotlib.rc('font', **font)
%matplotlib inline
fig = exp.as_pyplot_figure()

png

exp.show_in_notebook(text=False)
exp.show_in_notebook(text=True)

Categories: ,

Updated: