sklearn_fab.SklearnFABTree

class sklearn_fab.SklearnFABTree(model, feature_names)

A binary tree object based on the latent variable prior information found in the models created by SklearnFABEstimator.

Parameters:
modelfab.hme.model.HMESupervisedModel

Model object where gate_tree information will be based from.

Attributes:
node_countint

Number of nodes in the SklearnFABTree.

num_gatesint

Number of gate nodes in the SklearnFABTree.

num_compsint

Number of comp nodes in the SklearnFABTree.

max_depthint

Maximum tree depth, i.e. the maximum depth of its leaves.

root_nodeint

node_id of the root node of the SklearnFABTree.

get_decision_path(comp_id)

Extracts a list of node_id indicating the node traversal path from the root to the specified component.

Parameters:
comp_idint

The ID of the component whose traversal path from root is returned.

Returns:
decision_pathlist of int

Describes the traversal path of samples to the assigned component from root to the component’s direct parent gate node.

Raises:
TypeError

The specified value is not int.

ValueError

The specified value is outside the set of comp ids.

Examples

>>> estimator.tree_.get_decision_path(12)
[0, 9, 10]
get_gate_functions()

Extracts gate function information from each gate node in the latent variable prior of the model.

Returns:
gate_funcdict

The resulting dictionary has keys as node_id and values as the node’s GateFunction object. Refer to the FAB Reference > API Documents for more information about the GateFunction object.

Examples

Displaying information of each gate function

>>> for node_id, gate_func in estimator.tree_.get_gate_functions().items():
...     print(f'nodeID:{node_id}, '
...           f'feature:{estimator.feature_names_[gate_func.feature_id]}, '
...           f'threshold:{gate_func.threshold}, prob_left:{gate_func.prob_left}')
nodeID:0, feature:'X[2]', threshold:0.329, prob_left:1.0
nodeID:1, feature:'X[0]', threshold:-1.962, prob_left:0.0
to_dict()

Outputs a dictionary summarizing the latent variable prior related information.

Returns:
tree_dictdict

Describes the node and branching information of the latent variable prior information derived from the model.

The output dictionary has the following keys and values:

  • node_infodict

    Dictionary describing each node in the tree. Each key in this dictionary is an ID assigned to each node called node_id. The value of each key is another dictionary describing that node’s information.

    If the node is a ‘gate’, the dictionary contains the following keys:

    • node_idint

      ID number assigned to each node in the tree

    • typestr

      Displays the type of node which is ‘gate’

    • gate_indexint

      Index number assigned to each gate in the tree.

    • gate_funcdict

      Dictionary that contains information describing the features applied to the decision and all related coefficients, constants to describe the gating function.

    If the node is a ‘comp’, the dictionary contains the following keys:

    • node_idint

      ID number assigned to each node in the tree

    • typestr

      Displays the type of node which is ‘comp’.

    • comp_idint

      Component ID

  • branch_infodict

    Dictionary describing the children of gate nodes. Each key in this dictionary is the node_id of each gate node. The value of each key is another dictionary describing the gate’s branching.

    The dictionary contains the following keys:

    • left_idint

      Left child’s node_id as described in the node_info dict.

    • right_idint

      Right child’s node_id as described in the node_info dict.

  • root_idint

    node_id of the root node

Examples

>>> estimator.tree_.to_dict()
{'node_info': {
  0: {'type': 'gate',
   'node_id': 0,
   'gate_index': 0,
   'gate_func': {'feature': 'X[2]', 'feature_id': 2,
                 'threshold': 0.329, 'prob_left': 1.0}},
  1: {'type': 'gate',
   'node_id': 1,
   'gate_index': 1,
   'gate_func': {'feature': 'X[0]', 'feature_id': 0,
                 'threshold': -1.962, 'prob_left': 0.0}},
  2: {'type': 'comp', 'node_id': 2, 'comp_id': 12},
  3: {'type': 'comp', 'node_id': 3, 'comp_id': 20},
  4: {'type': 'comp', 'node_id': 4, 'comp_id': 24}},
 'branch_info': {
  0: {'left_id': 1, 'right_id': 4},
  1: {'left_id': 2, 'right_id': 3}},
 'root_id': 0}