日々の学びと煩悩

Dashで機械学習ができるWebアプリを作る [Step2]

Dashを使って、機械学習をさせるWebアプリケーションを作ろうStep2です!

流れはこちら

流れ(おさらい)

Step1: タイトルとアップロードの部分だけ作って、cssでいい感じに表示させる

wimper-1996.hatenablog.com

今回はここ↓
Step2: データを読み込み、単純な線形モデルを作り残渣プロットとスコアを表示させる

Step3: 複数のモデルを作り、ドロップダウンで選べるようにする

Step4: ファイルアップロード機能をつける (アップロードとデータの学習は別々)

Step5: アップロードしたデータを読み込んで学習させる、つまりアップロードと学習を連携させる


前回は、htmlとcssファイル無しでpythonを使ってレイアウトを記述し、Webページとして表示させました

今回はStep2「データを読み込んで線形回帰モデルで予測し、結果をグラフ化」していきます

データは、「[第二版] Python機械学習プログラミング 達人データサイエンティストによる理論と実践」で使われているこちらのデータを、csvに変換してさらにカラムを加えています。

予測値は、住宅価格のの中央値(MEDV)で、一番右の列です

まずは、このデータを読み込んで、線形回帰モデルを作ります

モデルの精度評価として、クロスバリデーションでR2スコア (モデルの説明力)とRMSE (最小二乗誤差)を計算します

import dash
import dash_core_components as dcc
import dash_html_components as html

import pandas as pd
import numpy as np
# グラフ描写モジュール、plotlyを使う
import plotly.graph_objs as go

from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import KFold

# データの読み込みから学習まで

# 全体に適用するスタイル
common_style = {'font-family': 'Comic Sans MS', 'textAlign': 'center', 'margin': '0 auto'}

# アプリの実態(インスタンス)を定義
app = dash.Dash(__name__)

# データをインポート
df = pd.read_csv('housing_data.csv')
X_train = df.iloc[:, :-1]
y_train = df.iloc[:, -1]

rmse_scores = []
r2_scores = []

# クロスバリデーションでモデルを評価する
kf = KFold(n_splits=4, shuffle=True, random_state=0)
for tr_idx, val_idx in kf.split(X_train):
    x_tr, x_val = X_train.iloc[tr_idx], X_train.iloc[val_idx]
    y_tr, y_val = y_train.iloc[tr_idx], y_train.iloc[val_idx]

    # 学習の実行
    lr = LinearRegression()
    lr.fit(x_tr, y_tr)

    y_val_pred = lr.predict(x_val)
    y_tr_pred = lr.predict(x_tr)

    rmse_score = np.sqrt(mean_squared_error(y_val, y_val_pred))
    rmse_scores.append(rmse_score)

    r2_score_ = r2_score(y_val, y_val_pred)
    r2_scores.append(r2_score_)

# 各foldのスコア平均
avg_rmse_score = np.mean(rmse_scores)
avg_r2_score = np.mean(r2_scores)

さて、最初にapp = dash.Dash(__name__)でアプリを定義しています。

あとは単純にデータを読み込み、全てのデータを使って住宅価格(MEDV)を予測する単純な線形回帰モデルを作りました。

これだけじゃ、何もWebに出力されず「え、終わり????」となってしまうため、

予測値と実測値の差である残渣をプロットして、またスコアも表示させましょう

# アプリの見た目の記述
app.layout = html.Div(
    html.Div([
        html.H1('Dash Machine Learning Application'),
        # 空白を加える
        html.Br(),

        # ファイルアップロードの部分
        dcc.Upload(
            children=html.Div([
                'Drag and Drop or ',
                html.A('Select Files')
            ]),
            style={
                'width': '60%',
                'height': '60px',
                'lineHeight': '60px',
                'borderWidth': '1px',
                'borderStyle': 'dashed',
                'borderRadius': '5px',
                'textAlign': 'center',
                'margin': '0 auto'
            }
        ),

        # スコアの表示
        html.H3(f'Average RMSE Score of Linear Regression model is {avg_rmse_score}'),
        html.H3(f'R2 score is {avg_r2_score}'),

        # グラフの記述
        dcc.Graph(
            figure={
                'data': [
                    go.Scatter(
                        x=y_tr_pred,
                        y=y_tr_pred - y_tr,
                        mode='markers',
                        opacity=0.7,
                        marker={
                            'size': 10,
                            'line': {'width': 0.5, 'color': 'white'}
                        },
                        name='train data'
                    ),

                    go.Scatter(
                        x=y_val_pred,
                        y=y_val_pred - y_val,
                        mode='markers',
                        opacity=0.7,
                        marker={
                            'size': 10,
                            'line': {'width': 0.5, 'color': 'white'}
                        },
                        name='test data'
                    )
                ],
                'layout': go.Layout(
                    title='Residual Plot of Median House Price',
                    xaxis={'title': 'Predicted Values'},
                    yaxis={'title': 'Residuals'},
                )
            },
            style={'margin': '0px 100px'}
        )
    ]),
    style=common_style
)

if __name__ == '__main__':
    app.run_server(debug=True)

結果

f:id:wimper_1996:20191029011043g:plain

plotlyモジュールで生成されたグラフ、綺麗…

matlplotlibやseabornで描写するよりもくっきりしているし、拡大できるし、さらにマウスホバーすれば値が表示される。

Data VisualizationがDashを使う醍醐味!

ちょっと今回は単純すぎるプロットだけど、複雑なデータだと可視化させたりするだけで楽しい(私は)。

RSMEとR2スコアは、html.H3()要素 (<h3>~~</h3>)の文の中に変数として格納します

グラフは、dash-core-componentsという、色んな可視化ツールが入ったモジュールの中で、グラフ描写に使われるGraphを読み込みます

全部中身かくと冗長でちょっとわかりにくいんですが、dcc.Graphの引数figureの中の構造を単純化するとこんな感じ

dcc.Graph(
            figure={
                # データ部分
                'data': [
                  # プロット1種類目(train data)
                    go.Scatter(),
           # プロット2種類目(test data)
                    go.Scatter()
                ],
                # グラフのレイアウト部分
                'layout': go.Layout()
            },

これはPlotlyの記法を踏襲していて、Plotyを使ったことがある人はすんなり使えるかも。

でも正直書き方が独特というか、辞書みたいに書くパターン、dict(x=)...みたいに書くパターン、、、いくつかあって、中々覚えられない笑

都度、Plotly公式ドキュメント確認しながら...って感じですね…

さて、結果ですが、残渣プロットは、各プロットが y=0軸付近で、正負均等に分散されていれば、上手く目的変数を予測したと言えます。

見てみると住宅価格が大きくなるほど右下にプロットがあり、予測と実測値がずれていますね

説明力は、72%か。
果たして、いいのか悪いのか…… ?

別のモデルを使ったらどうなるか、試したくなりますね!!!

Step3では、複数のモデルをドロップダウンで選択し、それをグラフに反映させるようにします!

肝である「コールバック(callback)関数」の登場です!!!!