SAMPO/FAB¶
データ準備¶
まず、学習用のデータを確認します。
本ノートブックでは、中古車の販売価格を示すデータを使用します。
データの詳細については、以下をご覧ください。
[1]:
import pandas as pd
train_data = pd.read_csv('../data/train_data.csv')
train_data.insert(0, '_sid', list(range(train_data.shape[0])))
train_data.head()
[1]:
_sid | price | symboling | normalized-losses | num-of-doors | wheel-base | length | width | height | |
---|---|---|---|---|---|---|---|---|---|
0 | 0 | 12964 | 3 | 145.0 | two | 95.9 | 173.2 | 66.3 | 50.2 |
1 | 1 | 10198 | 0 | 89.0 | four | 97.0 | 173.5 | 65.4 | 53.0 |
2 | 2 | 11245 | 0 | 115.0 | four | 98.8 | 177.8 | 66.5 | 55.5 |
3 | 3 | 14399 | 0 | 108.0 | four | 100.4 | 184.6 | 66.5 | 56.1 |
4 | 4 | 25552 | -1 | 93.0 | four | 110.0 | 190.9 | 70.3 | 56.5 |
上記のデータについて、以下の前処理を実施します。
num-of-doorsはカテゴリ変数のため、二値展開
num-of-doorsを除く変数は数値データであるため、標準化
これらの前処理を、後ほどSPDで定義します。
次に、ASDを作成します。
[2]:
from sampotools.api import gen_asd_from_pandas_df
asd = gen_asd_from_pandas_df(train_data)
pd.DataFrame(asd).T[['scale', 'domain']]
[2]:
scale | domain | |
---|---|---|
_sid | INTEGER | NaN |
price | INTEGER | NaN |
symboling | INTEGER | NaN |
normalized-losses | REAL | NaN |
num-of-doors | NOMINAL | [four, two] |
wheel-base | REAL | NaN |
length | REAL | NaN |
width | REAL | NaN |
height | REAL | NaN |
最後に、予測用データを読み込みます。
[3]:
predict_data = pd.read_csv('../data/predict_data.csv')
predict_data.insert(0, '_sid', list(range(predict_data.shape[0])))
predict_data.head()
[3]:
_sid | price | symboling | normalized-losses | num-of-doors | wheel-base | length | width | height | |
---|---|---|---|---|---|---|---|---|---|
0 | 0 | 6338 | 1 | 87.0 | two | 95.7 | 158.7 | 63.6 | 54.5 |
1 | 1 | 18150 | 3 | 150.0 | two | 99.1 | 186.6 | 66.5 | 56.1 |
2 | 2 | 17669 | 2 | 134.0 | two | 98.4 | 176.2 | 65.6 | 53.0 |
3 | 3 | 11549 | 2 | 134.0 | two | 98.4 | 176.2 | 65.6 | 52.0 |
4 | 4 | 32250 | 0 | 145.0 | four | 113.0 | 199.6 | 69.6 | 52.8 |
分析手順の設計¶
目的変数priceは連続値であるため、回帰用のcomponentである「FABHMEBernGateLinearRg Component」を選択します。
学習を行う際に、以下のパラメーターを指定します。
seed
componentで使用される乱数シードです。
tree_depth
門木の深さを設定するパラメーターです。
shrink_threshold
門木の枝刈りの閾値を設定するパラメーターです。
seed, tree_depth, shrink_thresholdの詳細は、「SAMPO Reference V1.6.2」をご覧ください。
SPD定義¶
[4]:
from sampo.api import gen_spd
# seedの指定
seed = 0
# SPD の定義
spd_content = '''
dl -> std -> rg
-> bexp -> rg
---
components:
dl:
component: DataLoader
std:
component: StandardizeFDComponent
features: scale == 'real' or scale == 'integer'
bexp:
component: BinaryExpandFDComponent
features: scale == 'nominal'
rg:
component: FABHMEBernGateLinearRgComponent
features: name != 'price'
target: name == 'price'
standardize_target: True
tree_depth: 3
shrink_threshold: 2.0
global_settings:
keep_attributes:
- price
feature_exclude:
- price
'''
spd = gen_spd(template=spd_content)
SRC定義¶
[5]:
from sampo.api import gen_src
train_src_temp = '''
train:
type: learn
data_sources:
dl:
df: {{ data_df }}
attr_schema: {{ asd }}
'''
predict_src_temp = '''
predict:
type: predict
data_sources:
dl:
df: {{ data_df }}
attr_schema: {{ asd }}
model_process: {{ model_process }}
'''
学習、予測¶
学習を実行します。
[6]:
from sampo.api import process_store, process_runner
train_src = gen_src(template=train_src_temp, params={'data_df': train_data, 'asd': asd})
pstore_url = 'pstore_rg'
!rm -rf $pstore_url
process_store.create(pstore_url)
process_runner.run(spd=spd, src=train_src, pstore_url=pstore_url, seed=seed)
process_store.list_process_metadata(pstore_url)
[6]:
process name | version | started at | running time | status | |
---|---|---|---|---|---|
0 | train | 4a53449a-801c-415f-a9e4-a2e9b47c4979 | 2020-03-05 13:53:44.265040 | 00:00:01.261934 | Succeeded |
予測を実行します。
[7]:
train_proc_name = 'train'
predict_src = gen_src(template=predict_src_temp, params={'model_process': train_proc_name,
'data_df': predict_data, 'asd': asd})
process_runner.run(src=predict_src, pstore_url=pstore_url)
process_store.list_process_metadata(pstore_url)
[7]:
process name | version | started at | running time | status | |
---|---|---|---|---|---|
0 | predict | b11be9f7-5ef6-4cd9-994d-6f661864f931 | 2020-03-05 13:53:45.625547 | 00:00:00.736196 | Succeeded |
1 | train | 4a53449a-801c-415f-a9e4-a2e9b47c4979 | 2020-03-05 13:53:44.265040 | 00:00:01.261934 | Succeeded |
モデルバリデーション¶
予測精度の確認¶
作成したモデルを評価するため、予測精度を確認します。
評価指標として、RMSEを選定します。
[8]:
predict_proc_name = 'predict'
with process_store.open_process(pstore_url, predict_proc_name) as prl:
prl = prl
prl.load_comp_output_evaluation('rg')['std_root_mean_squared_error'][0]
[8]:
0.7621801464683062
門木の可視化¶
モデルがもつ門木を可視化します。
[9]:
from IPython.display import display, Image
from sampovis.api.fabhme_vis import save_gate_tree
save_gate_tree(pstore_url, train_proc_name, 'output')
display(Image('output/{}_rg_fabhme_gate_tree.png'.format(train_proc_name)))

予測式の可視化¶
モデルがもつ予測式情報にアクセスし、可視化します。
[10]:
import matplotlib.pyplot as plt
%matplotlib inline
model_params = prl.load_model('rg')
prediction_formulas = model_params['prediction_formulas']
relevant_feature_indices = prediction_formulas.sum(axis=1) != 0
prediction_formulas = prediction_formulas[relevant_feature_indices]
prediction_formulas.plot(kind='barh', figsize=(8, 4), stacked=True)
plt.show()
