"""
Base Module for Pipelines
"""
[docs]__author__ = "Elisha Yadgaran"
import logging
import uuid
import weakref
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import pandas as pd
from simpleml.constants import TRAIN_SPLIT
from simpleml.datasets.dataset_splits import Split, SplitContainer
from simpleml.persistables.base_persistable import Persistable
from simpleml.registries import PipelineRegistry
from simpleml.save_patterns.decorators import ExternalArtifactDecorators
from simpleml.utils.errors import PipelineError
from simpleml.utils.signature_inspection import signature_kwargs_validator
from .projected_splits import IdentityProjectedDatasetSplit, ProjectedDatasetSplit
if TYPE_CHECKING:
# Cyclical import hack for type hints
from simpleml.datasets.base_dataset import Dataset
[docs]LOGGER = logging.getLogger(__name__)
[docs]@ExternalArtifactDecorators.register_artifact(
artifact_name="pipeline",
save_attribute="external_pipeline",
restore_attribute="_external_file",
)
class Pipeline(Persistable, metaclass=PipelineRegistry):
"""
Abstract Base class for all Pipelines objects.
Relies on mixin classes to define the split_dataset method. Will throw
an error on use otherwise
-------
Schema
-------
params: pipeline parameter metadata for easy insight into hyperparameters across trainings
"""
[docs] object_type: str = "PIPELINE"
def __init__(
self,
has_external_files: bool = True,
transformers: Optional[List[Any]] = None,
fitted: bool = False,
dataset_id: Optional[Union[str, uuid.uuid4]] = None,
**kwargs,
):
# If no save patterns are set, specify a default for disk_pickled
if "save_patterns" not in kwargs:
kwargs["save_patterns"] = {"pipeline": ["disk_pickled"]}
super(Pipeline, self).__init__(has_external_files=has_external_files, **kwargs)
# prep instantiation of pipeline - lazy build
if transformers is None:
transformers: List[Any] = []
self._transformers = transformers
self._external_pipeline_init_kwargs = kwargs
# Initialize fit state -- pass as true to skip fitting transformers
self.fitted = fitted
# initialize null dataset reference
self.dataset_id = dataset_id
"""
Persistable Management
"""
@property
[docs] def fitted(self) -> bool:
return self.state.get("fitted")
@fitted.setter
def fitted(self, value: bool) -> None:
self.state["fitted"] = value
@property
[docs] def external_pipeline(self) -> Any:
"""
All pipeline objects are going to require some filebase persisted object
Wrapper around whatever underlying class is desired
(eg sklearn or native)
"""
self.load_if_unloaded("pipeline")
# lazy build
if not hasattr(self, "_external_file"):
self._external_file = self._create_external_pipeline(
self._transformers, **self._external_pipeline_init_kwargs
)
# clear temp vars
del self._transformers
del self._external_pipeline_init_kwargs
return self._external_file
[docs] def _create_external_pipeline(self, *args, **kwargs):
"""
each subclass should instantiate the respective pipeline library
"""
raise NotImplementedError("Must use a specific pipeline implementation")
[docs] def add_dataset(self, dataset: "Dataset") -> None:
"""
Setter method for dataset used
"""
if dataset is None:
return
self.dataset_id = dataset.id
self.dataset = dataset
@property
[docs] def dataset(self):
"""
Use a weakref to bind linked dataset so it doesnt bloat usage
returns dataset if still available or tries to fetch otherwise
"""
# still referenced weakref
if hasattr(self, "_dataset") and self._dataset() is not None:
return self._dataset()
# null return if no associated dataset (governed by dataset_id)
elif self.dataset_id is None:
return None
# else regenerate weakref
LOGGER.info("No referenced object available. Refreshing weakref")
dataset = self._load_dataset()
self._dataset = weakref.ref(dataset)
return dataset
@dataset.setter
def dataset(self, dataset: "Dataset") -> None:
"""
Need to be careful not to set as the orm object
Cannot load if wrong type because of recursive behavior (will
propagate down the whole dependency chain)
"""
self._dataset = weakref.ref(dataset)
[docs] def _load_dataset(self):
"""
Helper to fetch the dataset
"""
return self.orm_cls.load_dataset(self.dataset_id)
[docs] def assert_dataset(self, msg: str = "") -> None:
"""
Helper method to raise an error if dataset isn't present
"""
if self.dataset is None:
raise PipelineError(msg)
[docs] def assert_fitted(self, msg: str = "") -> None:
"""
Helper method to raise an error if pipeline isn't fit
"""
if not self.fitted:
raise PipelineError(msg)
[docs] def _hash(self) -> str:
"""
Hash is the combination of the:
1) Dataset
2) Transformers
3) Transformer Params
4) Pipeline Config
"""
dataset_hash = self.dataset.hash_ or self.dataset._hash()
transformers = self.get_transformers()
transformer_params = self.get_params(params_only=True)
pipeline_config = self.config
return self.custom_hasher(
(dataset_hash, transformers, transformer_params, pipeline_config)
)
[docs] def save(self, **kwargs) -> None:
"""
Extend parent function with a few additional save routines
1) save params
2) save transformer metadata
3) features
"""
self.assert_dataset("Must set dataset before saving")
self.assert_fitted("Must fit pipeline before saving")
# log only attributes - can be refreshed on each save (does not take effect on reloading)
self.params = self.get_params(params_only=True, **kwargs)
self.metadata_["transformers"] = self.get_transformers()
self.metadata_["feature_names"] = self.get_feature_names()
# Skip file-based persistence if there are no transformers
if not self.get_transformers():
self.has_external_files = False
super().save(**kwargs)
[docs] def __post_restore__(self) -> None:
"""
Extend main load routine to load relationship class
"""
super().__post_restore__()
# Create dummy pipeline if one wasnt saved
if not self.has_external_files:
self._external_file = self._create_external_pipeline([], **self.params)
"""
Data Accessors
"""
[docs] def split_dataset(self) -> None:
"""
Method to create a cached reference to the projected data (cant use dataset
directly in case of mutation concerns)
Non-split mixin class. Returns full dataset for any split name
"""
default_split = IdentityProjectedDatasetSplit(dataset=self.dataset, split=None)
# use a single reference to avoid duplicating on different key searches
self._dataset_splits = SplitContainer(default_factory=lambda: default_split)
[docs] def get_dataset_split(self, split: Optional[str] = None) -> ProjectedDatasetSplit:
"""
Get specific dataset split
Assumes a ProjectedDatasetSplit object (`simpleml.pipelines.projected_splits.ProjectedDatasetSplit`)
is returned. Inherit or implement similar expected attributes to replace
Uses internal `self._dataset_splits` as the split container - assumes
dictionary like itemgetter
"""
if not hasattr(self, "_dataset_splits") or self._dataset_splits is None:
self.split_dataset()
return self._dataset_splits[split]
[docs] def get_split_names(self) -> List[str]:
if not hasattr(self, "_dataset_splits") or self._dataset_splits is None:
self.split_dataset()
return list(self._dataset_splits.keys())
[docs] def X(self, split: Optional[str] = None) -> Any:
"""
Get X for specific dataset split
"""
return self.get_dataset_split(split=split).X
[docs] def y(self, split: Optional[str] = None) -> Any:
"""
Get labels for specific dataset split
"""
return self.get_dataset_split(split=split).y
[docs] def _filter_fit_params(self, split: ProjectedDatasetSplit) -> Dict[str, Any]:
"""
Helper to filter unsupported fit params from dataset splits
"""
return signature_kwargs_validator(self.external_pipeline.fit, **split)
[docs] def fit(self):
"""
Pass through method to external pipeline
"""
self.assert_dataset("Must set dataset before fitting")
if self.fitted:
LOGGER.warning("Cannot refit pipeline, skipping operation")
return self
# Only use default (train) fold to fit
# No constraint on split -- can be a dataframe, ndarray, or generator
# but must be encased in a Split object
split = self.get_dataset_split(split=TRAIN_SPLIT)
supported_fit_params = self._filter_fit_params(split)
self.external_pipeline.fit(**supported_fit_params)
self.fitted = True
return self
"""
Pass-through methods to external pipeline
"""
[docs] def get_params(self, **kwargs):
"""
Pass through method to external pipeline
"""
return self.external_pipeline.get_params(**kwargs)
[docs] def set_params(self, **params):
"""
Pass through method to external pipeline
"""
return self.external_pipeline.set_params(**params)
[docs] def get_feature_names(self) -> List[str]:
"""
Pass through method to external pipeline
Should return a list of the final features generated by this pipeline
"""
initial_features = self.dataset.get_feature_names()
return self.external_pipeline.get_feature_names(feature_names=initial_features)