モデルバリデーション
====================
本項では、モデルバリデーションについて、SAMPO/FABとsklearn-fabにおける実現方法およびスクリプトの差異を示します。
モデルバリデーションとして、定量的な評価、門木可視化、予測式可視化を示します。

定量的な評価
------------
主な評価指標として、以下のように取得または算出します。

.. csv-table::
  :header-rows: 1

  "予測種別", "評価指標", "SAMPO/FAB", "sklearn-fab"
  "回帰", "RMSE", "ProcessResultLoaderクラス内の以下を取得

  * root_mean_squared_error
  * std_root_mean_squared_error（standardize_target: True の場合）
  ", "以下の関数により算出

  * \ `sklearn.metrics.mean_squared_error()`_\"
  "回帰", "MAE", "ProcessResultLoaderクラス内の以下を取得

  * mean_abs_error
  * std_mean_abs_error（standardize_target: True の場合）
  ", "以下の関数により算出

  * \ `sklearn.metrics.mean_absolute_error()`_\"
  "分類", "accuracy", "ProcessResultLoaderクラス内の以下を取得

  * accuracy
  ", "以下の関数により算出

  * \ `sklearn.metrics.accuracy_score()`_\"
  "分類", "precision", "ProcessResultLoaderクラス内の以下を取得

  * precision
  ", "以下の関数により算出

  * \ `sklearn.metrics.precision_score()`_\"
  "分類", "recall", "ProcessResultLoaderクラス内の以下を取得

  * recall
  ", "以下の関数により算出

  * \ `sklearn.metrics.recall_score()`_\"

ここでは、RMSEを出力する例を示します。

.. csv-table::
  :header: "SAMPO/FAB", "sklearn-fab"

  "
  .. code-block:: python

    predict_proc_name = 'predict'
    with process_store.open_process(pstore_url, predict_proc_name) as prl:
        prl = prl
    prl.load_comp_output_evaluation('rg')['std_root_mean_squared_error'][0]
  ", "
  .. code-block:: python

    from sklearn.metrics import mean_squared_error

    test_rmse = mean_squared_error(y_predict, predict_value, squared=False)
    test_rmse
  "
  "
  .. code-block:: python

    0.7621801464683062
  ", "
  .. code-block:: python

    0.6509250433352727
  "

門木可視化
----------
**SAMPO/FAB**

* SAMPOVIS API save_gate_tree関数を使用します。

**sklearn-fab**

* sklearn-fab API export_gate_tree_dot関数を使用します。
* sklearn-fabでは、ユーザがOSSライブラリ等を用いてデータの標準化を事前に実施するため、逆標準化された値は出力されません。

.. csv-table::
  :header: "SAMPO/FAB", "sklearn-fab"

  "
  .. code-block:: python

    from IPython.display import display, Image
    from sampovis.api.fabhme_vis import save_gate_tree

    save_gate_tree(pstore_url, train_proc_name, 'output')
    display(Image('output/{}_rg_fabhme_gate_tree.png'.format(train_proc_name)))
  ", "
  .. code-block:: python

    from IPython.display import display, Image
    from sklearn_fab.utils import export_gate_tree_dot

    dot = export_gate_tree_dot(estimator, X=X_train)
    display(Image(dot.create_png()))
  "
  "
  .. image:: ./image/gate_tree_sampo_fab.png
  ", "
  .. image:: ./image/gate_tree_sklearn_fab.png
  "

予測式可視化
------------
**SAMPO/FAB**

* SAMPO API process_storeとProcessResultLoaderクラスを利用して、予測式を参照し、可視化します。

**sklearn-fab**

* 学習済みのestimatorがもつcomps_アトリビュートを参照し、データ型を変換して可視化します。

.. csv-table::
  :header: "SAMPO/FAB", "sklearn-fab"

  "
  .. code-block:: python

    import matplotlib.pyplot as plt
    %matplotlib inline

    model_params = prl.load_model('rg')
    prediction_formulas = model_params['prediction_formulas']
    relevant_feature_indices = prediction_formulas.sum(axis=1) != 0
    prediction_formulas = prediction_formulas[relevant_feature_indices]
    prediction_formulas.plot(kind='barh', figsize=(8, 4), stacked=True)
    plt.show()
  ", "
  .. code-block:: python

    import numpy as np
    import matplotlib.pyplot as plt
    %matplotlib inline

    # 可視化のために、DataFrame形式で予測式を表現
    columns = np.append(estimator.feature_names_, ['bias', 'variance'])
    prediction_formulas = [np.append(comp.weights, [comp.bias, comp.variance])
                           for comp in estimator.comps_]
    pf_df = pd.DataFrame(prediction_formulas, columns=columns,
                         index=['component #' + str(i) for i in estimator.comp_ids_])

    # 係数が0ではない説明変数のみを抽出し、表示
    relevant_feature_indices = pf_df.sum(axis=0) != 0
    pf_df = pf_df.T[relevant_feature_indices]
    pf_df.plot(kind='barh', figsize=(8, 4), stacked=True)
    plt.show()
  "
  "
  .. image:: ./image/prediction_formula_sampo_fab.png
  ", "
  .. image:: ./image/prediction_formula_sklearn_fab.png
  "

.. _sklearn.metrics.mean_squared_error(): https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html#sklearn-metrics-mean-squared-error
.. _sklearn.metrics.mean_absolute_error(): https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html#sklearn-metrics-mean-absolute-error
.. _sklearn.metrics.accuracy_score(): https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html#sklearn-metrics-accuracy-score
.. _sklearn.metrics.precision_score(): https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn-metrics-precision-score
.. _sklearn.metrics.recall_score(): https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html#sklearn-metrics-recall-score
