Dashで機械学習ができるWebアプリを作る [Step3]
Dashを使って、機械学習をさせるWebアプリケーション作ろうStep3!
これまでの流れはこちら
流れ(おさらい)
Step1: タイトルとデータアップロードの枠だけ作り、cssでいい感じに表示させる
Step2: 単純な線形モデルを作り、グラフとスコアを表示させる
今回はここ↓
Step3: 複数のモデルを作り、ドロップダウンで選べるようにする
Step4: ファイルアップロード機能をつける
Step5: アップロードしたデータを読み込んで学習させるようにする
Step3での完成イメージ
ドロップダウンで選択したモデルに応じて、結果が変わる!! (いや当たり前)
単純なアプリだけど、インプットに応じてアウトプットが変わるなんて、動的な、立派なWebアプリケーション!!…なはず。
いいんだよまずは簡単なもんで。やってみよう。
さて、ユーザーが選択したドロップダウンの中身(今回は学習モデル)に応じて、出力させる結果をインタラクティブに変えたい。どうやってやるのか?
Dashの重要な機能であるコールバック: Callbackの登場です。
今回は、最初から全部のコードを載せてしまう。↓
さあ、できるだけ噛み砕いて説明します頑張ります
Callbackって何
@app.callback
デコレータ*1により記述される、ユーザーのインプットとそれに応じたアウトプットを明示的に記すための記法。
Dash apps are made interactive through Dash Callbacks: Python functions that are automatically called whenever an input component's property changes. Callbacks can be chained, allowing one update in the UI to trigger several updates across the app.
インプットに応じて自動でアウトプットが変わる様子は、エクセルのマクロに例えるとわかりやすい?↓
It's sort of like programming with Microsoft Excel: whenever an input cell changes, all of the cells that depend on that cell will get updated automatically. This is called "Reactive Programming".
引用元: https://dash.plot.ly/getting-started-part-2
callbackに必要な引数はたった2つ、component_id
と、component_property
。
component_id
は、callbackする(される)要素を見分ける識別子みたいなもので、component_property
は、callbackする(される)中身の性質を表す*2。
つまり今回の場合、まずinputは、
- ID "
model-dropdown
" を持つドロップダウン要素(dcc.Dropdown)の "value
" プロパティ
(app.layout
の中のdcc.Dropdown
要素の中のプロパティvalue
部分。もう一つのプロパティoptions
の中のvalue
に値として引き渡される)
outputは3つ、
- ID "
rmse-sentense
"を持つ、html.H3要素の "children
" プロパティ - ID "
r2-sentence
" を持つ、html.H3要素の"children
" プロパティ - ID "
residual-plot
"を持つ、dcc.Graph要素の "figure
" プロパティ
図示するとこんな感じ
ちなみに、dcc.Dropdown
要素の中のvalue
に、すでに値Linear Regression
が入っていると思いますが、これはドロップダウンに表示される初期値。
今回は、初期値として線形モデルがvalueとしてcallback関数にinputされている。
さて、つらつらと書いてきましたが、callback関数でInputとOutputの中身を明示したら、
関数update_result(model_name):
を定義し、
引数としてInputの中身を受けとり、returnとしてoutputの内容を記述します。
前回Step2で線形モデルを作ったと思いますが、その部分をごっそりcallback関数以下のupdate_result
関数内に持っていき、
具体的にLinearRegression()と定義していたのを
# 学習の実行
lr = LinearRegression()
lr.fit(x_tr, y_tr)
変数model_name
で置き換えたりすればOKです。
# 学習の実行
model = models[model_name]
model.fit(x_tr, y_tr)
そしたら、完成イメージの出来上がり!
結果について
線形回帰だと説明力が72%だったのが、ランダムフォレスト回帰だと84%に、
RMSEは4.77から3.77になった!
誤差が小さくなったってことだから、残渣プロットも、y=0近傍により近づいて分布している(分散が小さくなった)ことが分かりますね。
きっとXGboostだともっと予測精度が上がるんだろうなぁとか、
ランダムフォレストを選択した場合は木の深さなどのパラメータも自分で動かせるようにもできるんだろうなぁとか、できることが膨らむ膨らむ。
最後に
Dashではよく、dcc.Graph
要素のfigure
プロパティや、html要素中のchildren
でアプリの見た目をアップデートさせます。
でもそれだけじゃなくて、style
やdcc.Dropdown
要素のoptions
など、(多分)ほぼ全てのプロパティが変更可能!
公式ドキュメントは英語ですが、本当に丁寧に記述されていてわかりやすく、例も豊富なのでぜひみてみてください。
callbackは、Dashアプリでインタラクティブな操作を実現するために(最も)重要な機能。
ので、しっかり使いこなせるようにしたいところ。
今回は、Step1で作っていたデータのアップロード部分の枠組みは一度置いといて、callbackを使ったドロップダウン機能を追加しました〜
次回は、このcallback関数を使って、アップロードしたデータをデータテーブルとして表示させます!