scikit-learnで構築したモデルをシリアライズする

はじめに

11月となり、僕の修士研究*1も大詰めになってきました。
研究ではPython機械学習ライブラリであるscikit-learnを使っているのですが、先日、「(1時間かけて)訓練したモデルをシリアライズしておけば、評価用のプログラムと分割できて、色々と捗るのでは?」と思って調べたところ、かなり簡単にできたのでやり方をまとめました。

シリアライズとは

シリアライズ - Wikipediaでは、2つの意味が紹介されていて、今回は後者の方が該当します。以下引用です。

ある環境に存在しているオブジェクトをバイト列やXMLフォーマットに変換すること。この意味では直列化という訳語が用いられる。同義語にMarshallingがある。対義語は直列化復元ないしデシリアライズである。

今回の文脈で言えば、Pythonプログラムの実行中にメモリ上に存在する変数を、バイト列に変換する(したがって、ローカルに保存できる)って感じでしょうか。
ちなみにscikit-learnでは、sklearn.externals.joblibを使うことで、予測モデルを簡単にシリアライズ・復元することができます。

コード例

3.6. Model persistence — scikit-learn 0.15.2 documentationの例を一部改変したものを掲載します。
一応内容を解説しておくと、おなじみのirisデータセットを読み込んで、サポートベクターマシンで予測モデルを構築し、あやめの品種を予測するというコードです。

from sklearn import svm
from sklearn import datasets
from sklearn.externals import joblib

# SVMを訓練する
clf = svm.SVC()
iris = datasets.load_iris()
X, y = iris.data, iris.target
clf.fit(X, y)

# 予測結果を出力
print(clf.predict(X))

# 予測モデルをシリアライズ
joblib.dump(clf, 'clf.pkl') 

このコードを実行すると、次のような結果が出力されます。

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 2 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]

また、カレントディレクトリにclf.pkl, clf.pkl_01.npy, clf.pkl_02.npy, ... というファイルが生成されたかと思います。
これをjoblib.loadで読み込むことで、予測モデルを復元することができます。

from sklearn import datasets
from sklearn.externals import joblib

# データセットを再読み込み
iris = datasets.load_iris()

# 予測モデルを復元
clf = joblib.load('clf.pkl') 

# 予測結果を出力
print(clf.predict(iris.data))

このコードを実行すると、先程と同じ予測結果が表示されます。
これで、無事復元されたことが確認できました。

おわりに

モデルの訓練に時間がかかるときに便利です。

参考