Overview

This code builds very directly on https://paws-public.wmflabs.org/paws-public/User:Abdulwd/predict-mid-level-category.ipynb (accessed 19 Aug 2019). It makes two changes to that code:

  • Fixes a bug due to improper use of MultiLabelBinarizer where the class labels were being tokenized by character instead of WikiProject, so the labels the model was predicting were e.g., 'G', 'e', ..., 'i', 'a' as opposed to 'Geography.Oceania'. Coincidentally there were about as many unique characters as WikiProject mid-level categories and still sufficient connection between the individuals characters and true labels that the bug was not immediately evident.
  • Improved labeling for model statistics (which is what led to discovering the bug)

Imports, parameters, etc.

import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, roc_curve, auc
# JSON file with wikidata QIDs, claims, ENWP title, and associated mid-level categories
data_json_fn = "./drafttopic_wditem_dataset.json"
# If true, Wikidata properties and values are treated separately
# If false, Wikidata statements are retained whole
prop_val_sep = True
# maximum number of unique property/value to model
max_features = 10000

Prepare data

# load data
data = pd.read_json(data_json_fn, lines=True)
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 84297 entries, 0 to 84296
Data columns (total 4 columns):
QID                     84297 non-null object
claims                  84297 non-null object
mid_level_categories    84297 non-null object
title                   84297 non-null object
dtypes: object(4)
memory usage: 2.6+ MB
# delete column used for reference and reformat data
data.drop(labels=["QID", "title"], axis=1, inplace=True)
if prop_val_sep:
    join_char = ' '
else:
    join_char = ':'
x_corpus = data["claims"].apply(
    lambda row: " ".join([join_char.join(pair) for pair in row]))
y_corpus = data["mid_level_categories"]
x_corpus.head()
0    P2184 Q193315 P61 Q92743 P61 Q62843 P910 Q4049...
1    P910 Q9313304 P131 Q18241891 P131 Q15123 P4552...
2    P559 Q4918 P559 Q168277 P131 Q31070 P1343 Q602...
3    P2579 Q23399 P910 Q7144831 P279 Q11028 P279 Q9...
4    P17 Q39 P4552 Q609634 P31 Q46831 P910 Q2524154...
Name: claims, dtype: object
y_corpus.head()
0    [History_And_Society.History and society, Cult...
1    [History_And_Society.History and society, STEM...
2    [Geography.Countries, STEM.Science, History_An...
3    [History_And_Society.History and society, Cult...
4              [Geography.Landforms, Geography.Europe]
Name: mid_level_categories, dtype: object
# convert data to model-ready data
claim_vectorizer = CountVectorizer(max_features=max_features)
x_train = claim_vectorizer.fit_transform(x_corpus)

label_binarizer = MultiLabelBinarizer()
y_train = label_binarizer.fit_transform(y_corpus)
# split data for training, validation, test
X_train_sub, X_validation_sub, y_train_sub, y_validation_sub = train_test_split(x_train, y_train, random_state=0)

Build/train model

# fit model
rfc = RandomForestClassifier()
rfc.fit(X_train_sub, y_train_sub)
/srv/paws/lib/python3.6/site-packages/sklearn/ensemble/forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.
  "10 in version 0.20 to 100 in 0.22.", FutureWarning)
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
                       max_depth=None, max_features='auto', max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=10,
                       n_jobs=None, oob_score=False, random_state=None,
                       verbose=0, warm_start=False)

Model Statistics

# examine most important features
n_top_features = 10
feat_import = np.argsort(rfc.feature_importances_)[::-1][:n_top_features]
feat_names = claim_vectorizer.get_feature_names()
property_names = {'P17':'country',
                  'Q5':'human',
                  'P31':'instance_of',
                  'P131':'located in the administrative territorial entity',
                  'P19':'place of birth',
                  'Q6581097':'male',
                  'Q11173':'chemical compound',
                  'Q4167410':'Wikimedia disambiguation page',
                  'P175':'performer',
                  'P641':'sport',
                  'P106':'occupation', 
                  'Q30':'United States of America',
                  'P735':'given name',
                  'P27':'country of citizenship',
                  'P21':'sex or gender',
                  'Q145':'United Kingdom',
                  'Q408':'Australia'}
for p in property_names:
    property_names[p] = property_names[p] + ' ({0})'.format(p)
for i in feat_import:
    print('{1:.4f}\t{0}'.format(property_names.get(feat_names[i].upper(),
                                                   feat_names[i].upper()),
                                rfc.feature_importances_[i]))
0.0403	sex or gender (P21)
0.0170	country (P17)
0.0140	place of birth (P19)
0.0120	United States of America (Q30)
0.0119	male (Q6581097)
0.0117	instance_of (P31)
0.0116	Wikimedia disambiguation page (Q4167410)
0.0094	performer (P175)
0.0086	United Kingdom (Q145)
0.0085	Australia (Q408)
# gather predictions for model evaluation
predictions = rfc.predict(X_validation_sub)
# build statistics dataframe for printing
preds = pd.DataFrame(predictions, columns=label_binarizer.classes_)
actuals = pd.DataFrame(y_validation_sub, columns=label_binarizer.classes_)
trues = preds + actuals
falses = preds - actuals
tps = trues.applymap(lambda x: 1 if x == 2 else 0).sum(axis=0)
tps.name = 'TP'
fps = falses.applymap(lambda x: 1 if x == 1 else 0).sum(axis=0)
fps.name = 'FP'
fns = falses.applymap(lambda x: 1 if x == -1 else 0).sum(axis=0)
fns.name = 'FN'
tns = trues.applymap(lambda x: 1 if x == 0 else 0).sum(axis=0)
tns.name = 'TN'
ns = actuals.sum(axis=0)
ns.name = 'n'
precisions = tps / (tps + fps)
precisions.name = 'precision'
recalls = tps / (tps + fns)
recalls.name = 'recall'
f1s = 2 * (precisions * recalls) / (precisions + recalls)
f1s.name = 'f1'
formatting = ns.apply(lambda x: '-->')
formatting.name = ''
# order from here: https://github.com/wikimedia/drafttopic/blob/master/model_info/enwiki.drafttopic.md
# allows for easier comparison with word-embeddings model
class_order = ['Geography.Oceania',
               'STEM.Mathematics',
               'STEM.Science',
               'STEM.Meteorology',
               'Culture.Sports',
               'Culture.Performing arts',
               'Culture.Entertainment',
               'Assistance.Article improvement and grading',
               'Culture.Language and literature',
               'Culture.Visual arts',
               'STEM.Biology',
               'History_And_Society.Business and economics',
               'Assistance.Files',
               'History_And_Society.History and society',
               'STEM.Medicine',
               'Culture.Crafts and hobbies',
               'STEM.Geosciences',
               'Culture.Food and drink',
               'History_And_Society.Transportation',
               'Geography.Cities',
               'Geography.Landforms',
               'Assistance.Maintenance',
               'STEM.Information science',
               'STEM.Time',
               'Geography.Europe',
               'STEM.Engineering',
               'Culture.Media',
               'STEM.Technology',
               'STEM.Space',
               'History_And_Society.Education',
               'Geography.Countries',
               'History_And_Society.Military and warfare',
               'Culture.Plastic arts',
               'STEM.Physics',
               'History_And_Society.Politics and government',
               'STEM.Chemistry',
               'Culture.Broadcasting',
               'Geography.Maps',
               'Culture.Arts',
               'Culture.Internet culture',
               'Geography.Bodies of water',
               'Assistance.Contents systems',
               'Culture.Philosophy and religion']
statistics = pd.concat([ns, formatting, tps, fps, fns, tns, precisions, recalls, f1s], axis=1).fillna(0)
statistics = statistics.reindex(class_order)
print("Statistics:")
print("counts (n={0})".format(len(preds)))
display(statistics[['n','','TP','FP','FN','TN']])
Statistics:
counts (n=21075)
n TP FP FN TN
Geography.Oceania 1000 --> 501 50 499 20025
STEM.Mathematics 407 --> 159 11 248 20657
STEM.Science 538 --> 93 40 445 20497
STEM.Meteorology 493 --> 220 23 273 20559
Culture.Sports 1228 --> 976 56 252 19791
Culture.Performing arts 1094 --> 836 33 258 19948
Culture.Entertainment 1367 --> 829 56 538 19652
Assistance.Article improvement and grading 12 --> 2 0 10 21063
Culture.Language and literature 4851 --> 4402 166 449 16058
Culture.Visual arts 1115 --> 521 114 594 19846
STEM.Biology 817 --> 440 44 377 20214
History_And_Society.Business and economics 1476 --> 536 206 940 19393
Assistance.Files 78 --> 1 8 77 20989
History_And_Society.History and society 1796 --> 282 159 1514 19120
STEM.Medicine 472 --> 120 35 352 20568
Culture.Crafts and hobbies 506 --> 110 19 396 20550
STEM.Geosciences 475 --> 201 23 274 20577
Culture.Food and drink 567 --> 143 49 424 20459
History_And_Society.Transportation 914 --> 480 44 434 20117
Geography.Cities 201 --> 133 27 68 20847
Geography.Landforms 527 --> 466 30 61 20518
Assistance.Maintenance 1507 --> 834 200 673 19368
STEM.Information science 475 --> 170 22 305 20578
STEM.Time 608 --> 459 27 149 20440
Geography.Europe 3790 --> 1902 505 1888 16780
STEM.Engineering 493 --> 46 43 447 20539
Culture.Media 481 --> 94 39 387 20555
STEM.Technology 957 --> 152 121 805 19997
STEM.Space 561 --> 403 17 158 20497
History_And_Society.Education 632 --> 350 40 282 20403
Geography.Countries 5807 --> 3306 998 2501 14270
History_And_Society.Military and warfare 965 --> 404 77 561 20033
Culture.Plastic arts 902 --> 498 57 404 20116
STEM.Physics 594 --> 77 50 517 20431
History_And_Society.Politics and government 948 --> 228 94 720 20033
STEM.Chemistry 535 --> 292 33 243 20507
Culture.Broadcasting 642 --> 360 41 282 20392
Geography.Maps 393 --> 75 13 318 20669
Culture.Arts 473 --> 329 20 144 20582
Culture.Internet culture 457 --> 81 42 376 20576
Geography.Bodies of water 518 --> 433 11 85 20546
Assistance.Contents systems 472 --> 217 53 255 20550
Culture.Philosophy and religion 946 --> 191 121 755 20008
class_order = ['History_And_Society.Education',
               'STEM.Geosciences',
               'Culture.Language and literature',
               'Assistance.Maintenance',
               'STEM.Technology',
               'Geography.Cities',
               'Culture.Sports',
               'STEM.Chemistry',
               'STEM.Physics',
               'Culture.Broadcasting',
               'Assistance.Contents systems',
               'Geography.Oceania',
               'Assistance.Files',
               'Geography.Maps',
               'Assistance.Article improvement and grading',
               'Geography.Landforms',
               'Culture.Visual arts',
               'STEM.Medicine',
               'Culture.Plastic arts',
               'Culture.Arts',
               'Culture.Food and drink',
               'STEM.Information science',
               'STEM.Engineering',
               'Culture.Philosophy and religion',
               'STEM.Science',
               'Culture.Crafts and hobbies',
               'History_And_Society.Business and economics',
               'Geography.Countries',
               'STEM.Time',
               'STEM.Biology',
               'History_And_Society.Transportation',
               'STEM.Meteorology',
               'History_And_Society.Politics and government',
               'Culture.Internet culture',
               'History_And_Society.Military and warfare',
               'Culture.Media',
               'STEM.Mathematics',
               'STEM.Space',
               'Culture.Performing arts',
               'Geography.Bodies of water',
               'Geography.Europe',
               'History_And_Society.History and society',
               'Culture.Entertainment']
statistics = statistics.reindex(class_order)

micro = np.average(statistics['precision'], weights=statistics['n'])
macro = np.average(statistics['precision'])
print("precision (micro={0:.3f}, macro={1:.3f})".format(micro, macro))
display(statistics['precision'])
precision (micro=0.823, macro=0.813)
History_And_Society.Education                  0.897436
STEM.Geosciences                               0.897321
Culture.Language and literature                0.963660
Assistance.Maintenance                         0.806576
STEM.Technology                                0.556777
Geography.Cities                               0.831250
Culture.Sports                                 0.945736
STEM.Chemistry                                 0.898462
STEM.Physics                                   0.606299
Culture.Broadcasting                           0.897756
Assistance.Contents systems                    0.803704
Geography.Oceania                              0.909256
Assistance.Files                               0.111111
Geography.Maps                                 0.852273
Assistance.Article improvement and grading     1.000000
Geography.Landforms                            0.939516
Culture.Visual arts                            0.820472
STEM.Medicine                                  0.774194
Culture.Plastic arts                           0.897297
Culture.Arts                                   0.942693
Culture.Food and drink                         0.744792
STEM.Information science                       0.885417
STEM.Engineering                               0.516854
Culture.Philosophy and religion                0.612179
STEM.Science                                   0.699248
Culture.Crafts and hobbies                     0.852713
History_And_Society.Business and economics     0.722372
Geography.Countries                            0.768123
STEM.Time                                      0.944444
STEM.Biology                                   0.909091
History_And_Society.Transportation             0.916031
STEM.Meteorology                               0.905350
History_And_Society.Politics and government    0.708075
Culture.Internet culture                       0.658537
History_And_Society.Military and warfare       0.839917
Culture.Media                                  0.706767
STEM.Mathematics                               0.935294
STEM.Space                                     0.959524
Culture.Performing arts                        0.962025
Geography.Bodies of water                      0.975225
Geography.Europe                               0.790195
History_And_Society.History and society        0.639456
Culture.Entertainment                          0.936723
Name: precision, dtype: float64
micro = np.average(statistics['f1'], weights=statistics['n'])
macro = np.average(statistics['f1'])
print("f1 (micro={0:.3f}, macro={1:.3f})".format(micro, macro))
display(statistics['f1'])
f1 (micro=0.618, macro=0.549)
History_And_Society.Education                  0.684932
STEM.Geosciences                               0.575107
Culture.Language and literature                0.934706
Assistance.Maintenance                         0.656434
STEM.Technology                                0.247154
Geography.Cities                               0.736842
Culture.Sports                                 0.863717
STEM.Chemistry                                 0.679070
STEM.Physics                                   0.213592
Culture.Broadcasting                           0.690316
Assistance.Contents systems                    0.584906
Geography.Oceania                              0.646035
Assistance.Files                               0.022989
Geography.Maps                                 0.311850
Assistance.Article improvement and grading     0.285714
Geography.Landforms                            0.911046
Culture.Visual arts                            0.595429
STEM.Medicine                                  0.382775
Culture.Plastic arts                           0.683596
Culture.Arts                                   0.800487
Culture.Food and drink                         0.376812
STEM.Information science                       0.509745
STEM.Engineering                               0.158076
Culture.Philosophy and religion                0.303657
STEM.Science                                   0.277198
Culture.Crafts and hobbies                     0.346457
History_And_Society.Business and economics     0.483318
Geography.Countries                            0.653941
STEM.Time                                      0.839122
STEM.Biology                                   0.676403
History_And_Society.Transportation             0.667594
STEM.Meteorology                               0.597826
History_And_Society.Politics and government    0.359055
Culture.Internet culture                       0.279310
History_And_Society.Military and warfare       0.558783
Culture.Media                                  0.306189
STEM.Mathematics                               0.551127
STEM.Space                                     0.821611
Culture.Performing arts                        0.851758
Geography.Bodies of water                      0.900208
Geography.Europe                               0.613845
History_And_Society.History and society        0.252123
Culture.Entertainment                          0.736234
Name: f1, dtype: float64

Dump model if desired

with open("./wd_draftopic_rfc.model", 'wb') as fout:
    pickle.dump(file=fout, obj=rfc)