Классификация текстов

In [3]:
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
In [4]:
from sklearn.datasets import fetch_20newsgroups
import numpy as np
import gensim.parsing.preprocessing as gp
from sklearn import feature_extraction, metrics
from sklearn import naive_bayes, linear_model, svm
from sklearn.preprocessing import Binarizer

Набор данных 20 newsgroups состоит из множества usenet-постов из 20 тем. Задача заключается в опредении, к какой теме относится пост. Из постов удалены заголовки, подписи и цитаты (на семинаре мы этого не делали, поэтому сейчас результаты будут пореалистичнее). Набор данных встроен в sklearn.

In [5]:
train_data = fetch_20newsgroups(subset='train',remove=('headers', 'footers', 'quotes'))
test_data = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))

Выведем пример текста.

In [6]:
text = train_data.data[0]
print(text)
I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.

Так называются искомые темы:

In [7]:
print(train_data.target_names)
['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']

В этом туториале мы будем использовать bag-of-words представления и их производные. В базовом виде каждый текст $t$ представляется в виде вектора:

\begin{equation*} \vec{v}(t) = [c(w_1), c(w_2), ..., c(w_{|V|}] \end{equation*} где $c(w_i)$ обозначает количество, сколько раз уникальное слово $w_i$ встретилось в тексте (счетчик слова $w_i$), а $|V|$ - общее количество уникальных слов (размер словаря). Словарь наполняется словами из всех текстов и опционально фильтруется. Bag-of-words векторы в подавляющем большинстве случаев разрежены - то есть, практически все их элементы равны нулю.

Если для каждого слова использовать one-hot-кодировку, то bag-of-words - это векторная сумма последовательности кодировок слов из текста.

Чтобы получить BpW-представление, нужно извлечь собственно слова из текста, т.е. провести токенизацию текста. Вообще говоря, слово - условное понятие. Под ним можно подразумевать слова, знаки препинания, группы слов (например, все нецензурные слова можно считать вместе, а не по отдельности) и вообще произвольные счетные признаки.

В обработке естественных языков преобработка (нормализация) текстов играет ключевую роль. Под ней подразумевается различная фильтрация лишних деталей, общие преобразования. Мы воспользуемся встроенной преобработкой из библиотеки gensim и оценим эффекты некоторых стадий.

Переведем текст в единый регистр, удалим html теги, пунктуацию, числа и стоп-слова (часто встречающиеся слова, которые без контекста практически не имеют смысла (пример https://gist.github.com/sebleier/554280))

In [8]:
def gensim_preprocessing1(documents):
    filters = [lambda s: s.lower(), gp.strip_tags, gp.strip_punctuation, gp.strip_numeric, gp.remove_stopwords]
    return [gp.preprocess_string(doc,filters) for doc in documents]
In [10]:
tokens_train = gensim_preprocessing1(train_data.data)

Выведем полученный список токенов из первого текста

In [15]:
print(tokens_train[0])
['wondering', 'enlighten', 'car', 'saw', 'day', 'door', 'sports', 'car', 'looked', 'late', 's', 'early', 's', 'called', 'bricklin', 'doors', 'small', 'addition', 'bumper', 'separate', 'rest', 'body', 'know', 'tellme', 'model', 'engine', 'specs', 'years', 'production', 'car', 'history', 'info', 'funky', 'looking', 'car', 'e', 'mail']
In [14]:
tokens_test = gensim_preprocessing1(test_data.data)

Теперь построим из списков токенов векторы bow. Для этого воспользуемся классом CountVectorizer из sklearn. Вообще, sklearn предоставляет свою ограниченную токенизацию и преобработку, но поскольку мы сделали её сами, заменим соответствующие шаги на ничего не делающие. Также для скорости ограничим словарь 30к самых частых слов (отметим, что мы выбросили стоп-слова).

In [16]:
count_vectorizer = feature_extraction.text.CountVectorizer(preprocessor=lambda x:x,
                                                           tokenizer=lambda x:x,max_features=30000)
# tfidf = feature_extraction.text.TfidfTransformer()
# binarizer = Binarizer()

Натренируем векторизатор (т.е. дадим ему тексты, из которых он выяснит 30к самых частых слов и назначит им номера) и преобразуем тренировочные данные в векторы.

In [17]:
X_train = count_vectorizer.fit_transform(tokens_train)
In [18]:
feature_names = count_vectorizer.get_feature_names()
In [19]:
X_test = count_vectorizer.transform(tokens_test)
#X_test = binarizer.transform(X_test)
In [21]:
print(X_train.shape)
(11314, 30000)

Выведем вектор первого текста. Поскольку в нем 30000 элементов, хранить их все было бы очень затратно. Поэтому хранятся в памяти только ненулевые элементы и их номера. Все векторы вместе образуют sparse-матрицу (X_train и X_test)

In [22]:
print(X_train[0])
  (0, 15657)	1
  (0, 7962)	1
  (0, 15267)	1
  (0, 10416)	1
  (0, 12800)	1
  (0, 11872)	1
  (0, 20502)	1
  (0, 29710)	1
  (0, 24679)	1
  (0, 8457)	1
  (0, 16756)	1
  (0, 14116)	1
  (0, 3016)	1
  (0, 22156)	1
  (0, 23524)	1
  (0, 3463)	1
  (0, 287)	1
  (0, 24329)	1
  (0, 7579)	1
  (0, 3666)	1
  (0, 7975)	1
  (0, 22818)	2
  (0, 14502)	1
  (0, 15265)	1
  (0, 24800)	1
  (0, 7578)	1
  (0, 6339)	1
  (0, 23019)	1
  (0, 3786)	4
  (0, 8483)	1
  (0, 29002)	1

Используя хранимое в CountVectorizer отображение номеров на слова (feature_names), выведем счетчики слов первого текста

In [23]:
for i in X_train[0].indices:
    print(feature_names[i], X_train[0,i])
mail 1
e 1
looking 1
funky 1
info 1
history 1
production 1
years 1
specs 1
engine 1
model 1
know 1
body 1
rest 1
separate 1
bumper 1
addition 1
small 1
doors 1
called 1
early 1
s 2
late 1
looked 1
sports 1
door 1
day 1
saw 1
car 4
enlighten 1
wondering 1

Натренируем на полученных векторах наивный Байесовский классификатор. Используемая реализация NB использует следующую формулу для определения класса: \begin{equation*} P(C|w_1, ..., w_{|V|}) = Z P(C)P(w_1, ..., w_{|V|}|C) = Z P(C)\prod_{i = 1,~C(w_i) \ne 0}^{|V|} P(w_i|C) \end{equation*} Z - нормализующая константа, чтобы вероятности классов суммировались в 1. $P(w_i|C)$ - вероятность появления слова $w_i$ в тексте этого класса, она оценивается как: \begin{equation*} P(w_i|C) = \frac{\sum{c(w_i) + \alpha}}{\sum_j [\sum{c(w_j) + \alpha}]} \end{equation*} где $\sum{c(w_i)}$ общее количество раз, которое данное слово встретилось во всех текстах класса $C$, $\alpha$ - сглаживающая константа, благодаря которой у нас нет нулевых вероятностей.

In [24]:
multi_nb = naive_bayes.MultinomialNB()

multi_nb.fit(X_train, train_data.target)
m_pred = multi_nb.predict(X_test)
print("Multinb: ", metrics.accuracy_score(test_data.target, m_pred))
print(metrics.classification_report(test_data.target, m_pred, target_names = test_data.target_names))

# confusion_matrix = metrics.confusion_matrix(test_data.target, m_pred)
# for i, row in enumerate(confusion_matrix):
#     print(test_data.target_names[i])
#     for j, col in enumerate(row):
#         print(test_data.target_names[j], ":", col, end=' ')
#     print('')
Multinb:  0.6512214551248009
                          precision    recall  f1-score   support

             alt.atheism       0.53      0.48      0.50       319
           comp.graphics       0.56      0.72      0.63       389
 comp.os.ms-windows.misc       0.18      0.01      0.01       394
comp.sys.ibm.pc.hardware       0.48      0.73      0.58       392
   comp.sys.mac.hardware       0.64      0.62      0.63       385
          comp.windows.x       0.68      0.78      0.72       395
            misc.forsale       0.82      0.70      0.75       390
               rec.autos       0.78      0.71      0.74       396
         rec.motorcycles       0.83      0.70      0.76       398
      rec.sport.baseball       0.90      0.81      0.85       397
        rec.sport.hockey       0.58      0.89      0.71       399
               sci.crypt       0.70      0.74      0.72       396
         sci.electronics       0.65      0.52      0.58       393
                 sci.med       0.82      0.77      0.79       396
               sci.space       0.77      0.73      0.75       394
  soc.religion.christian       0.53      0.85      0.66       398
      talk.politics.guns       0.56      0.68      0.61       364
   talk.politics.mideast       0.76      0.75      0.76       376
      talk.politics.misc       0.46      0.44      0.45       310
      talk.religion.misc       0.40      0.13      0.20       251

               micro avg       0.65      0.65      0.65      7532
               macro avg       0.63      0.64      0.62      7532
            weighted avg       0.64      0.65      0.63      7532

Результаты не очень, попробуем добавить стеммизацию (грубую обработку слов по морфологическим правилам, которая сводит большинство форм одного слова в одну), а также удалить все короткие слова.

In [29]:
def gensim_preprocessing2(documents):
    filters = [lambda s: s.lower(), gp.strip_tags, gp.strip_punctuation, 
               gp.strip_numeric, gp.remove_stopwords, gp.strip_short, gp.stem_text]
    return [gp.preprocess_string(doc,filters) for doc in documents]
In [30]:
tokens_stem_train = gensim_preprocessing2(train_data.data)
tokens_stem_test = gensim_preprocessing2(test_data.data)
In [31]:
print(tokens_stem_train[0])
['wonder', 'enlighten', 'car', 'saw', 'dai', 'door', 'sport', 'car', 'look', 'late', 'earli', 'call', 'bricklin', 'door', 'small', 'addit', 'bumper', 'separ', 'rest', 'bodi', 'know', 'tellm', 'model', 'engin', 'spec', 'year', 'product', 'car', 'histori', 'info', 'funki', 'look', 'car', 'mail']
In [34]:
cv_stem = feature_extraction.text.CountVectorizer(preprocessor=lambda x:x,
                                                           tokenizer=lambda x:x,max_features=30000)
X_train_stem = cv_stem.fit_transform(tokens_stem_train)
X_test_stem = cv_stem.transform(tokens_stem_test)
In [36]:
multi_nb_stem = naive_bayes.MultinomialNB()

multi_nb_stem.fit(X_train_stem, train_data.target)
m_pred = multi_nb_stem.predict(X_test_stem)
print("Multinb: ", metrics.accuracy_score(test_data.target, m_pred))
print(metrics.classification_report(test_data.target, m_pred, target_names = test_data.target_names))
Multinb:  0.6533457249070632
                          precision    recall  f1-score   support

             alt.atheism       0.54      0.45      0.49       319
           comp.graphics       0.60      0.71      0.65       389
 comp.os.ms-windows.misc       0.81      0.20      0.32       394
comp.sys.ibm.pc.hardware       0.52      0.67      0.59       392
   comp.sys.mac.hardware       0.65      0.61      0.63       385
          comp.windows.x       0.66      0.77      0.71       395
            misc.forsale       0.82      0.66      0.73       390
               rec.autos       0.77      0.73      0.75       396
         rec.motorcycles       0.84      0.69      0.76       398
      rec.sport.baseball       0.91      0.77      0.83       397
        rec.sport.hockey       0.58      0.90      0.70       399
               sci.crypt       0.63      0.77      0.69       396
         sci.electronics       0.63      0.49      0.55       393
                 sci.med       0.81      0.79      0.80       396
               sci.space       0.76      0.74      0.75       394
  soc.religion.christian       0.52      0.85      0.65       398
      talk.politics.guns       0.55      0.68      0.61       364
   talk.politics.mideast       0.74      0.76      0.75       376
      talk.politics.misc       0.45      0.43      0.44       310
      talk.religion.misc       0.40      0.12      0.19       251

               micro avg       0.65      0.65      0.65      7532
               macro avg       0.66      0.64      0.63      7532
            weighted avg       0.67      0.65      0.64      7532

Результат практически не изменился. Добавим ещё 30000 атрибутов

In [38]:
cv_stem2 = feature_extraction.text.CountVectorizer(preprocessor=lambda x:x,
                                                           tokenizer=lambda x:x,max_features=60000)
X_train_stem2 = cv_stem2.fit_transform(tokens_stem_train)
X_test_stem2 = cv_stem2.transform(tokens_stem_test)

multi_nb_stem2 = naive_bayes.MultinomialNB()

multi_nb_stem2.fit(X_train_stem2, train_data.target)
m_pred = multi_nb_stem2.predict(X_test_stem2)
print("Multinb: ", metrics.accuracy_score(test_data.target, m_pred))
print(metrics.classification_report(test_data.target, m_pred, target_names = test_data.target_names))
Multinb:  0.6427243759957515
                          precision    recall  f1-score   support

             alt.atheism       0.60      0.38      0.46       319
           comp.graphics       0.61      0.71      0.66       389
 comp.os.ms-windows.misc       0.89      0.14      0.25       394
comp.sys.ibm.pc.hardware       0.53      0.66      0.59       392
   comp.sys.mac.hardware       0.70      0.57      0.63       385
          comp.windows.x       0.60      0.79      0.68       395
            misc.forsale       0.84      0.62      0.71       390
               rec.autos       0.80      0.72      0.76       396
         rec.motorcycles       0.88      0.66      0.75       398
      rec.sport.baseball       0.93      0.73      0.82       397
        rec.sport.hockey       0.57      0.91      0.71       399
               sci.crypt       0.55      0.80      0.65       396
         sci.electronics       0.66      0.50      0.57       393
                 sci.med       0.79      0.79      0.79       396
               sci.space       0.75      0.75      0.75       394
  soc.religion.christian       0.49      0.88      0.63       398
      talk.politics.guns       0.55      0.64      0.59       364
   talk.politics.mideast       0.65      0.78      0.71       376
      talk.politics.misc       0.45      0.44      0.44       310
      talk.religion.misc       0.45      0.07      0.12       251

               micro avg       0.64      0.64      0.64      7532
               macro avg       0.66      0.63      0.61      7532
            weighted avg       0.67      0.64      0.63      7532

Взвесим слова при помощи idf-весов (обратная документная частота). Воспользуемся схемой $idf_i = log \frac{N}{df_i}$, где $df_i$ - количество документов, в которых встретилось слово, а $N$ общее количество документов. Таким образом, слова более уникальные для документов имеют больший вес, общераспространенные - меньший. После это мы получаем tf-idf представление вида. Представление также нормализовано по длине, чтобы сгладить эффект более длинных текстов: \begin{equation*} \vec{v}(t) = [idf_1*c(w_1), idf_2*c(w_2), ..., idf_{|V|}*c(w_{|V|})] \end{equation*}

In [43]:
tfidf = feature_extraction.text.TfidfTransformer()
X_train_idf = tfidf.fit_transform(X_train_stem)
X_test_idf = tfidf.transform(X_test_stem)

nb3 = naive_bayes.MultinomialNB()
nb3.fit(X_train_idf, train_data.target)
nb3_pred = nb3.predict(X_test_idf)
print("Multinb: ", metrics.accuracy_score(test_data.target, nb3_pred))
print(metrics.classification_report(test_data.target, nb3_pred, target_names = test_data.target_names))
Multinb:  0.6708709506107275
                          precision    recall  f1-score   support

             alt.atheism       0.75      0.20      0.31       319
           comp.graphics       0.68      0.69      0.68       389
 comp.os.ms-windows.misc       0.66      0.54      0.59       394
comp.sys.ibm.pc.hardware       0.57      0.69      0.62       392
   comp.sys.mac.hardware       0.72      0.66      0.69       385
          comp.windows.x       0.78      0.76      0.77       395
            misc.forsale       0.77      0.71      0.74       390
               rec.autos       0.81      0.73      0.77       396
         rec.motorcycles       0.83      0.72      0.77       398
      rec.sport.baseball       0.93      0.79      0.85       397
        rec.sport.hockey       0.57      0.93      0.71       399
               sci.crypt       0.60      0.79      0.68       396
         sci.electronics       0.70      0.51      0.59       393
                 sci.med       0.83      0.77      0.80       396
               sci.space       0.78      0.76      0.77       394
  soc.religion.christian       0.39      0.92      0.54       398
      talk.politics.guns       0.55      0.73      0.63       364
   talk.politics.mideast       0.83      0.78      0.80       376
      talk.politics.misc       0.85      0.31      0.45       310
      talk.religion.misc       1.00      0.01      0.02       251

               micro avg       0.67      0.67      0.67      7532
               macro avg       0.73      0.65      0.64      7532
            weighted avg       0.72      0.67      0.66      7532

Натренируем также логистическую регрессию на этих же данных.

In [114]:
logr = linear_model.LogisticRegression()
logr.fit(X_train_idf, train_data.target)
logr_pred = logr.predict(X_test_idf)
print("Logr: ", metrics.accuracy_score(test_data.target, logr_pred))
print(metrics.classification_report(test_data.target, logr_pred, target_names = test_data.target_names))
/home/ivan/.pyenv/versions/3.6.4/envs/general36/lib/python3.6/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.
  FutureWarning)
/home/ivan/.pyenv/versions/3.6.4/envs/general36/lib/python3.6/site-packages/sklearn/linear_model/logistic.py:459: FutureWarning: Default multi_class will be changed to 'auto' in 0.22. Specify the multi_class option to silence this warning.
  "this warning.", FutureWarning)
Logr:  0.6808284652150823
                          precision    recall  f1-score   support

             alt.atheism       0.51      0.47      0.49       319
           comp.graphics       0.66      0.69      0.68       389
 comp.os.ms-windows.misc       0.65      0.59      0.62       394
comp.sys.ibm.pc.hardware       0.64      0.62      0.63       392
   comp.sys.mac.hardware       0.72      0.67      0.69       385
          comp.windows.x       0.80      0.70      0.75       395
            misc.forsale       0.70      0.76      0.73       390
               rec.autos       0.78      0.71      0.74       396
         rec.motorcycles       0.49      0.78      0.60       398
      rec.sport.baseball       0.79      0.80      0.80       397
        rec.sport.hockey       0.89      0.88      0.89       399
               sci.crypt       0.87      0.69      0.77       396
         sci.electronics       0.54      0.60      0.57       393
                 sci.med       0.75      0.80      0.77       396
               sci.space       0.73      0.73      0.73       394
  soc.religion.christian       0.62      0.79      0.70       398
      talk.politics.guns       0.58      0.66      0.62       364
   talk.politics.mideast       0.85      0.75      0.79       376
      talk.politics.misc       0.57      0.45      0.50       310
      talk.religion.misc       0.50      0.19      0.27       251

               micro avg       0.68      0.68      0.68      7532
               macro avg       0.68      0.67      0.67      7532
            weighted avg       0.69      0.68      0.68      7532

Для каждого класса также выведем наиболее важные в положительном и отрицательном смысле слова, с точки зрения логистической регресии. На каждый класс натренирована линейная функция и выбирается тот класс, для которого соответствующая функция вернула максимальное значение (см. softmax-регрессия).

In [115]:
print(logr.coef_.shape)
(20, 30000)
In [117]:
words = cv_stem.get_feature_names()
for i in range(20):
    indices = logr.coef_[i].argsort()
    least = indices[:10]
    most = indices[-10:]
    
    print(train_data.target_names[i])
    print('Most important(+)', ' '.join([words[j] for j in most]))
    print('Most important(-)', ' '.join([words[j] for j in least]))
alt.atheism
Most important(+) post delet punish bobbi motto moral atheism religion atheist islam
Most important(-) window thank us want game work mail christ drive christian
comp.graphics
Most important(+) format anim cview tiff program algorithm file polygon imag graphic
Most important(-) peopl drive kei car window right god win monitor believ
comp.os.ms-windows.misc
Most important(+) risc problem microsoft font cica win max driver file window
Most important(-) year car bit game time power server state peopl sale
comp.sys.ibm.pc.hardware
Most important(+) gatewai vlb motherboard irq id drive card bu monitor scsi
Most important(-) mac peopl car appl year window case god includ offer
comp.sys.mac.hardware
Most important(+) lciii nubu problem duo powerbook simm quadra centri appl mac
Most important(-) window do car control id god file game com year
comp.windows.x
Most important(+) displai applic sun client mit window xterm server motif widget
Most important(-) do card drive mac god driver car game lot post
misc.forsale
Most important(+) price interest new email condit includ ship sell offer sale
Most important(-) think know help read peopl problem file appreci team sure
rec.autos
Most important(+) wheel road toyota wagon engin dealer oil auto ford car
Most important(-) bike game file card program god team christian plai softwar
rec.motorcycles
Most important(+) biker harlei dog rider bmw helmet motorcycl dod ride bike
Most important(-) card window game us program kei believ team christian file
rec.sport.baseball
Most important(+) philli bat brave stadium cub pitcher hit year pitch basebal
Most important(-) hockei window us car peopl file work nhl playoff problem
rec.sport.hockey
Most important(+) coach player season leaf nhl playoff plai game team hockei
Most important(-) run us work window drive pitch file god problem program
sci.crypt
Most important(+) clinton privaci phone govern chip secur nsa clipper encrypt kei
Most important(-) window drive god thank car problem card christian help game
sci.electronics
Most important(+) scope detector wire signal ground amp power voltag electron circuit
Most important(-) window file peopl mac bike think govern year god encrypt
sci.med
Most important(+) cancer effect pain treatment food patient medic diseas msg doctor
Most important(-) god window car game drive christian file govern card space
sci.space
Most important(+) satellit earth rocket shuttl spacecraft moon nasa launch orbit space
Most important(-) window car drive game god card chip kei run bike
soc.religion.christian
Most important(+) resurrect jesu marriag faith cathol christ sin god christian church
Most important(-) window game car run file us drive right atheism team
talk.politics.guns
Most important(+) law handheld crimin nra jmd batf fbi firearm weapon gun
Most important(-) god window game armenian kei problem program christian encrypt israel
talk.politics.mideast
Most important(+) kill turkei greek palestinian turkish jew arab armenian isra israel
Most important(-) look us thank window game god jesu work gun drive
talk.politics.misc
Most important(+) gai state trial govern clinton homosexu drug presid libertarian tax
Most important(-) god christian gun window chip thank kei team game armenian
talk.religion.misc
Most important(+) tyre mormon god rosicrucian object jesu moral kent koresh christian
Most important(-) thank need edu window problem drive file car work game