SAMPO/FAB

データ準備

まず、学習用のデータを確認します。

本ノートブックでは、中古車の販売価格を示すデータを使用します。

データの詳細については、以下をご覧ください。

Automobile Data Set

[1]:
import pandas as pd

train_data = pd.read_csv('../data/train_data.csv')
train_data.insert(0, '_sid', list(range(train_data.shape[0])))
train_data.head()
[1]:
_sid price symboling normalized-losses num-of-doors wheel-base length width height
0 0 12964 3 145.0 two 95.9 173.2 66.3 50.2
1 1 10198 0 89.0 four 97.0 173.5 65.4 53.0
2 2 11245 0 115.0 four 98.8 177.8 66.5 55.5
3 3 14399 0 108.0 four 100.4 184.6 66.5 56.1
4 4 25552 -1 93.0 four 110.0 190.9 70.3 56.5

上記のデータについて、以下の前処理を実施します。

  • num-of-doorsはカテゴリ変数のため、二値展開

  • num-of-doorsを除く変数は数値データであるため、標準化

これらの前処理を、後ほどSPDで定義します。

次に、ASDを作成します。

[2]:
from sampotools.api import gen_asd_from_pandas_df

asd = gen_asd_from_pandas_df(train_data)
pd.DataFrame(asd).T[['scale', 'domain']]
[2]:
scale domain
_sid INTEGER NaN
price INTEGER NaN
symboling INTEGER NaN
normalized-losses REAL NaN
num-of-doors NOMINAL [four, two]
wheel-base REAL NaN
length REAL NaN
width REAL NaN
height REAL NaN

最後に、予測用データを読み込みます。

[3]:
predict_data = pd.read_csv('../data/predict_data.csv')
predict_data.insert(0, '_sid', list(range(predict_data.shape[0])))
predict_data.head()
[3]:
_sid price symboling normalized-losses num-of-doors wheel-base length width height
0 0 6338 1 87.0 two 95.7 158.7 63.6 54.5
1 1 18150 3 150.0 two 99.1 186.6 66.5 56.1
2 2 17669 2 134.0 two 98.4 176.2 65.6 53.0
3 3 11549 2 134.0 two 98.4 176.2 65.6 52.0
4 4 32250 0 145.0 four 113.0 199.6 69.6 52.8

分析手順の設計

目的変数priceは連続値であるため、回帰用のcomponentである「FABHMEBernGateLinearRg Component」を選択します。

学習を行う際に、以下のパラメーターを指定します。

seed

  • componentで使用される乱数シードです。

tree_depth

  • 門木の深さを設定するパラメーターです。

shrink_threshold

  • 門木の枝刈りの閾値を設定するパラメーターです。

seed, tree_depth, shrink_thresholdの詳細は、「SAMPO Reference V1.6.2」をご覧ください。

SPD定義

[4]:
from sampo.api import gen_spd

# seedの指定
seed = 0

# SPD の定義
spd_content = '''
dl -> std -> rg
   -> bexp -> rg

---
components:
    dl:
        component: DataLoader

    std:
        component: StandardizeFDComponent
        features: scale == 'real' or scale == 'integer'

    bexp:
        component: BinaryExpandFDComponent
        features: scale == 'nominal'

    rg:
        component: FABHMEBernGateLinearRgComponent
        features: name != 'price'
        target: name == 'price'
        standardize_target: True
        tree_depth: 3
        shrink_threshold: 2.0

global_settings:
    keep_attributes:
        - price
    feature_exclude:
        - price
'''

spd = gen_spd(template=spd_content)

SRC定義

[5]:
from sampo.api import gen_src

train_src_temp = '''
train:
    type: learn
    data_sources:
        dl:
            df: {{ data_df }}
            attr_schema: {{ asd }}
'''

predict_src_temp = '''
predict:
    type: predict
    data_sources:
        dl:
            df: {{ data_df }}
            attr_schema: {{ asd }}
    model_process: {{ model_process }}
'''

学習、予測

学習を実行します。

[6]:
from sampo.api import process_store, process_runner

train_src = gen_src(template=train_src_temp, params={'data_df': train_data, 'asd': asd})

pstore_url = 'pstore_rg'
!rm -rf $pstore_url
process_store.create(pstore_url)

process_runner.run(spd=spd, src=train_src, pstore_url=pstore_url, seed=seed)
process_store.list_process_metadata(pstore_url)
[6]:
process name version started at running time status
0 train 4a53449a-801c-415f-a9e4-a2e9b47c4979 2020-03-05 13:53:44.265040 00:00:01.261934 Succeeded

予測を実行します。

[7]:
train_proc_name = 'train'
predict_src = gen_src(template=predict_src_temp, params={'model_process': train_proc_name,
                                                         'data_df': predict_data, 'asd': asd})
process_runner.run(src=predict_src, pstore_url=pstore_url)
process_store.list_process_metadata(pstore_url)
[7]:
process name version started at running time status
0 predict b11be9f7-5ef6-4cd9-994d-6f661864f931 2020-03-05 13:53:45.625547 00:00:00.736196 Succeeded
1 train 4a53449a-801c-415f-a9e4-a2e9b47c4979 2020-03-05 13:53:44.265040 00:00:01.261934 Succeeded

モデルバリデーション

予測精度の確認

作成したモデルを評価するため、予測精度を確認します。

評価指標として、RMSEを選定します。

[8]:
predict_proc_name = 'predict'
with process_store.open_process(pstore_url, predict_proc_name) as prl:
    prl = prl
prl.load_comp_output_evaluation('rg')['std_root_mean_squared_error'][0]
[8]:
0.7621801464683062

門木の可視化

モデルがもつ門木を可視化します。

[9]:
from IPython.display import display, Image
from sampovis.api.fabhme_vis import save_gate_tree

save_gate_tree(pstore_url, train_proc_name, 'output')
display(Image('output/{}_rg_fabhme_gate_tree.png'.format(train_proc_name)))
../../_images/section_3_sampo_fab_sampo_fab_26_0.png

予測式の可視化

モデルがもつ予測式情報にアクセスし、可視化します。

[10]:
import matplotlib.pyplot as plt
%matplotlib inline

model_params = prl.load_model('rg')
prediction_formulas = model_params['prediction_formulas']
relevant_feature_indices = prediction_formulas.sum(axis=1) != 0
prediction_formulas = prediction_formulas[relevant_feature_indices]
prediction_formulas.plot(kind='barh', figsize=(8, 4), stacked=True)
plt.show()
../../_images/section_3_sampo_fab_sampo_fab_29_0.png