scikit-learnで構築したモデルをシリアライズする
はじめに
11月となり、僕の修士研究*1も大詰めになってきました。
研究ではPythonの機械学習ライブラリであるscikit-learnを使っているのですが、先日、「(1時間かけて)訓練したモデルをシリアライズしておけば、評価用のプログラムと分割できて、色々と捗るのでは?」と思って調べたところ、かなり簡単にできたのでやり方をまとめました。
シリアライズとは
シリアライズ - Wikipediaでは、2つの意味が紹介されていて、今回は後者の方が該当します。以下引用です。
ある環境に存在しているオブジェクトをバイト列やXMLフォーマットに変換すること。この意味では直列化という訳語が用いられる。同義語にMarshallingがある。対義語は直列化復元ないしデシリアライズである。
今回の文脈で言えば、Pythonプログラムの実行中にメモリ上に存在する変数を、バイト列に変換する(したがって、ローカルに保存できる)って感じでしょうか。
ちなみにscikit-learnでは、sklearn.externals.joblib
を使うことで、予測モデルを簡単にシリアライズ・復元することができます。
コード例
3.6. Model persistence — scikit-learn 0.15.2 documentationの例を一部改変したものを掲載します。
一応内容を解説しておくと、おなじみのirisデータセットを読み込んで、サポートベクターマシンで予測モデルを構築し、あやめの品種を予測するというコードです。
from sklearn import svm from sklearn import datasets from sklearn.externals import joblib # SVMを訓練する clf = svm.SVC() iris = datasets.load_iris() X, y = iris.data, iris.target clf.fit(X, y) # 予測結果を出力 print(clf.predict(X)) # 予測モデルをシリアライズ joblib.dump(clf, 'clf.pkl')
このコードを実行すると、次のような結果が出力されます。
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
また、カレントディレクトリにclf.pkl
, clf.pkl_01.npy
, clf.pkl_02.npy
, ... というファイルが生成されたかと思います。
これをjoblib.load
で読み込むことで、予測モデルを復元することができます。
from sklearn import datasets from sklearn.externals import joblib # データセットを再読み込み iris = datasets.load_iris() # 予測モデルを復元 clf = joblib.load('clf.pkl') # 予測結果を出力 print(clf.predict(iris.data))
このコードを実行すると、先程と同じ予測結果が表示されます。
これで、無事復元されたことが確認できました。
おわりに
モデルの訓練に時間がかかるときに便利です。