Source code for simpleml.models.classifiers.sklearn.tree

"""
Wrapper module around `sklearn.tree`
"""

[docs]__author__ = "Elisha Yadgaran"
import logging from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier from simpleml.models.classifiers.external_models import ClassificationExternalModelMixin from .base_sklearn_classifier import SklearnClassifier
[docs]LOGGER = logging.getLogger(__name__)
""" Trees """
[docs]class WrappedSklearnDecisionTreeClassifier( DecisionTreeClassifier, ClassificationExternalModelMixin ):
[docs] def get_feature_metadata(self, features, **kwargs): feature_importances = self.feature_importances_.squeeze() if features is None or len(features) < len(feature_importances): LOGGER.warning( "Fewer feature names than features passed, defaulting to numbered list" ) features = range(len(feature_importances)) return dict(zip(features, feature_importances))
[docs]class SklearnDecisionTreeClassifier(SklearnClassifier):
[docs] def _create_external_model(self, **kwargs): return WrappedSklearnDecisionTreeClassifier(**kwargs)
[docs]class WrappedSklearnExtraTreeClassifier( ExtraTreeClassifier, ClassificationExternalModelMixin ):
[docs] def get_feature_metadata(self, features, **kwargs): feature_importances = self.feature_importances_.squeeze() if features is None or len(features) < len(feature_importances): LOGGER.warning( "Fewer feature names than features passed, defaulting to numbered list" ) features = range(len(feature_importances)) return dict(zip(features, feature_importances))
[docs]class SklearnExtraTreeClassifier(SklearnClassifier):
[docs] def _create_external_model(self, **kwargs): return WrappedSklearnExtraTreeClassifier(**kwargs)