{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# モデルの精度向上と検証方法\n",
    "\n",
    "## 目次\n",
    "\n",
    "- [1. はじめに](#1.-はじめに)\n",
    "- [2. データの準備](#2.-データの準備)\n",
    "- [3. より精度の高いモデルの作成方法](#3.-より精度の高いモデルの作成方法)\n",
    "    - [3.1. テンプレート構文を用いたSPDとSRCの定義](#3.1.-テンプレート構文を用いたSPDとSRCの定義)\n",
    "    - [3.2. パラメーターチューニングの実行方法](#3.2.-パラメーターチューニングの実行方法)\n",
    "    - [3.3. ランダムリスタートの目的と設計・実行方法](#3.3.-ランダムリスタートの目的と設計・実行方法)\n",
    "- [4. モデルの検証方法](#4.-モデルの検証方法)\n",
    "    - [4.1. ホールドアウト検証](#4.1.-ホールドアウト検証)\n",
    "    - [4.2. 交差検証](#4.2.-交差検証)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. はじめに\n",
    "本章を通して、ユーザーは精度の高いモデルを得るために、多様なハイパーパラメーターのモデルを大量かつ手軽に作成できるようになります。  \n",
    "さらに、作成したモデルの中から精度の高いモデルを選択できるようになります。\n",
    "\n",
    "具体的な達成目標は、以下の通りです。\n",
    "\n",
    "- **「テンプレート構文を用いたSPDとSRCの作成ができる」**\n",
    "- **「パラメーターチューニングとランダムリスタートを理解し、テンプレート構文を活用してSPDとSRCの実装ができる」**\n",
    "- **「モデルの検証方法を理解した上でモデルの検証を実行し、検証結果として予測精度を確認できる」**\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. データの準備\n",
    "\n",
    "本節では、以下のデータ準備の手順について示します。[学習・予測の実行と結果確認の2. データの準備](../simple/simple.ipynb#2.-データの準備)と同様ですが、本章ではテンプレート構文を用いるので、CSVとASDはファイル出力は行いません。\n",
    "\n",
    "1. 分析対象データに _sid (sample ID) を追加\n",
    "2. ASD（属性スキーマ）の作成\n",
    "3. 分析対象データから学習用と予測用を作成\n",
    "\n",
    "\n",
    "本章では、分析対象データをPandas Dataframeのまま扱うため、カテゴリ型の属性の型は`np.object`または`np.bool`でなければなりません。\n",
    "そのため、数値型のカテゴリを持つ属性は`np.object`に変換する必要があり、read_csv()関数を実行する際に、`origin`の型を`np.object`として読み込みます。\n",
    "\n",
    "以下のコードで、自動車の燃料消費量予測の分析対象データを示します。\n",
    "\n",
    "データは、UCIのオープンデータである Auto MPG Data Set (https://archive.ics.uci.edu/ml/datasets/auto+mpg) を属性名`car_name`を削除して、利用しています。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>_sid</th>\n",
       "      <th>mpg</th>\n",
       "      <th>cylinders</th>\n",
       "      <th>displacement</th>\n",
       "      <th>horsepower</th>\n",
       "      <th>weight</th>\n",
       "      <th>acceleration</th>\n",
       "      <th>model_year</th>\n",
       "      <th>origin</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>8</td>\n",
       "      <td>307.0</td>\n",
       "      <td>130</td>\n",
       "      <td>3504</td>\n",
       "      <td>12.0</td>\n",
       "      <td>70</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>15.0</td>\n",
       "      <td>8</td>\n",
       "      <td>350.0</td>\n",
       "      <td>165</td>\n",
       "      <td>3693</td>\n",
       "      <td>11.5</td>\n",
       "      <td>70</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>18.0</td>\n",
       "      <td>8</td>\n",
       "      <td>318.0</td>\n",
       "      <td>150</td>\n",
       "      <td>3436</td>\n",
       "      <td>11.0</td>\n",
       "      <td>70</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>16.0</td>\n",
       "      <td>8</td>\n",
       "      <td>304.0</td>\n",
       "      <td>150</td>\n",
       "      <td>3433</td>\n",
       "      <td>12.0</td>\n",
       "      <td>70</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>17.0</td>\n",
       "      <td>8</td>\n",
       "      <td>302.0</td>\n",
       "      <td>140</td>\n",
       "      <td>3449</td>\n",
       "      <td>10.5</td>\n",
       "      <td>70</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   _sid   mpg  cylinders  displacement  horsepower  weight  acceleration  \\\n",
       "0     0  18.0          8         307.0         130    3504          12.0   \n",
       "1     1  15.0          8         350.0         165    3693          11.5   \n",
       "2     2  18.0          8         318.0         150    3436          11.0   \n",
       "3     3  16.0          8         304.0         150    3433          12.0   \n",
       "4     4  17.0          8         302.0         140    3449          10.5   \n",
       "\n",
       "   model_year origin  \n",
       "0          70      1  \n",
       "1          70      1  \n",
       "2          70      1  \n",
       "3          70      1  \n",
       "4          70      1  "
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "input_data = pd.read_csv('./data/auto-mpg.csv', na_values='?', dtype={'origin': np.object})\n",
    "\n",
    "input_data.dropna(inplace=True)\n",
    "input_data.insert(0, '_sid', list(range(input_data.shape[0])))\n",
    "\n",
    "input_data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上記の分析対象データからASDを作成します。\n",
    "\n",
    "[学習・予測の実行と結果確認の2. データの準備](../simple/simple.ipynb#2.-データの準備)で示していますが、ASDのweightを正しいデータ型に修正します。\n",
    "originは、`np.object`としてread_csv()関数で読み込んだのでカテゴリ型として認識されています。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>scale</th>\n",
       "      <th>domain</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>_sid</th>\n",
       "      <td>INTEGER</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mpg</th>\n",
       "      <td>REAL</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>cylinders</th>\n",
       "      <td>INTEGER</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>displacement</th>\n",
       "      <td>REAL</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>horsepower</th>\n",
       "      <td>INTEGER</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>weight</th>\n",
       "      <td>REAL</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>acceleration</th>\n",
       "      <td>REAL</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>model_year</th>\n",
       "      <td>INTEGER</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>origin</th>\n",
       "      <td>NOMINAL</td>\n",
       "      <td>[1, 3, 2]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                scale     domain\n",
       "_sid          INTEGER        NaN\n",
       "mpg              REAL        NaN\n",
       "cylinders     INTEGER        NaN\n",
       "displacement     REAL        NaN\n",
       "horsepower    INTEGER        NaN\n",
       "weight           REAL        NaN\n",
       "acceleration     REAL        NaN\n",
       "model_year    INTEGER        NaN\n",
       "origin        NOMINAL  [1, 3, 2]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sampotools.api import gen_asd_from_pandas_df\n",
    "import pandas as pd\n",
    "import yaml\n",
    "\n",
    "# ASD作成\n",
    "asd = gen_asd_from_pandas_df(input_data)\n",
    "\n",
    "# 修正\n",
    "asd['weight'] = {'scale': 'REAL'}\n",
    "\n",
    "# 確認\n",
    "pd.DataFrame(asd).T[['scale', 'domain']]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "分析対象データを学習用と予測用に分けます。\n",
    "\n",
    "以下のコードを実行し、全体の90%にあたる件数を `learn_data`、残り10%の件数を `predict_data` としてPandas DataFrameに格納します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_all = len(input_data)\n",
    "n_predict = n_all // 10\n",
    "n_learn = n_all - n_predict\n",
    "\n",
    "learn_data = input_data.iloc[0:n_learn,:]\n",
    "predict_data = input_data.iloc[n_learn:n_all,:]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本節の具体的な説明は、[学習・予測の実行と結果確認の2. データの準備](../simple/simple.ipynb#2.-データの準備) を参照してください。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. より精度の高いモデルの作成方法\n",
    "\n",
    "より精度の高いモデルを作成するため、以下について説明します。\n",
    "  - テンプレート構文\n",
    "  - パラメーターサーチ\n",
    "  - ランダムリスタート"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1. テンプレート構文を用いたSPDとSRCの定義\n",
    "\n",
    "本節では、テンプレート構文を用いたSPDとSRCの定義について示します。\n",
    "\n",
    "**テンプレート構文**は、`{{ <テンプレート変数名> }}` というプレースホルダーを用いることで、SPDまたはSRCのプレースホルダーに動的にPythonコード上の変数の値を入力することが可能です。\n",
    "\n",
    "テンプレート構文は、次節以降で利用されます。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "テンプレート構文を用いるSPDの記述例を以下で示します。\n",
    "記述例では、`tree_depth`の値にテンプレート構文を適用しています。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "spd_content_templ = '''\n",
    "dl -> std  -> rg\n",
    "   -> bexp -> rg\n",
    "\n",
    "---\n",
    "\n",
    "components:\n",
    "    dl:\n",
    "        component: DataLoader\n",
    "\n",
    "    std:\n",
    "        component: StandardizeFDComponent\n",
    "        features: scale == 'real' or scale == 'integer'\n",
    "\n",
    "    bexp:\n",
    "        component: BinaryExpandFDComponent\n",
    "        features: scale == 'nominal'\n",
    "\n",
    "    rg:\n",
    "        component: FABHMEBernGateLinearRgComponent\n",
    "        features: name != 'mpg'\n",
    "        target: name == 'mpg'\n",
    "        standardize_target: True\n",
    "        tree_depth: {{ tree_depth }}\n",
    "\n",
    "global_settings:\n",
    "    keep_attributes:\n",
    "        - mpg\n",
    "    feature_exclude:\n",
    "        - mpg\n",
    "'''\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "テンプレート構文は、gen_spd()関数で生成する時にプレースホルダーに値が入力されます。値はテンプレート変数名をキー値とした辞書で与えます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sampo.api import gen_spd\n",
    "\n",
    "spd_param = {'tree_depth': 3}\n",
    "spd = gen_spd(template=spd_content_templ, params=spd_param)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上記で生成したSPDを出力し、`{{ tree_depth }}`の値が入力されていることを確認します。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dl -> std -> rg\n",
      "dl -> bexp -> rg\n",
      "\n",
      "---\n",
      "\n",
      "components:\n",
      "    dl:\n",
      "        component: DataLoader\n",
      "        \n",
      "    bexp:\n",
      "        component: BinaryExpandFDComponent\n",
      "        features: scale == 'nominal'\n",
      "        \n",
      "    std:\n",
      "        component: StandardizeFDComponent\n",
      "        features: scale == 'real' or scale == 'integer'\n",
      "        \n",
      "    rg:\n",
      "        component: FABHMEBernGateLinearRgComponent\n",
      "        features: name != 'mpg'\n",
      "        standardize_target: true\n",
      "        target: name == 'mpg'\n",
      "        tree_depth: 3\n",
      "        \n",
      "global_settings:\n",
      "    keep_attributes:\n",
      "    - mpg\n",
      "\n",
      "    feature_exclude:\n",
      "    - mpg\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(spd) #確認"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以下に、テンプレート構文を用いるSRCの記述例を示します。\n",
    "テンプレート構文を学習用と予測用でプロセス名や分析対象データ、ASDの格納先の指定に適用しています。\n",
    "\n",
    "テンプレート構文を用いることで、Pandas DataFrameのようなPythonオブジェクトもSRCに指定することができます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 学習用SRC\n",
    "learn_src_templ = '''\n",
    "fabhmerg_learn_{{ run_times }}:\n",
    "    type: learn\n",
    "\n",
    "    data_sources:\n",
    "        dl:\n",
    "            df: {{ learn_data }}\n",
    "            attr_schema: {{ asd }}\n",
    "'''\n",
    "\n",
    "# 予測用SRC\n",
    "predict_src_templ = '''\n",
    "fabhmerg_predict_{{ run_times }}:\n",
    "    type: predict\n",
    "\n",
    "    data_sources:\n",
    "        dl:\n",
    "            df: {{ predict_data }}\n",
    "            attr_schema: {{ asd }}\n",
    "\n",
    "    model_process: fabhmerg_learn_{{ run_times }}\n",
    "'''\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SPDと同様に、SRCをgen_src()関数で生成する時に、`src_param`も渡すことでプレースホルダーに値が入力されます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sampo.api import gen_src\n",
    "\n",
    "src_param = {'run_times': 0 ,'learn_data':learn_data ,'predict_data':predict_data, 'asd': asd}\n",
    "learn_src = gen_src(template=learn_src_templ, params=src_param)\n",
    "predict_src = gen_src(template=predict_src_templ, params=src_param)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "値を確認するため、gen_src()関数で生成したSRCをそれぞれ出力します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fabhmerg_learn_0:\n",
      "    type: learn\n",
      "\n",
      "    data_sources:\n",
      "        dl:\n",
      "            df:      _sid   mpg  cylinders   ...    acceleration  model_year  origin\n",
      "                0       0  18.0          8   ...            12.0          70       1\n",
      "                1       1  15.0          8   ...            11.5          70       1\n",
      "                2       2  18.0          8   ...            11.0          70       1\n",
      "                ..    ...   ...        ...   ...             ...         ...     ...\n",
      "                350   350  33.7          4   ...            14.4          81       3\n",
      "                351   351  32.4          4   ...            16.8          81       3\n",
      "                352   352  32.9          4   ...            14.8          81       3\n",
      "                \n",
      "                [353 rows x 9 columns]\n",
      "            attr_schema: \n",
      "                [('_sid', {'scale': 'INTEGER'}),\n",
      "                 ('mpg', {'scale': 'REAL'}),\n",
      "                 ('cylinders', {'scale': 'INTEGER'}),\n",
      "                 ('displacement', {'scale': 'REAL'}),\n",
      "                 ('horsepower', {'scale': 'INTEGER'}),\n",
      "                 ...]\n",
      "            \n",
      "                [Displaying 5 out of 9 attributes.]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 確認\n",
    "print(learn_src)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fabhmerg_predict_0:\n",
      "    type: predict\n",
      "\n",
      "    data_sources:\n",
      "        dl:\n",
      "            df:      _sid   mpg  cylinders   ...    acceleration  model_year  origin\n",
      "                353   353  31.6          4   ...            18.3          81       3\n",
      "                354   354  28.1          4   ...            20.4          81       2\n",
      "                355   355  30.7          6   ...            19.6          81       2\n",
      "                ..    ...   ...        ...   ...             ...         ...     ...\n",
      "                389   389  32.0          4   ...            11.6          82       1\n",
      "                390   390  28.0          4   ...            18.6          82       1\n",
      "                391   391  31.0          4   ...            19.4          82       1\n",
      "                \n",
      "                [39 rows x 9 columns]\n",
      "            attr_schema: \n",
      "                [('_sid', {'scale': 'INTEGER'}),\n",
      "                 ('mpg', {'scale': 'REAL'}),\n",
      "                 ('cylinders', {'scale': 'INTEGER'}),\n",
      "                 ('displacement', {'scale': 'REAL'}),\n",
      "                 ('horsepower', {'scale': 'INTEGER'}),\n",
      "                 ...]\n",
      "            \n",
      "                [Displaying 5 out of 9 attributes.]\n",
      "    model_process: fabhmerg_learn_0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(predict_src)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2. パラメーターチューニングの実行方法\n",
    "\n",
    "パラメーターチューニングの例として、グリッドサーチを実行する例を示します。  \n",
    "**グリッドサーチ**は、ユーザーが指定したハイパーパラメーターから全ての組み合わせで学習を実施し、最適なパラメーターの組み合わせを探すことです。\n",
    "\n",
    "下記のSPDの`tree_depth`と`standardize_target`にテンプレート構文を適用し、グリッドサーチを行います。  \n",
    "SAMPO/FABでグリッドサーチを行う場合は、ハイパーパラメーターの組み合わせの数だけSPDを作成します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "spd_content_para_tuning = '''\n",
    "dl -> std  -> rg\n",
    "   -> bexp -> rg\n",
    "\n",
    "---\n",
    "\n",
    "components:\n",
    "    dl:\n",
    "        component: DataLoader\n",
    "\n",
    "    std:\n",
    "        component: StandardizeFDComponent\n",
    "        features: scale == 'real' or scale == 'integer'\n",
    "\n",
    "    bexp:\n",
    "        component: BinaryExpandFDComponent\n",
    "        features: scale == 'nominal'\n",
    "\n",
    "    rg:\n",
    "        component: FABHMEBernGateLinearRgComponent\n",
    "        features: name != 'mpg'\n",
    "        target: name == 'mpg'\n",
    "        standardize_target: {{ standardize_target }}\n",
    "        tree_depth: {{ tree_depth }}\n",
    "\n",
    "global_settings:\n",
    "    keep_attributes:\n",
    "        - mpg\n",
    "    feature_exclude:\n",
    "        - mpg\n",
    "'''\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下記では、ハイパーパラメーターの各組み合わせをリストで持つ`spd_params`が作成されます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import ParameterGrid\n",
    "\n",
    "spd_param_combination = {\n",
    "    'standardize_target': [True, False],\n",
    "    'tree_depth': [3, 4, 5]\n",
    "}\n",
    "\n",
    "spd_params = list(ParameterGrid(spd_param_combination))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SRCの記述例を示します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 学習用SRC\n",
    "learn_src_para_tuning = '''\n",
    "fabhmerg_learn_{{ parameter_pattern }}:\n",
    "    type: learn\n",
    "\n",
    "    data_sources:\n",
    "        dl:\n",
    "            df: {{ learn_data }}\n",
    "            attr_schema: {{ input_asd }}\n",
    "'''\n",
    "\n",
    "# 予測用SRC\n",
    "predict_src_para_tuning = '''\n",
    "fabhmerg_predict_{{ parameter_pattern }}:\n",
    "    type: predict\n",
    "\n",
    "    data_sources:\n",
    "        dl:\n",
    "            df: {{ predict_data }}\n",
    "            attr_schema: {{ input_asd }}\n",
    "\n",
    "    model_process: fabhmerg_learn_{{ parameter_pattern }}\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上記で作成したSPD、`spd_params`、SRCを用いて、グリッドサーチを実行します。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "from sampo.api import sampo_logging\n",
    "from sampo.api import gen_spd, gen_src\n",
    "from sampotools.api import gen_asd_from_pandas_df\n",
    "\n",
    "sampo_logging.configure(logging.INFO, filename='./fabhmerg_para.log')\n",
    "\n",
    "process_list = []  # 並列実行するプロセスを保存するリスト\n",
    "\n",
    "for grid_id, spd_param in enumerate(spd_params):\n",
    "    spd = gen_spd(template=spd_content_para_tuning, params=spd_param)\n",
    "    src_param = {'parameter_pattern': grid_id, 'learn_data': learn_data, 'predict_data': predict_data, 'input_asd': asd}\n",
    "    learn_src = gen_src(template=learn_src_para_tuning, params=src_param)\n",
    "    predict_src = gen_src(template=predict_src_para_tuning, params=src_param)\n",
    "    process_list.append((learn_src, spd))  # 学習用の分析プロセスの追加\n",
    "    process_list.append((predict_src, None))  # 予測用の分析プロセスの追加\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('fabhmerg_learn_0.f495addd-ab93-4952-aff5-40a6327812d9', None),\n",
       " ('fabhmerg_learn_2.28b9d7d4-3310-4708-aa7c-76fe9f126979', None),\n",
       " ('fabhmerg_learn_1.5171862b-1406-4df7-a542-e8c8325f34c1', None),\n",
       " ('fabhmerg_learn_3.86f3c015-dc4d-4306-9bf3-7fa3dac69680', None),\n",
       " ('fabhmerg_learn_4.e79f0543-d244-4858-b22a-81c4b66826ed', None),\n",
       " ('fabhmerg_predict_0.a1db32c2-c9b0-4591-a5be-d33d7d90d7d7', None),\n",
       " ('fabhmerg_learn_5.f66267e2-5117-4adc-ad7d-2096ec4d5850', None),\n",
       " ('fabhmerg_predict_1.6decfd04-c14f-4982-bb02-e607dda16f6b', None),\n",
       " ('fabhmerg_predict_3.257d52fd-7069-4b35-a5f2-918bd618d269', None),\n",
       " ('fabhmerg_predict_2.dd03f131-4b9e-4aca-937a-888355b7fa0c', None),\n",
       " ('fabhmerg_predict_4.03063d1d-d030-41f3-8c1d-aa6d4d299895', None),\n",
       " ('fabhmerg_predict_5.362e7401-9426-4ceb-95ee-0601950cb419', None)]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sampo.api import process_runner, process_store\n",
    "\n",
    "pstore_url = './parallel_pstore_tuning'\n",
    "process_store.create(pstore_url)\n",
    "\n",
    "#グリッドサーチ実行\n",
    "process_runner.session_run(process_list, pstore_url=pstore_url, max_workers=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "プロセスストアから実行済プロセスを開き、算出された予測精度を表示します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>process_name</th>\n",
       "      <th>rmse</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>fabhmerg_predict_1</td>\n",
       "      <td>3.682181</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>fabhmerg_predict_3</td>\n",
       "      <td>3.804716</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>fabhmerg_predict_5</td>\n",
       "      <td>3.860132</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>fabhmerg_predict_0</td>\n",
       "      <td>3.916196</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>fabhmerg_predict_2</td>\n",
       "      <td>4.140845</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>fabhmerg_predict_4</td>\n",
       "      <td>4.152079</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         process_name      rmse\n",
       "1  fabhmerg_predict_1  3.682181\n",
       "3  fabhmerg_predict_3  3.804716\n",
       "5  fabhmerg_predict_5  3.860132\n",
       "0  fabhmerg_predict_0  3.916196\n",
       "2  fabhmerg_predict_2  4.140845\n",
       "4  fabhmerg_predict_4  4.152079"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import re\n",
    "import pandas as pd\n",
    "from sampo.api import process_store\n",
    "\n",
    "result = []\n",
    "predict_proc_names = [src.name for src, _ in process_list if re.match('fabhmerg_predict.*', src.name)]\n",
    "for predict_proc_name in predict_proc_names:\n",
    "    row = {}\n",
    "    with process_store.open_process(pstore_url, predict_proc_name) as prl:\n",
    "        evaluation = prl.load_comp_output_evaluation('rg')\n",
    "        row['process_name'] = predict_proc_name\n",
    "        row['rmse'] = evaluation['root_mean_squared_error'][0]\n",
    "        result.append(row)\n",
    "\n",
    "pd.DataFrame(result).sort_values(by='rmse')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上記の表から、rmseの値が低いプロセス名の番号から、パラメーターの組み合わせを確認します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'standardize_target': True, 'tree_depth': 3}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_model = result[0]['process_name']\n",
    "num_best_model = int(best_model[-1:])\n",
    "spd_params[num_best_model]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3. ランダムリスタートの目的と設計・実行方法\n",
    "\n",
    "SAMPO/FABの異種混合学習コンポーネント(`FABHMEBernGateLinearRgComponent`)で作成するモデルは、ランダムに決定される初期状態から学習を開始します。\n",
    "これは同じ分析対象データとハイパーパラメーターで学習を実行しても、実行するたびに異なるモデルが作成されることを示します。\n",
    "\n",
    "そのため、より精度の高いモデルを選択するために分析プロセスを複数回実行することを**ランダムリスタート**と言います。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ランダムリスタートでは、SRCのプロセス名が一意になるようにテンプレート構文を適用します。\n",
    "\n",
    "下記のSPDとSRCを用いて、ランダムリスタートの実行例を示します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "spd_content_random = '''\n",
    "dl -> std  -> rg\n",
    "   -> bexp -> rg\n",
    "\n",
    "---\n",
    "\n",
    "components:\n",
    "    dl:\n",
    "        component: DataLoader\n",
    "\n",
    "    std:\n",
    "        component: StandardizeFDComponent\n",
    "        features: scale == 'real' or scale == 'integer'\n",
    "\n",
    "    bexp:\n",
    "        component: BinaryExpandFDComponent\n",
    "        features: scale == 'nominal'\n",
    "\n",
    "    rg:\n",
    "        component: FABHMEBernGateLinearRgComponent\n",
    "        features: name != 'mpg'\n",
    "        target: name == 'mpg'\n",
    "        standardize_target: True\n",
    "        tree_depth: 3\n",
    "\n",
    "global_settings:\n",
    "    keep_attributes:\n",
    "        - mpg\n",
    "    feature_exclude:\n",
    "        - mpg\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 学習用SRC\n",
    "learn_src_random = '''\n",
    "fabhmerg_learn_{{ run_times }}:\n",
    "    type: learn\n",
    "\n",
    "    data_sources:\n",
    "        dl:\n",
    "            df: {{ learn_data }}\n",
    "            attr_schema: {{ asd }}\n",
    "'''\n",
    "\n",
    "# 予測用SRC\n",
    "predict_src_random = '''\n",
    "fabhmerg_predict_{{ run_times }}:\n",
    "    type: predict\n",
    "\n",
    "    data_sources:\n",
    "        dl:\n",
    "            df: {{ predict_data }}\n",
    "            attr_schema: {{ asd }}\n",
    "\n",
    "    model_process: fabhmerg_learn_{{ run_times }}\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SAMPO/FABが読み込めるように、SPDをgen_spd()関数で、学習用SRCをgen_src()関数で生成します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "from sampo.api import sampo_logging\n",
    "from sampo.api import gen_spd, gen_src\n",
    "from sampotools.api import gen_asd_from_pandas_df\n",
    "\n",
    "sampo_logging.configure(logging.INFO, filename='./fabhmerg_randomrestart.log')\n",
    "\n",
    "process_list = []  # 並列実行するプロセスを保存するリスト\n",
    "num_random_restarts = 5\n",
    "\n",
    "for idx in range(num_random_restarts):\n",
    "    src_param = {'run_times': idx , 'learn_data': learn_data, 'predict_data': predict_data, 'asd': asd}\n",
    "    spd = gen_spd(template=spd_content_random)\n",
    "    learn_src = gen_src(template=learn_src_random, params=src_param)\n",
    "    predict_src = gen_src(template=predict_src_random, params=src_param)\n",
    "    process_list.append((learn_src, spd))  # 学習用の分析プロセスの追加\n",
    "    process_list.append((predict_src, None))  # 予測用の分析プロセスの追加"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ランダムリスタートを実行します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('fabhmerg_learn_1.0555f2e3-ced0-4ed6-a5f9-0fd87a0d977a', None),\n",
       " ('fabhmerg_learn_0.c061ee34-5bfe-4cf4-a020-4abf274ef9f1', None),\n",
       " ('fabhmerg_learn_2.13a26f3b-bc8a-4bd3-86f9-2eac0384c398', None),\n",
       " ('fabhmerg_learn_3.eb000f49-c8f4-4197-a200-cb40ad2da087', None),\n",
       " ('fabhmerg_learn_4.d219f3dc-c390-4e81-bead-36cd48ad5565', None),\n",
       " ('fabhmerg_predict_1.c2c6c0c2-9366-4db2-ba7e-c5072796b82f', None),\n",
       " ('fabhmerg_predict_0.ace6592c-6b99-434a-b029-e702d46ac270', None),\n",
       " ('fabhmerg_predict_3.ec6dcfc1-e402-4282-bcc2-d27ae10586a1', None),\n",
       " ('fabhmerg_predict_2.ce690cf7-0a84-4ee9-8e12-254887703da3', None),\n",
       " ('fabhmerg_predict_4.ffc42403-a93d-4254-8207-d293e1641f0a', None)]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sampo.api import process_runner, process_store\n",
    "\n",
    "pstore_url = './parallel_pstore'\n",
    "process_store.create(pstore_url)\n",
    "\n",
    "#ランダムリスタート実行\n",
    "process_runner.session_run(process_list, pstore_url=pstore_url, max_workers=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>process_name</th>\n",
       "      <th>rmse</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>fabhmerg_predict_3</td>\n",
       "      <td>3.632642</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>fabhmerg_predict_1</td>\n",
       "      <td>3.726962</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>fabhmerg_predict_4</td>\n",
       "      <td>3.798683</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>fabhmerg_predict_2</td>\n",
       "      <td>3.818484</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>fabhmerg_predict_0</td>\n",
       "      <td>4.227888</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         process_name      rmse\n",
       "3  fabhmerg_predict_3  3.632642\n",
       "1  fabhmerg_predict_1  3.726962\n",
       "4  fabhmerg_predict_4  3.798683\n",
       "2  fabhmerg_predict_2  3.818484\n",
       "0  fabhmerg_predict_0  4.227888"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import re\n",
    "import pandas as pd\n",
    "from sampo.api import process_store\n",
    "\n",
    "result = []\n",
    "predict_proc_names = [src.name for src, _ in process_list if re.match('fabhmerg_predict.*', src.name)]\n",
    "for predict_proc_name in predict_proc_names:\n",
    "    row = {}\n",
    "    with process_store.open_process(pstore_url, predict_proc_name) as prl:\n",
    "        evaluation = prl.load_comp_output_evaluation('rg')\n",
    "        row['process_name'] = predict_proc_name\n",
    "        row['rmse'] = evaluation['root_mean_squared_error'][0]\n",
    "        result.append(row)\n",
    "\n",
    "pd.DataFrame(result).sort_values(by='rmse')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上記の結果から、作成したモデルの中から、RMSEが最小のモデルを選択することができます。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. モデルの検証方法\n",
    "\n",
    "本節では、複数存在するモデルの検証方法からSAMPO/FABでよく用いられる下記の例を示します。\n",
    "\n",
    "- ホールドアウト検証\n",
    "- 交差検証\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1. ホールドアウト検証\n",
    "\n",
    "　**ホールドアウト検証**は、学習用データでモデルを作成して検証用データでモデルの評価を行う検証方法です。\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SPDは、[3.2節のパラメーターチューニングの実行方法](#parameter_tuning)を流用します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "spd_content_ho = '''\n",
    "dl -> std  -> rg\n",
    "   -> bexp -> rg\n",
    "\n",
    "---\n",
    "\n",
    "components:\n",
    "    dl:\n",
    "        component: DataLoader\n",
    "\n",
    "    std:\n",
    "        component: StandardizeFDComponent\n",
    "        features: scale == 'real' or scale == 'integer'\n",
    "\n",
    "    bexp:\n",
    "        component: BinaryExpandFDComponent\n",
    "        features: scale == 'nominal'\n",
    "\n",
    "    rg:\n",
    "        component: FABHMEBernGateLinearRgComponent\n",
    "        features: name != 'mpg'\n",
    "        target: name == 'mpg'\n",
    "        standardize_target: {{ standardize_target }}\n",
    "        tree_depth: {{ tree_depth }}\n",
    "\n",
    "global_settings:\n",
    "    keep_attributes:\n",
    "        - mpg\n",
    "    feature_exclude:\n",
    "        - mpg\n",
    "'''\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SRCは、下記のテンプレート構文を用いた学習用と検証用を使用します。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 学習用SRC\n",
    "learn_src_ho = '''\n",
    "fabhmerg_learn_ho{{ run_times }}:\n",
    "    type: learn\n",
    "\n",
    "    data_sources:\n",
    "        dl:\n",
    "            df: {{ input_df }}\n",
    "            attr_schema: {{ asd }}\n",
    "            filters:\n",
    "                 - k_split(3, 0, True)\n",
    "'''\n",
    "\n",
    "# 検証用SRC\n",
    "predict_src_ho = '''\n",
    "fabhmerg_predict_ho{{ run_times }}:\n",
    "    type: predict\n",
    "\n",
    "    data_sources:\n",
    "        dl:\n",
    "            df: {{ input_df }}\n",
    "            attr_schema: {{ asd }}\n",
    "            filters:\n",
    "                 - k_split(3, 0, False)\n",
    "\n",
    "    model_process: fabhmerg_learn_ho{{ run_times }}\n",
    "'''\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "分析対象データをSRCに記述されているk_split()関数で学習用と検証用に分割します。\n",
    "\n",
    "k_split()関数は、以下の内容を引数にて指定します。\n",
    " - 第一引数 : 分析対象データを分割する数 k\n",
    " - 第二引数 : 分割したデータのindex (0 から k-1まで)\n",
    " - 第三引数 : Trueの場合、第二引数で指定したデータ以外を返す\n",
    "\n",
    "上記のSRCでは、分析対象データを3分割にし、1つ目を検証用、残りを学習用に使用する指定にしています。\n",
    "\n",
    "k_split()関数の詳細については、`Analytics Reference`の`SRC (SAMPO Run Configuration) Specification`を参照してください。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "from sampo.api import sampo_logging\n",
    "import pandas as pd\n",
    "from sampo.api import gen_spd, gen_src\n",
    "from sampotools.api import gen_asd_from_pandas_df\n",
    "\n",
    "sampo_logging.configure(logging.INFO, filename='./fabhmerg_ho.log')\n",
    "\n",
    "process_list = []  # 並列実行するプロセスを保存するリスト\n",
    "num_random_restarts = 5\n",
    "\n",
    "for n in range(num_random_restarts):\n",
    "    spd_param = {'tree_depth': 3 , 'standardize_target':True }\n",
    "    src_param = {'run_times': n , 'input_df': input_data , 'asd': asd}\n",
    "    spd = gen_spd(template=spd_content_ho, params=spd_param)\n",
    "    learn_src = gen_src(template=learn_src_ho, params=src_param)\n",
    "    predict_src = gen_src(template=predict_src_ho, params=src_param)\n",
    "    process_list.append((learn_src, spd))  # 学習用の分析プロセスの追加\n",
    "    process_list.append((predict_src, None))  # 予測用の分析プロセスの追加"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('fabhmerg_learn_ho0.dfd413c1-8d98-45d8-b47e-7bbd1bae63d4', None),\n",
       " ('fabhmerg_learn_ho2.949da768-90da-4eab-b98b-8433f79617aa', None),\n",
       " ('fabhmerg_learn_ho1.841cf70d-95f6-45a9-b7f2-7744bfa0a325', None),\n",
       " ('fabhmerg_predict_ho0.0bd40baa-a5d0-40c6-8e1d-878df3b50512', None),\n",
       " ('fabhmerg_predict_ho2.86d00f75-1c2b-46af-8c44-83cbb2c583c1', None),\n",
       " ('fabhmerg_learn_ho4.6f4da795-8a43-4fdf-855a-984c12e5a6d3', None),\n",
       " ('fabhmerg_learn_ho3.a7c1c0d6-c3c0-470a-939b-64d8b91cd681', None),\n",
       " ('fabhmerg_predict_ho4.bdf53cf5-3b9c-4a85-8eac-730ad73af60a', None),\n",
       " ('fabhmerg_predict_ho1.3cfe4864-e9cf-4c66-a139-d3150c253eae', None),\n",
       " ('fabhmerg_predict_ho3.e7052729-b760-4aad-9c17-bd3eba68e2ca', None)]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sampo.api import process_runner, process_store\n",
    "\n",
    "pstore_url = './parallel_pstore_hold_out'\n",
    "process_store.create(pstore_url)\n",
    "\n",
    "#ホールドアウト検証実行\n",
    "process_runner.session_run(process_list, pstore_url=pstore_url, max_workers=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>process_name</th>\n",
       "      <th>rmse</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>fabhmerg_predict_ho2</td>\n",
       "      <td>3.539923</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>fabhmerg_predict_ho3</td>\n",
       "      <td>4.184492</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>fabhmerg_predict_ho1</td>\n",
       "      <td>4.382446</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>fabhmerg_predict_ho4</td>\n",
       "      <td>4.596399</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>fabhmerg_predict_ho0</td>\n",
       "      <td>4.791177</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           process_name      rmse\n",
       "2  fabhmerg_predict_ho2  3.539923\n",
       "3  fabhmerg_predict_ho3  4.184492\n",
       "1  fabhmerg_predict_ho1  4.382446\n",
       "4  fabhmerg_predict_ho4  4.596399\n",
       "0  fabhmerg_predict_ho0  4.791177"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import re\n",
    "import pandas as pd\n",
    "from sampo.api import process_store\n",
    "\n",
    "result = []\n",
    "predict_proc_names = [src.name for src, _ in process_list if re.match('fabhmerg_predict.*', src.name)]\n",
    "for predict_proc_name in predict_proc_names:\n",
    "    row = {}\n",
    "    with process_store.open_process(pstore_url, predict_proc_name) as prl:\n",
    "        evaluation = prl.load_comp_output_evaluation('rg')\n",
    "        row['process_name'] = predict_proc_name\n",
    "        row['rmse'] = evaluation['root_mean_squared_error'][0]\n",
    "        result.append(row)\n",
    "\n",
    "pd.DataFrame(result).sort_values(by='rmse')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上記から、ホールドアウト検証を行った検証結果として予測精度のRMSEが確認できます。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2. 交差検証\n",
    "\n",
    "　**交差検証**は、分析対象データを分割し、学習用と検証用の組み合わせを入れ替えながら実行する検証方法です。\n",
    "ホールドアウト検証と比べて、1度の検証回数が多くなることで実行時間が長くなりますが、全パターンの評価結果の平均値をとるので信頼が高くなります。\n",
    "\n",
    "交差検証にはいくつか種類があり、以下でK分割交差検証の実行例を示します。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "spd_content_cv = '''\n",
    "dl -> std  -> rg\n",
    "   -> bexp -> rg\n",
    "\n",
    "---\n",
    "\n",
    "components:\n",
    "    dl:\n",
    "        component: DataLoader\n",
    "\n",
    "    std:\n",
    "        component: StandardizeFDComponent\n",
    "        features: scale == 'real' or scale == 'integer'\n",
    "\n",
    "    bexp:\n",
    "        component: BinaryExpandFDComponent\n",
    "        features: scale == 'nominal'\n",
    "\n",
    "    rg:\n",
    "        component: FABHMEBernGateLinearRgComponent\n",
    "        features: name != 'mpg'\n",
    "        target: name == 'mpg'\n",
    "        standardize_target: {{ standardize_target }}\n",
    "        tree_depth: {{ tree_depth }}\n",
    "\n",
    "global_settings:\n",
    "    keep_attributes:\n",
    "        - mpg\n",
    "    feature_exclude:\n",
    "        - mpg\n",
    "'''\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以下のSRCでは、プロセス名のrrがランダムリスタートの回数、splitが指定したデータのindexを示しています。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 学習用SRC\n",
    "learn_src_cv = '''\n",
    "fabhmerg_learn_rr{{ run_times }}_split{{ split_position }}:\n",
    "    type: learn\n",
    "\n",
    "    data_sources:\n",
    "        dl:\n",
    "            df: {{ input_df }}\n",
    "            attr_schema: {{ asd }}\n",
    "            filters:\n",
    "                 - k_split({{ k_split }}, {{ split_position }}, True)\n",
    "'''\n",
    "\n",
    "# 検証用SRC\n",
    "predict_src_cv = '''\n",
    "fabhmerg_predict_rr{{ run_times }}_split{{ split_position }}:\n",
    "    type: predict\n",
    "\n",
    "    data_sources:\n",
    "        dl:\n",
    "            df: {{ input_df }}\n",
    "            attr_schema: {{ asd }}\n",
    "            filters:\n",
    "                 - k_split({{ k_split }}, {{ split_position }}, False)\n",
    "\n",
    "    model_process: fabhmerg_learn_rr{{ run_times }}_split{{ split_position }}\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上記のSPDとSRCを用いて、3-分割交差検証をランダムリスタートで3回実行します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "from sampo.api import sampo_logging\n",
    "from sampo.api import gen_spd, gen_src\n",
    "from sampotools.api import gen_asd_from_pandas_df\n",
    "\n",
    "sampo_logging.configure(logging.INFO, filename='./fabhmerg_cv.log')\n",
    "\n",
    "process_list = []  # 並列実行するプロセスを保存するリスト\n",
    "num_random_restarts = 3 \n",
    "num_split = 3      #分割する数\n",
    "\n",
    "# 交差検証実行\n",
    "for rr in range(num_random_restarts):\n",
    "    for k in range(num_split):\n",
    "        spd_param = {'tree_depth': 3, 'standardize_target': True }\n",
    "        src_param = {'run_times': rr ,  'input_df': input_data, 'k_split': num_split, 'split_position': k , 'asd': asd}\n",
    "        spd = gen_spd(template=spd_content_cv, params=spd_param)\n",
    "        learn_src = gen_src(template=learn_src_cv, params=src_param)\n",
    "        predict_src = gen_src(template=predict_src_cv, params=src_param)\n",
    "        process_list.append((learn_src, spd))  # 学習用の分析プロセスの追加\n",
    "        process_list.append((predict_src, None))  # 予測用の分析プロセスの追加"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('fabhmerg_learn_rr0_split2.a6854723-bb24-42c0-a297-5ebfc1a3c9cd', None),\n",
       " ('fabhmerg_learn_rr0_split0.24e3c4b7-8989-49c8-895c-6e35618b788b', None),\n",
       " ('fabhmerg_learn_rr0_split1.7422ae74-82cd-490e-ba56-a632335ed504', None),\n",
       " ('fabhmerg_learn_rr1_split0.92eb3f1a-e732-476f-86fb-62bad71a84c6', None),\n",
       " ('fabhmerg_learn_rr1_split1.16dc4ff2-7b96-4846-94c7-e073bad03486', None),\n",
       " ('fabhmerg_learn_rr1_split2.e5d37eba-5510-4a7b-9c06-7d36feced671', None),\n",
       " ('fabhmerg_learn_rr2_split1.852c4344-5a58-4d37-8aa7-cb1263f5131b', None),\n",
       " ('fabhmerg_learn_rr2_split0.ae895fcd-7b09-4d7f-b89f-1bd6c258c223', None),\n",
       " ('fabhmerg_learn_rr2_split2.e465bcd3-d1f5-4cc3-94bc-06c26e445fa4', None),\n",
       " ('fabhmerg_predict_rr0_split2.eaf65065-8c8c-48e7-b610-0688d492045d', None),\n",
       " ('fabhmerg_predict_rr0_split0.606e8a9e-62df-4356-bc53-0728a5b64ca7', None),\n",
       " ('fabhmerg_predict_rr1_split0.d38f5a79-bc1a-4d3a-a76f-610d836ce545', None),\n",
       " ('fabhmerg_predict_rr1_split1.0719da35-12a7-4767-921e-8297f37f7a6e', None),\n",
       " ('fabhmerg_predict_rr0_split1.6f254508-14a6-48f9-b295-1a189f61c063', None),\n",
       " ('fabhmerg_predict_rr2_split1.3814e0b8-38a7-4db9-8544-6d7a18a29435', None),\n",
       " ('fabhmerg_predict_rr1_split2.356964d1-a644-429f-b82e-e7edfb3da1c0', None),\n",
       " ('fabhmerg_predict_rr2_split0.fb026f46-8192-4321-b1be-f3b54b26bf57', None),\n",
       " ('fabhmerg_predict_rr2_split2.c30218aa-ebdb-4cfa-94d2-a6067938a9ec', None)]"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sampo.api import process_runner, process_store\n",
    "\n",
    "pstore_url = './parallel_pstore_cv'\n",
    "process_store.create(pstore_url)\n",
    "\n",
    "process_runner.session_run(process_list, pstore_url=pstore_url, max_workers=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>process_name</th>\n",
       "      <th>rmse</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>fabhmerg_predict_rr0_split0</td>\n",
       "      <td>4.532768</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>fabhmerg_predict_rr0_split1</td>\n",
       "      <td>4.638185</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>fabhmerg_predict_rr0_split2</td>\n",
       "      <td>5.321502</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>fabhmerg_predict_rr1_split0</td>\n",
       "      <td>4.156640</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>fabhmerg_predict_rr1_split1</td>\n",
       "      <td>2.355842</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>fabhmerg_predict_rr1_split2</td>\n",
       "      <td>5.546513</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>fabhmerg_predict_rr2_split0</td>\n",
       "      <td>3.592213</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>fabhmerg_predict_rr2_split1</td>\n",
       "      <td>2.406080</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>fabhmerg_predict_rr2_split2</td>\n",
       "      <td>5.186844</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  process_name      rmse\n",
       "0  fabhmerg_predict_rr0_split0  4.532768\n",
       "1  fabhmerg_predict_rr0_split1  4.638185\n",
       "2  fabhmerg_predict_rr0_split2  5.321502\n",
       "3  fabhmerg_predict_rr1_split0  4.156640\n",
       "4  fabhmerg_predict_rr1_split1  2.355842\n",
       "5  fabhmerg_predict_rr1_split2  5.546513\n",
       "6  fabhmerg_predict_rr2_split0  3.592213\n",
       "7  fabhmerg_predict_rr2_split1  2.406080\n",
       "8  fabhmerg_predict_rr2_split2  5.186844"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import re\n",
    "import pandas as pd\n",
    "from sampo.api import process_store\n",
    "\n",
    "result = []\n",
    "predict_proc_names = [src.name for src, _ in process_list if re.match('fabhmerg_predict.*', src.name)]\n",
    "for predict_proc_name in predict_proc_names:\n",
    "    row = {}\n",
    "    with process_store.open_process(pstore_url, predict_proc_name) as prl:\n",
    "        evaluation = prl.load_comp_output_evaluation('rg')\n",
    "        row['process_name'] = predict_proc_name\n",
    "        row['rmse'] = evaluation['root_mean_squared_error'][0]\n",
    "        result.append(row)\n",
    "\n",
    "pd.DataFrame(result).sort_values('process_name')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上記から交差検証を行った結果が確認できます。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[ページトップへ](#top)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
