The main purpose of this package is to serve as a library of decision tree classifiers and regressors, both learned from supervised data and learned to explain other classifiers.
What we aim to do here is create a general framework and a collection of ingredients that go into learning a decision tree. These ingredients can then be combined to create an object that performs the learning/explanation task.
The following outlines how we abstract the task in general terms.
A decision tree is a tree the internal (non-leaf) nodes of which correspond to splits and the leaves of which correspond to predictors.
A split is an object that given a data matrix of
Each leaf node must be able to, given
For a data matrix of
Status: Not implemented
In some applications it may be of interest to query the decision path for a particular instance.
Abstractly, to learn a decision tree one needs to construct a tree object given data (or in the case of an explanation, a model) to fit. This entails specifying a fitter
- Supervised learning fitter: takes data and targets
- Type 1 explanation fitter: takes data, a predictor, and a recipe for generating unlabeled data based on given data
- Type 2 explanation fitter: takes a predictor and an unlabeled data generator. Status: Not implemented
The fitter must then call a builder that builds the tree given the provided information. The most common build strategy is a greedy building of the tree, followed (optionally) by pruning of the tree.
The greedy tree building process is the following:
- Initialize a root node
- Push the node into a queue (or any sequential access data structure such as a stack or heap).
- While the queue is not empty and global stopping criteria aren't met:
- Pop a node from the queue
- Find the locally best split at this node
- If there is no good split or if local stopping criteria are met, make this node a leaf
- Else:
- Set the identified split to be the split at the node.
- Initialize children based on the split (one for each split branch)
- Push the children into the queue.
The elements that vary in the building process from algorithm to algorithm are:
- The sequential access data structure
- The algorithm for finding the locally best split
- The algorithm for building a leaf node
- The local stopping criteria
- The global stopping criteria
TODO
For convenience we provide several 'recipes' for standard learners, such as a classic decision tree learner and Trepan.
These are found in generalizedtees.recipes
and the source code therein can be used as an example for composing learners.
Initially we hoped to make the built estimators pass sklearn's sklearn.utils.estimator_checks.check_estimator
check, but this proves challenging with composed models.
Instead, we aim for a consistent API in terms of method names (e.g. fit
, predict
, predict_proba
) which should suffice in many cases for working alongside sklearn code.
Future consideration: We will revisit the feasibility of inheriting from the sklearn.base.BaseEstimator base class and the sklearn.base.ClassifierMixin (sklearn.base.RegressorMixin) mixin.
For completeness we include notes on what one would need to do to pass sklearn checks (not currently the case for our estimators).
To be able to pass the scikit-learn checks an estimator must:
- All arguments of
__init__
must have a default value - All arguments of
__init__
must correspond to class attributes. __init__
cannot set attributes that do not correspond to its arguments.- The second parameter of
fit
must bey
orY
. - See the following subsection for a description of what is considered to be part of the scikit-learn interface
Method | Implemented in |
---|---|
predict(data) |
|
`fit(data,[y]) | |
score(data, y) |
BaseEstimator |
get_params and set_params |
BaseEstimator |
Particularly, with sklearn.utils.estimator_checks.check_estimator
.
Here we document some of the less-obvious checks it performs:
- Passing an object to the
y
argument offit
should raise an exception. Usesklearn.utils.multiclass.check_classification_targets
to ensure it. fit
returnsself
- Predict raises an error for infinite and NaN inputs. (I prefer a more robust implementation of decision trees that can handle infinite and missing values).
- Classifiers must have a
classes_
element - Trying to predict without fitting must raise an error containing the word "fit"
Licensed under the Apache License, Version 2.0, Copyright 2020 Yuriy Sverchkov