日々の学びと煩悩

Dashで機械学習ができるWebアプリを作る [Step3]

Dashを使って、機械学習をさせるWebアプリケーション作ろうStep3!

これまでの流れはこちら

流れ(おさらい)

Step1: タイトルとデータアップロードの枠だけ作り、cssでいい感じに表示させる

wimper-1996.hatenablog.com

Step2: 単純な線形モデルを作り、グラフとスコアを表示させる

wimper-1996.hatenablog.com

今回はここ↓

Step3: 複数のモデルを作り、ドロップダウンで選べるようにする

Step4: ファイルアップロード機能をつける

Step5: アップロードしたデータを読み込んで学習させるようにする


Step3での完成イメージ

f:id:wimper_1996:20191105004211g:plain

ドロップダウンで選択したモデルに応じて、結果が変わる!! (いや当たり前)

単純なアプリだけど、インプットに応じてアウトプットが変わるなんて、動的な、立派なWebアプリケーション!!…なはず。

いいんだよまずは簡単なもんで。やってみよう。

さて、ユーザーが選択したドロップダウンの中身(今回は学習モデル)に応じて、出力させる結果をインタラクティブに変えたい。どうやってやるのか?

Dashの重要な機能であるコールバック: Callbackの登場です。

今回は、最初から全部のコードを載せてしまう。↓


さあ、できるだけ噛み砕いて説明します頑張ります



Callbackって何

@app.callbackデコレータ*1により記述される、ユーザーのインプットとそれに応じたアウトプットを明示的に記すための記法。

Dash apps are made interactive through Dash Callbacks: Python functions that are automatically called whenever an input component's property changes. Callbacks can be chained, allowing one update in the UI to trigger several updates across the app.

引用元:https://dash.plot.ly/

インプットに応じて自動でアウトプットが変わる様子は、エクセルのマクロに例えるとわかりやすい?↓

It's sort of like programming with Microsoft Excel: whenever an input cell changes, all of the cells that depend on that cell will get updated automatically. This is called "Reactive Programming".

引用元: https://dash.plot.ly/getting-started-part-2

callbackに必要な引数はたった2つ、component_idと、component_property

component_idは、callbackする(される)要素を見分ける識別子みたいなもので、component_propertyは、callbackする(される)中身の性質を表す*2

つまり今回の場合、まずinputは、

  • ID "model-dropdown" を持つドロップダウン要素(dcc.Dropdown)の "value" プロパティ

(app.layoutの中のdcc.Dropdown要素の中のプロパティvalue部分。もう一つのプロパティoptionsの中のvalueに値として引き渡される)

outputは3つ、

  • ID "rmse-sentense"を持つ、html.H3要素の "children" プロパティ
  • ID "r2-sentence" を持つ、html.H3要素の"children" プロパティ
  • ID "residual-plot"を持つ、dcc.Graph要素の "figure" プロパティ

図示するとこんな感じ

f:id:wimper_1996:20191104234848p:plain

ちなみに、dcc.Dropdown要素の中のvalueに、すでに値Linear Regressionが入っていると思いますが、これはドロップダウンに表示される初期値。

今回は、初期値として線形モデルがvalueとしてcallback関数にinputされている。

さて、つらつらと書いてきましたが、callback関数でInputとOutputの中身を明示したら、

関数update_result(model_name):を定義し、

引数としてInputの中身を受けとり、returnとしてoutputの内容を記述します。

前回Step2で線形モデルを作ったと思いますが、その部分をごっそりcallback関数以下のupdate_result関数内に持っていき、

具体的にLinearRegression()と定義していたのを

# 学習の実行
    lr = LinearRegression()
    lr.fit(x_tr, y_tr)

変数model_nameで置き換えたりすればOKです。

# 学習の実行
        model = models[model_name]
        model.fit(x_tr, y_tr)

そしたら、完成イメージの出来上がり!

結果について

線形回帰だと説明力が72%だったのが、ランダムフォレスト回帰だと84%に、
RMSEは4.77から3.77になった!

誤差が小さくなったってことだから、残渣プロットも、y=0近傍により近づいて分布している(分散が小さくなった)ことが分かりますね。

きっとXGboostだともっと予測精度が上がるんだろうなぁとか、

ランダムフォレストを選択した場合は木の深さなどのパラメータも自分で動かせるようにもできるんだろうなぁとか、できることが膨らむ膨らむ。

最後に

Dashではよく、dcc.Graph要素のfigureプロパティや、html要素中のchildrenでアプリの見た目をアップデートさせます。

でもそれだけじゃなくて、styledcc.Dropdown要素のoptionsなど、(多分)ほぼ全てのプロパティが変更可能!

公式ドキュメントは英語ですが、本当に丁寧に記述されていてわかりやすく、例も豊富なのでぜひみてみてください。

callbackは、Dashアプリでインタラクティブな操作を実現するために(最も)重要な機能。

ので、しっかり使いこなせるようにしたいところ。

今回は、Step1で作っていたデータのアップロード部分の枠組みは一度置いといて、callbackを使ったドロップダウン機能を追加しました〜

次回は、このcallback関数を使って、アップロードしたデータをデータテーブルとして表示させます!

*1:デコレータなんてDashでしか使ったことないし、実態は分かっていません。強そう。

*2:慣れてきたら、id→propertyの順番で引数は省略して書くのが一般的