Source code for simpleml.utils.scoring.load_persistable

"""
Module to query registry and retrieve persistables from wherever they are
stored.
"""

[docs]__author__ = "Elisha Yadgaran"
import logging from typing import Any, Dict from simpleml.datasets.base_dataset import Dataset from simpleml.metrics.base_metric import Metric from simpleml.models.base_model import Model from simpleml.orm.dataset import ORMDataset from simpleml.orm.metric import ORMMetric from simpleml.orm.model import ORMModel from simpleml.orm.persistable import ORMPersistable from simpleml.orm.pipeline import ORMPipeline from simpleml.persistables.base_persistable import Persistable from simpleml.pipelines.base_pipeline import Pipeline from simpleml.utils.errors import SimpleMLError from simpleml.utils.library_versions import INSTALLED_LIBRARIES
[docs]LOGGER = logging.getLogger(__name__)
[docs]class PersistableLoader(object): """ Wrapper class to load various persistables Sqlalchemy-mixins active record style allows for keyword based filtering: `BaseClass.where(**filters).order_by(**ordering).first()` """ @classmethod
[docs] def load_persistable( cls, persistable_class: ORMPersistable, filters: Dict[str, Any] ) -> Persistable: record = ( persistable_class.where(**filters) .order_by(persistable_class.version.desc()) .first() ) if record is not None: persistable = record.load(load_externals=False) cls.validate_environment(persistable) return persistable else: raise SimpleMLError( "No persistable found for specified filters: {}".format(filters)
) @classmethod
[docs] def load_dataset(cls, **filters) -> Dataset: return cls.load_persistable(ORMDataset, filters)
@classmethod
[docs] def load_pipeline(cls, **filters) -> Pipeline: return cls.load_persistable(ORMPipeline, filters)
@classmethod
[docs] def load_model(cls, **filters) -> Model: return cls.load_persistable(ORMModel, filters)
@classmethod
[docs] def load_metric(cls, **filters) -> Metric: return cls.load_persistable(ORMMetric, filters)
@staticmethod
[docs] def validate_environment(persistable: Persistable) -> None: training_env = persistable.library_versions scoring_env = INSTALLED_LIBRARIES mismatches = [] for package, version in training_env.items(): if package not in scoring_env: mismatches.append( {"package": package, "expected_version": version, "version": "None"} ) else: if version != scoring_env[package]: mismatches.append( { "package": package, "expected_version": version, "version": scoring_env[package], } ) warning_msg = "Attempted to score with different dependencies than training, proceed at your own risk" if mismatches: LOGGER.warning(warning_msg) for mismatch in mismatches: LOGGER.warning( "Expected: {package}=={expected_version}, found: {version}".format( **mismatch
) )