sklearn_fab.SklearnFABTree¶
- class sklearn_fab.SklearnFABTree(model, feature_names)¶
A binary tree object based on the latent variable prior information found in the models created by SklearnFABEstimator.
- Parameters:
- modelfab.hme.model.HMESupervisedModel
Model object where gate_tree information will be based from.
- Attributes:
- node_countint
Number of nodes in the SklearnFABTree.
- num_gatesint
Number of gate nodes in the SklearnFABTree.
- num_compsint
Number of comp nodes in the SklearnFABTree.
- max_depthint
Maximum tree depth, i.e. the maximum depth of its leaves.
- root_nodeint
node_id of the root node of the SklearnFABTree.
- get_decision_path(comp_id)¶
Extracts a list of node_id indicating the node traversal path from the root to the specified component.
- Parameters:
- comp_idint
The ID of the component whose traversal path from root is returned.
- Returns:
- decision_pathlist of int
Describes the traversal path of samples to the assigned component from root to the component’s direct parent gate node.
- Raises:
- TypeError
The specified value is not int.
- ValueError
The specified value is outside the set of comp ids.
Examples
>>> estimator.tree_.get_decision_path(12) [0, 9, 10]
- get_gate_functions()¶
Extracts gate function information from each gate node in the latent variable prior of the model.
- Returns:
- gate_funcdict
The resulting dictionary has keys as node_id and values as the node’s GateFunction object. Refer to the FAB Reference > API Documents for more information about the GateFunction object.
Examples
Displaying information of each gate function
>>> for node_id, gate_func in estimator.tree_.get_gate_functions().items(): ... print(f'nodeID:{node_id}, ' ... f'feature:{estimator.feature_names_[gate_func.feature_id]}, ' ... f'threshold:{gate_func.threshold}, prob_left:{gate_func.prob_left}') nodeID:0, feature:'X[2]', threshold:0.329, prob_left:1.0 nodeID:1, feature:'X[0]', threshold:-1.962, prob_left:0.0
- to_dict()¶
Outputs a dictionary summarizing the latent variable prior related information.
- Returns:
- tree_dictdict
Describes the node and branching information of the latent variable prior information derived from the model.
The output dictionary has the following keys and values:
- node_infodict
Dictionary describing each node in the tree. Each key in this dictionary is an ID assigned to each node called node_id. The value of each key is another dictionary describing that node’s information.
If the node is a ‘gate’, the dictionary contains the following keys:
- node_idint
ID number assigned to each node in the tree
- typestr
Displays the type of node which is ‘gate’
- gate_indexint
Index number assigned to each gate in the tree.
- gate_funcdict
Dictionary that contains information describing the features applied to the decision and all related coefficients, constants to describe the gating function.
If the node is a ‘comp’, the dictionary contains the following keys:
- node_idint
ID number assigned to each node in the tree
- typestr
Displays the type of node which is ‘comp’.
- comp_idint
Component ID
- branch_infodict
Dictionary describing the children of gate nodes. Each key in this dictionary is the node_id of each gate node. The value of each key is another dictionary describing the gate’s branching.
The dictionary contains the following keys:
- left_idint
Left child’s node_id as described in the node_info dict.
- right_idint
Right child’s node_id as described in the node_info dict.
- root_idint
node_id of the root node
Examples
>>> estimator.tree_.to_dict() {'node_info': { 0: {'type': 'gate', 'node_id': 0, 'gate_index': 0, 'gate_func': {'feature': 'X[2]', 'feature_id': 2, 'threshold': 0.329, 'prob_left': 1.0}}, 1: {'type': 'gate', 'node_id': 1, 'gate_index': 1, 'gate_func': {'feature': 'X[0]', 'feature_id': 0, 'threshold': -1.962, 'prob_left': 0.0}}, 2: {'type': 'comp', 'node_id': 2, 'comp_id': 12}, 3: {'type': 'comp', 'node_id': 3, 'comp_id': 20}, 4: {'type': 'comp', 'node_id': 4, 'comp_id': 24}}, 'branch_info': { 0: {'left_id': 1, 'right_id': 4}, 1: {'left_id': 2, 'right_id': 3}}, 'root_id': 0}