Source code for simpleml.utils.training.create_persistable

'''
Module with helper classes to create new persistables
'''
from abc import ABCMeta, abstractmethod
from simpleml.persistables.meta_registry import SIMPLEML_REGISTRY
from simpleml.datasets.base_dataset import Dataset
from simpleml.pipelines.base_pipeline import Pipeline
from simpleml.models.base_model import Model
from simpleml.metrics.base_metric import Metric
from simpleml.utils.errors import TrainingError
import logging
from future.utils import with_metaclass

LOGGER = logging.getLogger(__name__)

__author__ = 'Elisha Yadgaran'


[docs]class PersistableCreator(with_metaclass(ABCMeta, object)):
[docs] @classmethod def retrieve_or_create(self, **kwargs): ''' Wrapper method to first attempt to retrieve a matching persistable and then create a new one if it isn't found ''' cls, filters = self.determine_filters(**kwargs) persistable = self.retrieve(cls, filters) if persistable is not None: LOGGER.info('Using existing persistable: {}, {}, {}'.format(cls.__tablename__, persistable.name, persistable.version)) persistable.load() return persistable else: LOGGER.info('Existing {} not found. Creating new one now'.format(cls.__tablename__)) persistable = self.create(**kwargs) LOGGER.info('Using new persistable: {}, {}, {}'.format(cls.__tablename__, persistable.name, persistable.version)) return persistable
[docs] @staticmethod def retrieve(cls, filters): ''' Query database using the table model (cls) and filters for a matching persistable ''' return cls.where(**filters).order_by(cls.version.desc()).first()
[docs] @staticmethod def retrieve_dependency(dependency_cls, **dependency_kwargs): ''' Base method to query for dependency Raises TrainingError if dependency does not exist ''' dependency = dependency_cls.retrieve( *dependency_cls.determine_filters(**dependency_kwargs)) if dependency is None: raise TrainingError('Expected dependency is missing') dependency.load() return dependency
[docs] @abstractmethod def determine_filters(cls, strict=False, **kwargs): ''' method to determine which filters to apply when looking for existing persistable :param strict: whether to fit objects first before assuming they are identical In theory if all inputs and classes are the same, the outputs should deterministically be the same as well (up to random iter). So, you dont need to fit objects to be sure they are the same Default design iterates through 2 (or 3) options when retrieving persistables: 1) By name and version (unique properties that define persistables) 2) By name, registered_name, and computed hash 2.5) Optionally, just use name and registered_name (assumes class definition is the same and would result in an identical persistable) Returns: database class, filter dictionary '''
[docs] @abstractmethod def create(cls, **kwargs): ''' method to create a new persistable with the desired parameters kwargs are passed directly to persistable '''
[docs] @staticmethod def retrieve_from_registry(registered_name): ''' stateless method to query registry for class definitions. handles errors ''' cls = SIMPLEML_REGISTRY.get(registered_name) if cls is None: raise TrainingError('Referenced class unregistered: {}'.format(registered_name)) return cls
[docs]class DatasetCreator(PersistableCreator):
[docs] @classmethod def determine_filters(cls, name='', version=None, strict=True, **kwargs): ''' stateless method to determine which filters to apply when looking for existing persistable Returns: database class, filter dictionary :param registered_name: Class name registered in SimpleML :param strict: whether to assume same class and name = same persistable, or, load the data and compare the hash ''' if version is not None: filters = { 'name': name, 'version': version } else: registered_name = kwargs.pop('registered_name') # Check if dependency object was passed pipeline = kwargs.pop('pipeline', None) if pipeline is None: # Use dependency reference to retrieve object pipeline = cls.retrieve_pipeline(**kwargs.pop('pipeline_kwargs', {})) if strict: # Build dummy object to retrieve hash to look for new_dataset = cls.retrieve_from_registry(registered_name)(name=name, **kwargs) new_dataset.add_pipeline(pipeline) new_dataset.build_dataframe() filters = { 'name': name, 'registered_name': registered_name, 'hash_': new_dataset._hash() } else: # Assume combo of name, class, and pipeline will be unique filters = { 'name': name, 'registered_name': registered_name, 'pipeline_id': pipeline.id if pipeline is not None else None } return Dataset, filters
[docs] @classmethod def create(cls, registered_name, pipeline=None, **kwargs): ''' Stateless method to create a new persistable with the desired parameters kwargs are passed directly to persistable :param registered_name: Class name registered in SimpleML :param dataset_pipeline: dataset pipeline object ''' if pipeline is None: # Use dependency reference to retrieve object pipeline = cls.retrieve_pipeline(**kwargs.pop('pipeline_kwargs', {})) new_dataset = cls.retrieve_from_registry(registered_name)(**kwargs) new_dataset.add_pipeline(pipeline) new_dataset.build_dataframe() new_dataset.save() return new_dataset
[docs] @classmethod def retrieve_pipeline(cls, **pipeline_kwargs): # Datasets do not require dataset pipelines so return None if it isn't passed if not pipeline_kwargs: LOGGER.warning('Dataset Pipeline parameters not passed, skipping dependencies. \ Only use this if dataset is already in the right format!') return None return cls.retrieve_dependency(PipelineCreator, **pipeline_kwargs)
[docs]class PipelineCreator(PersistableCreator):
[docs] @classmethod def determine_filters(cls, name='', version=None, strict=False, **kwargs): ''' stateless method to determine which filters to apply when looking for existing persistable Returns: database class, filter dictionary :param registered_name: Class name registered in SimpleML :param strict: whether to fit objects first before assuming they are identical In theory if all inputs and classes are the same, the outputs should deterministically be the same as well (up to random iter). So, you dont need to fit objects to be sure they are the same ''' if version is not None: filters = { 'name': name, 'version': version } else: # Check if dependency object was passed dataset = kwargs.pop('dataset', None) if dataset is None: # Use dependency reference to retrieve object dataset = cls.retrieve_dataset(**kwargs.pop('dataset_kwargs', {})) # Build dummy object to retrieve hash to look for registered_name = kwargs.pop('registered_name') new_pipeline = cls.retrieve_from_registry(registered_name)(name=name, **kwargs) new_pipeline.add_dataset(dataset) if strict: new_pipeline.fit() filters = { 'name': name, 'registered_name': registered_name, 'hash_': new_pipeline._hash() } return Pipeline, filters
[docs] @classmethod def create(cls, registered_name, dataset=None, **kwargs): ''' Stateless method to create a new persistable with the desired parameters kwargs are passed directly to persistable :param registered_name: Class name registered in SimpleML :param dataset: dataset object ''' if dataset is None: # Use dependency reference to retrieve object dataset = cls.retrieve_dataset(**kwargs.pop('dataset_kwargs', {})) new_pipeline = cls.retrieve_from_registry(registered_name)(**kwargs) new_pipeline.add_dataset(dataset) new_pipeline.fit() new_pipeline.save() return new_pipeline
[docs] @classmethod def retrieve_dataset(cls, **dataset_kwargs): return cls.retrieve_dependency(DatasetCreator, **dataset_kwargs)
[docs]class ModelCreator(PersistableCreator):
[docs] @classmethod def determine_filters(cls, name='', version=None, strict=False, **kwargs): ''' stateless method to determine which filters to apply when looking for existing persistable Returns: database class, filter dictionary :param registered_name: Class name registered in SimpleML :param strict: whether to fit objects first before assuming they are identical In theory if all inputs and classes are the same, the outputs should deterministically be the same as well (up to random iter). So, you dont need to fit objects to be sure they are the same ''' if version is not None: filters = { 'name': name, 'version': version } else: # Check if dependency object was passed pipeline = kwargs.pop('pipeline', None) if pipeline is None: # Use dependency reference to retrieve object pipeline = cls.retrieve_pipeline(**kwargs.pop('pipeline_kwargs', {})) # Build dummy object to retrieve hash to look for registered_name = kwargs.pop('registered_name') new_model = cls.retrieve_from_registry(registered_name)(name=name, **kwargs) new_model.add_pipeline(pipeline) if strict: new_model.fit() filters = { 'name': name, 'registered_name': registered_name, 'hash_': new_model._hash() } return Model, filters
[docs] @classmethod def create(cls, registered_name, pipeline=None, **kwargs): ''' Stateless method to create a new persistable with the desired parameters kwargs are passed directly to persistable :param registered_name: Class name registered in SimpleML :param pipeline: pipeline object ''' if pipeline is None: # Use dependency reference to retrieve object pipeline = cls.retrieve_pipeline(**kwargs.pop('pipeline_kwargs', {})) new_model = cls.retrieve_from_registry(registered_name)(**kwargs) new_model.add_pipeline(pipeline) new_model.fit() new_model.save() return new_model
[docs] @classmethod def retrieve_pipeline(cls, **pipeline_kwargs): return cls.retrieve_dependency(PipelineCreator, **pipeline_kwargs)
[docs]class MetricCreator(PersistableCreator):
[docs] @classmethod def determine_filters(cls, name=None, model_id=None, strict=False, **kwargs): ''' stateless method to determine which filters to apply when looking for existing persistable Returns: database class, filter dictionary :param registered_name: Class name registered in SimpleML :param strict: whether to fit objects first before assuming they are identical In theory if all inputs and classes are the same, the outputs should deterministically be the same as well (up to random iter). So, you dont need to fit objects to be sure they are the same ''' # Check if dependency object was passed model = kwargs.pop('model', None) if name is not None and (model_id is not None or model is not None): # Can't use default name because metrics are hard coded to reflect dataset split + class filters = { 'name': name, 'model_id': model_id if model_id is not None else model.id, } else: if model is None: # Use dependency reference to retrieve object model = cls.retrieve_model(**kwargs.pop('model_kwargs', {})) # Build dummy object to retrieve hash to look for registered_name = kwargs.pop('registered_name') new_metric = cls.retrieve_from_registry(registered_name)(name=name, **kwargs) new_metric.add_model(model) if strict: new_metric.score() filters = { 'name': new_metric.name, 'registered_name': registered_name, 'hash_': new_metric._hash() } return Metric, filters
[docs] @classmethod def create(cls, registered_name, model=None, **kwargs): ''' Stateless method to create a new persistable with the desired parameters kwargs are passed directly to persistable :param registered_name: Class name registered in SimpleML :param model: model class ''' if model is None: # Use dependency reference to retrieve object model = cls.retrieve_model(**kwargs.pop('model_kwargs', {})) new_metric = cls.retrieve_from_registry(registered_name)(**kwargs) new_metric.add_model(model) new_metric.score() new_metric.save() return new_metric
[docs] @classmethod def retrieve_model(cls, **model_kwargs): return cls.retrieve_dependency(ModelCreator, **model_kwargs)