Source code for simpleml.models.classifiers.sklearn.naive_bayes

"""
Wrapper module around `sklearn.naive_bayes`
"""

[docs]__author__ = "Elisha Yadgaran"
import logging from sklearn.naive_bayes import BernoulliNB, GaussianNB, MultinomialNB from simpleml.models.classifiers.external_models import ClassificationExternalModelMixin from .base_sklearn_classifier import SklearnClassifier
[docs]LOGGER = logging.getLogger(__name__)
""" Bernoulli """
[docs]class WrappedSklearnBernoulliNB(BernoulliNB, ClassificationExternalModelMixin):
[docs] def get_feature_metadata(self, features, **kwargs): # By default probabilities are returned for all classes, only displays class 0 log_probs = self.feature_log_prob_.squeeze()[0] if features is None or len(features) < len(log_probs): LOGGER.warning( "Fewer feature names than features passed, defaulting to numbered list" ) features = range(len(log_probs)) return dict(zip(features, log_probs))
[docs]class SklearnBernoulliNB(SklearnClassifier):
[docs] def _create_external_model(self, **kwargs): return WrappedSklearnBernoulliNB(**kwargs)
""" Gaussian """
[docs]class WrappedSklearnGaussianNB(GaussianNB, ClassificationExternalModelMixin):
[docs] def get_feature_metadata(self, features, **kwargs): # By default probabilities are returned for all classes, only displays class 0 thetas = self.theta_.squeeze()[0] if features is None or len(features) < len(thetas): LOGGER.warning( "Fewer feature names than features passed, defaulting to numbered list" ) features = range(len(thetas)) return dict(zip(features, thetas))
[docs]class SklearnGaussianNB(SklearnClassifier):
[docs] def _create_external_model(self, **kwargs): return WrappedSklearnGaussianNB(**kwargs)
""" Multinomial """
[docs]class WrappedSklearnMultinomialNB(MultinomialNB, ClassificationExternalModelMixin):
[docs] def get_feature_metadata(self, features, **kwargs): # By default probabilities are returned for all classes, only displays class 0 log_probs = self.feature_log_prob_.squeeze()[0] if features is None or len(features) < len(log_probs): LOGGER.warning( "Fewer feature names than features passed, defaulting to numbered list" ) features = range(len(log_probs)) return dict(zip(features, log_probs))
[docs]class SklearnMultinomialNB(SklearnClassifier):
[docs] def _create_external_model(self, **kwargs): return WrappedSklearnMultinomialNB(**kwargs)