Source code for simpleml.utils.scoring.load_persistable

'''
Module to query registry and retrieve persistables from wherever they are
stored.
'''

[docs]__author__ = 'Elisha Yadgaran'
from simpleml.datasets.base_dataset import Dataset from simpleml.pipelines.base_pipeline import Pipeline from simpleml.models.base_model import Model from simpleml.metrics.base_metric import Metric from simpleml.utils.errors import SimpleMLError from simpleml.utils.library_versions import INSTALLED_LIBRARIES import logging
[docs]LOGGER = logging.getLogger(__name__)
[docs]class PersistableLoader(object): ''' Wrapper class to load various persistables Sqlalchemy-mixins active record style allows for keyword based filtering: `BaseClass.where(**filters).order_by(**ordering).first()` ''' @classmethod
[docs] def load_persistable(cls, persistable_class, filters): persistable = persistable_class.where(**filters).order_by(persistable_class.version.desc()).first() if persistable is not None: cls.validate_environment(persistable) persistable.load(load_externals=False) return persistable else: raise SimpleMLError('No persistable found for specified filters: {}'.format(filters))
@classmethod
[docs] def load_dataset(cls, name='default', **filters): filters['name'] = name return cls.load_persistable(Dataset, filters)
@classmethod
[docs] def load_pipeline(cls, name='default', **filters): filters['name'] = name return cls.load_persistable(Pipeline, filters)
@classmethod
[docs] def load_model(cls, name='default', **filters): filters['name'] = name return cls.load_persistable(Model, filters)
@classmethod
[docs] def load_metric(cls, name, model_id, **filters): filters['name'] = name filters['model_id'] = model_id return cls.load_persistable(Metric, filters)
@staticmethod
[docs] def validate_environment(persistable): training_env = persistable.library_versions scoring_env = INSTALLED_LIBRARIES mismatches = [] for package, version in training_env.items(): if package not in scoring_env: mismatches.append({'package': package, 'expected_version': version, 'version': 'None'}) else: if version != scoring_env[package]: mismatches.append({'package': package, 'expected_version': version, 'version': scoring_env[package]}) warning_msg = 'Attempted to score with different dependencies than training, proceed at your own risk' if mismatches: LOGGER.warning(warning_msg) for mismatch in mismatches: LOGGER.warning('Expected: {package}=={expected_version}, found: {version}'.format(**mismatch))