Source code for simpleml.models.classifiers.sklearn.tree
"""
Wrapper module around `sklearn.tree`
"""
import logging
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
from simpleml.models.classifiers.external_models import ClassificationExternalModelMixin
from .base_sklearn_classifier import SklearnClassifier
"""
Trees
"""
[docs]class WrappedSklearnDecisionTreeClassifier(
DecisionTreeClassifier, ClassificationExternalModelMixin
):
[docs] def get_feature_metadata(self, features, **kwargs):
feature_importances = self.feature_importances_.squeeze()
if features is None or len(features) < len(feature_importances):
LOGGER.warning(
"Fewer feature names than features passed, defaulting to numbered list"
)
features = range(len(feature_importances))
return dict(zip(features, feature_importances))
[docs]class SklearnDecisionTreeClassifier(SklearnClassifier):
[docs] def _create_external_model(self, **kwargs):
return WrappedSklearnDecisionTreeClassifier(**kwargs)
[docs]class WrappedSklearnExtraTreeClassifier(
ExtraTreeClassifier, ClassificationExternalModelMixin
):
[docs] def get_feature_metadata(self, features, **kwargs):
feature_importances = self.feature_importances_.squeeze()
if features is None or len(features) < len(feature_importances):
LOGGER.warning(
"Fewer feature names than features passed, defaulting to numbered list"
)
features = range(len(feature_importances))
return dict(zip(features, feature_importances))