"""
Module with helper classes to create new persistables
"""
[docs]__author__ = "Elisha Yadgaran"
import logging
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional, Tuple
from simpleml.datasets.base_dataset import Dataset
from simpleml.metrics.base_metric import Metric
from simpleml.models.base_model import Model
from simpleml.orm.dataset import ORMDataset
from simpleml.orm.metric import ORMMetric
from simpleml.orm.model import ORMModel
from simpleml.orm.persistable import ORMPersistable
from simpleml.orm.pipeline import ORMPipeline
from simpleml.persistables.base_persistable import Persistable
from simpleml.pipelines.base_pipeline import Pipeline
from simpleml.registries import SIMPLEML_REGISTRY
from simpleml.utils.errors import TrainingError
[docs]LOGGER = logging.getLogger(__name__)
[docs]class PersistableCreator(object, metaclass=ABCMeta):
@classmethod
[docs] def retrieve_or_create(self, **kwargs) -> Persistable:
"""
Wrapper method to first attempt to retrieve a matching persistable and
then create a new one if it isn't found
"""
cls, filters = self.determine_filters(**kwargs)
orm_persistable = self.retrieve(cls, filters)
if orm_persistable is not None:
LOGGER.info(
"Using existing persistable: {}, {}, {}".format(
cls.__tablename__, orm_persistable.name, orm_persistable.version
)
)
return orm_persistable.load()
else:
LOGGER.info(
"Existing {} not found. Creating new one now".format(cls.__tablename__)
)
persistable = self.create(**kwargs)
LOGGER.info(
"Using new persistable: {}, {}, {}".format(
cls.__tablename__, persistable.name, persistable.version
)
)
return persistable
@staticmethod
[docs] def retrieve(cls, filters: Dict[str, Any]) -> ORMPersistable:
"""
Query database using the table model (cls) and filters for a matching
persistable
"""
return cls.where(**filters).order_by(cls.version.desc()).first()
@staticmethod
[docs] def retrieve_dependency(
dependency_cls: "PersistableCreator", **dependency_kwargs
) -> Persistable:
"""
Base method to query for dependency
Raises TrainingError if dependency does not exist
"""
if not dependency_kwargs:
raise TrainingError(
"Must pass at least one key:value to look up in database"
)
dependency = dependency_cls.retrieve(
*dependency_cls.determine_filters(**dependency_kwargs)
)
if dependency is None:
raise TrainingError("Expected dependency is missing")
return dependency.load()
@classmethod
[docs] def retrieve_dataset(
cls,
dataset: Optional[Dataset] = None,
dataset_id: str = None,
dataset_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Dataset:
if dataset is not None:
return dataset
if dataset_id is not None:
return cls.retrieve_dependency(DatasetCreator, id=dataset_id)
if dataset_kwargs is not None:
# Use dependency reference to retrieve object
return cls.retrieve_dependency(DatasetCreator, **dataset_kwargs)
@classmethod
[docs] def retrieve_pipeline(
cls,
pipeline: Optional[Pipeline] = None,
pipeline_id: str = None,
pipeline_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Pipeline:
if pipeline is not None:
return pipeline
if pipeline_id is not None:
return cls.retrieve_dependency(PipelineCreator, id=pipeline_id)
if pipeline_kwargs is not None:
# Use dependency reference to retrieve object
return cls.retrieve_dependency(PipelineCreator, **pipeline_kwargs)
@classmethod
[docs] def retrieve_model(
cls,
model: Optional[Model] = None,
model_id: str = None,
model_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Model:
if model is not None:
return model
if model_id is not None:
return cls.retrieve_dependency(ModelCreator, id=model_id)
if model_kwargs is not None:
# Use dependency reference to retrieve object
return cls.retrieve_dependency(ModelCreator, **model_kwargs)
@abstractmethod
[docs] def determine_filters(cls, strict: bool = False, **kwargs):
"""
method to determine which filters to apply when looking for
existing persistable
:param strict: whether to fit objects first before assuming they are identical
In theory if all inputs and classes are the same, the outputs should deterministically
be the same as well (up to random iter). So, you dont need to fit objects
to be sure they are the same
Default design iterates through 2 (or 3) options when retrieving persistables:
1) By name and version (unique properties that define persistables)
2) By name, registered_name, and computed hash
2.5) Optionally, just use name and registered_name (assumes class
definition is the same and would result in an identical persistable)
Returns: database class, filter dictionary
"""
@abstractmethod
[docs] def create(cls, **kwargs):
"""
method to create a new persistable with the desired parameters
kwargs are passed directly to persistable
"""
@staticmethod
[docs] def retrieve_from_registry(registered_name: str) -> Persistable:
"""
stateless method to query registry for class definitions. handles errors
"""
cls = SIMPLEML_REGISTRY.get(registered_name)
if cls is None:
raise TrainingError(
"Referenced class unregistered: {}".format(registered_name)
)
return cls
[docs]class DatasetCreator(PersistableCreator):
@classmethod
[docs] def determine_filters(
cls, strict: bool = True, **kwargs
) -> Tuple[Dataset, Dict[str, Any]]:
"""
stateless method to determine which filters to apply when looking for
existing persistable
Returns: database class, filter dictionary
:param registered_name: Class name registered in SimpleML
:param strict: whether to assume same class and name = same persistable,
or, load the data and compare the hash
"""
if (
("id" not in kwargs)
and ("name" not in kwargs or "version" not in kwargs)
and ("registered_name" not in kwargs)
):
raise TrainingError(
"Need to pass at least one of: `id`, `name, version`, `registered_name` to compare against existing persistables"
)
if "id" in kwargs:
filters = {"id": kwargs["id"]}
elif "name" in kwargs and "version" in kwargs:
filters = {
"name": kwargs["name"],
"version": kwargs["version"],
}
else:
registered_name = kwargs["registered_name"]
# Check if dependency object was passed
pipeline = cls.retrieve_pipeline(**kwargs)
if strict:
# Build dummy object to retrieve hash to look for
new_dataset = cls.retrieve_from_registry(registered_name)(**kwargs)
new_dataset.add_pipeline(pipeline)
new_dataset.build_dataframe()
filters = {
"name": new_dataset.name,
"registered_name": new_dataset.registered_name,
"hash_": new_dataset._hash(),
}
else:
# Assume combo of name, class, and pipeline will be unique
filters = {
"registered_name": registered_name,
"pipeline_id": pipeline.id if pipeline is not None else None,
}
if "name" in kwargs:
filters["name"] = kwargs["name"]
return ORMDataset, filters
@classmethod
[docs] def create(cls, registered_name: str, **kwargs) -> Dataset:
"""
Stateless method to create a new persistable with the desired parameters
kwargs are passed directly to persistable
:param registered_name: Class name registered in SimpleML
:param dataset_pipeline: dataset pipeline object
"""
pipeline = cls.retrieve_pipeline(**kwargs)
new_dataset = cls.retrieve_from_registry(registered_name)(**kwargs)
new_dataset.add_pipeline(pipeline)
new_dataset.build_dataframe()
new_dataset.save()
return new_dataset
[docs]class PipelineCreator(PersistableCreator):
@classmethod
[docs] def determine_filters(
cls, strict: bool = False, **kwargs
) -> Tuple[Pipeline, Dict[str, Any]]:
"""
stateless method to determine which filters to apply when looking for
existing persistable
Returns: database class, filter dictionary
:param registered_name: Class name registered in SimpleML
:param strict: whether to fit objects first before assuming they are identical
In theory if all inputs and classes are the same, the outputs should deterministically
be the same as well (up to random iter). So, you dont need to fit objects
to be sure they are the same
"""
if (
("id" not in kwargs)
and ("name" not in kwargs or "version" not in kwargs)
and (
"registered_name" not in kwargs
or ("dataset" not in kwargs and "dataset_kwargs" not in kwargs)
)
):
raise TrainingError(
"Need to pass at least one of: `id`, `name, version`, `registered_name, dataset`, `registered_name, dataset_kwargs` to compare against existing persistables"
)
if "id" in kwargs:
filters = {"id": kwargs["id"]}
elif "name" in kwargs and "version" in kwargs:
filters = {
"name": kwargs["name"],
"version": kwargs["version"],
}
else:
dataset = cls.retrieve_dataset(**kwargs)
# Build dummy object to retrieve hash to look for
registered_name = kwargs["registered_name"]
new_pipeline = cls.retrieve_from_registry(registered_name)(**kwargs)
new_pipeline.add_dataset(dataset)
if strict:
new_pipeline.fit()
filters = {
"name": new_pipeline.name,
"registered_name": new_pipeline.registered_name,
"hash_": new_pipeline._hash(),
}
return ORMPipeline, filters
@classmethod
[docs] def create(cls, registered_name: str, **kwargs) -> Pipeline:
"""
Stateless method to create a new persistable with the desired parameters
kwargs are passed directly to persistable
:param registered_name: Class name registered in SimpleML
:param dataset: dataset object
"""
dataset = cls.retrieve_dataset(**kwargs)
new_pipeline = cls.retrieve_from_registry(registered_name)(**kwargs)
new_pipeline.add_dataset(dataset)
new_pipeline.fit()
new_pipeline.save()
return new_pipeline
[docs]class ModelCreator(PersistableCreator):
@classmethod
[docs] def determine_filters(
cls, strict: bool = False, **kwargs
) -> Tuple[Model, Dict[str, Any]]:
"""
stateless method to determine which filters to apply when looking for
existing persistable
Returns: database class, filter dictionary
:param registered_name: Class name registered in SimpleML
:param strict: whether to fit objects first before assuming they are identical
In theory if all inputs and classes are the same, the outputs should deterministically
be the same as well (up to random iter). So, you dont need to fit objects
to be sure they are the same
"""
if (
("id" not in kwargs)
and ("name" not in kwargs or "version" not in kwargs)
and (
"registered_name" not in kwargs
or ("pipeline" not in kwargs and "pipeline_kwargs" not in kwargs)
)
):
raise TrainingError(
"Need to pass at least one of: `id`, `name, version`, `registered_name, pipeline`, `registered_name, pipeline_kwargs` to compare against existing persistables"
)
if "id" in kwargs:
filters = {"id": kwargs["id"]}
elif "name" in kwargs and "version" in kwargs:
filters = {
"name": kwargs["name"],
"version": kwargs["version"],
}
else:
pipeline = cls.retrieve_pipeline(**kwargs)
# Build dummy object to retrieve hash to look for
registered_name = kwargs["registered_name"]
new_model = cls.retrieve_from_registry(registered_name)(**kwargs)
new_model.add_pipeline(pipeline)
if strict:
new_model.fit()
filters = {
"name": new_model.name,
"registered_name": new_model.registered_name,
"hash_": new_model._hash(),
}
return ORMModel, filters
@classmethod
[docs] def create(cls, registered_name: str, **kwargs) -> Model:
"""
Stateless method to create a new persistable with the desired parameters
kwargs are passed directly to persistable
:param registered_name: Class name registered in SimpleML
:param pipeline: pipeline object
"""
pipeline = cls.retrieve_pipeline(**kwargs)
new_model = cls.retrieve_from_registry(registered_name)(**kwargs)
new_model.add_pipeline(pipeline)
new_model.fit()
new_model.save()
return new_model
[docs]class MetricCreator(PersistableCreator):
@classmethod
[docs] def determine_filters(
cls, strict: bool = False, **kwargs
) -> Tuple[Metric, Dict[str, Any]]:
"""
stateless method to determine which filters to apply when looking for
existing persistable
Returns: database class, filter dictionary
:param registered_name: Class name registered in SimpleML
:param strict: whether to fit objects first before assuming they are identical
In theory if all inputs and classes are the same, the outputs should deterministically
be the same as well (up to random iter). So, you dont need to fit objects
to be sure they are the same
"""
if (
("id" not in kwargs)
and (
"name" not in kwargs
or "version" not in kwargs
or ("model_id" not in kwargs and "model" not in kwargs)
)
and ("registered_name" not in kwargs)
):
raise TrainingError(
"Need to pass at least one of: `id`, `name, version, model`, `name, version, model_id`, `registered_name` to compare against existing persistables"
)
model = cls.retrieve_model(**kwargs)
dataset = cls.retrieve_dataset(**kwargs)
if "id" in kwargs:
filters = {"id": kwargs["id"]}
elif "name" in kwargs and "version" in kwargs and model is not None:
filters = {
"name": kwargs["name"],
"version": kwargs["version"],
"model_id": model.id,
}
else:
# Build dummy object to retrieve hash to look for
registered_name = kwargs["registered_name"]
new_metric = cls.retrieve_from_registry(registered_name)(**kwargs)
new_metric.add_model(model)
new_metric.add_dataset(dataset)
if strict:
new_metric.score()
filters = {
"name": new_metric.name,
"registered_name": registered_name,
"hash_": new_metric._hash(),
}
return ORMMetric, filters
@classmethod
[docs] def create(cls, registered_name: str, **kwargs) -> Metric:
"""
Stateless method to create a new persistable with the desired parameters
kwargs are passed directly to persistable
:param registered_name: Class name registered in SimpleML
:param model: model class
"""
model = cls.retrieve_model(**kwargs)
dataset = cls.retrieve_dataset(**kwargs)
new_metric = cls.retrieve_from_registry(registered_name)(**kwargs)
new_metric.add_model(model)
new_metric.add_dataset(dataset)
new_metric.score()
new_metric.save()
return new_metric