Source code for simpleml.metrics.base_metric

[docs]__author__ = "Elisha Yadgaran"
import logging import uuid import weakref from typing import Any, Optional, Union from simpleml.datasets.base_dataset import Dataset from simpleml.models.base_model import Model from simpleml.persistables.base_persistable import Persistable from simpleml.registries import MetricRegistry from simpleml.utils.errors import MetricError
[docs]LOGGER = logging.getLogger(__name__)
[docs]class Metric(Persistable, metaclass=MetricRegistry): """ Base class for all Metric objects """
[docs] object_type: str = "METRIC"
def __init__( self, dataset_id: Optional[Union[str, uuid.uuid4]] = None, model_id: Optional[Union[str, uuid.uuid4]] = None, **kwargs, ): super().__init__(**kwargs) # initialize null references self.dataset_id = dataset_id self.model_id = model_id
[docs] def add_dataset(self, dataset: Dataset) -> None: """ Setter method for dataset used """ if dataset is None: return self.dataset_id = dataset.id self.dataset = dataset
@property
[docs] def dataset(self): """ Use a weakref to bind linked dataset so it doesnt bloat usage returns dataset if still available or tries to fetch otherwise """ # still referenced weakref if hasattr(self, "_dataset") and self._dataset() is not None: return self._dataset() # null return if no associated dataset (governed by dataset_id) elif self.dataset_id is None: return None # else regenerate weakref LOGGER.info("No referenced object available. Refreshing weakref") dataset = self._load_dataset() self._dataset = weakref.ref(dataset) return dataset
@dataset.setter def dataset(self, dataset: Dataset) -> None: """ Need to be careful not to set as the orm object Cannot load if wrong type because of recursive behavior (will propagate down the whole dependency chain) """ self._dataset = weakref.ref(dataset)
[docs] def _load_dataset(self): """ Helper to fetch the dataset """ return self.orm_cls.load_dataset(self.dataset_id)
[docs] def add_model(self, model: Model) -> None: """ Setter method for model used """ if model is None: return self.model_id = model.id self.model = model
@property
[docs] def model(self): """ Use a weakref to bind linked model so it doesnt bloat usage returns model if still available or tries to fetch otherwise """ # still referenced weakref if hasattr(self, "_model") and self._model() is not None: return self._model() # null return if no associated model (governed by model_id) elif self.model_id is None: return None # else regenerate weakref LOGGER.info("No referenced object available. Refreshing weakref") model = self._load_model() self._model = weakref.ref(model) return model
@model.setter def model(self, model: Model) -> None: """ Need to be careful not to set as the orm object Cannot load if wrong type because of recursive behavior (will propagate down the whole dependency chain) """ self._model = weakref.ref(model)
[docs] def _load_model(self): """ Helper to fetch the model """ return self.orm_cls.load_model(self.model_id)
[docs] def _hash(self) -> str: """ Hash is the combination of the: 1) Model 2) Dataset (optional) 3) Metric 4) Config """ model_hash = self.model.hash_ or self.model._hash() if self.dataset is not None: dataset_hash = self.dataset.hash_ or self.dataset._hash() else: dataset_hash = None metric = self.__class__.__name__ config = self.config return self.custom_hasher((model_hash, dataset_hash, metric, config))
[docs] def _get_latest_version(self) -> int: """ Versions should be autoincrementing for each object (constrained over friendly name and model). Executes a database lookup and increments.. """ return self.orm_cls.get_latest_version(name=self.name, model_id=self.model.id)
[docs] def _get_pipeline_split(self, column: str, split: str, **kwargs) -> Any: """ For special case where dataset is the same as the model's dataset, the dataset splits can refer to the pipeline imposed splits, not the inherent dataset's splits. Use the pipeline split then ex: RandomSplitPipeline on NoSplitDataset evaluating "in_sample" performance """ return getattr( self.model.pipeline.get_dataset_split(split=split, **kwargs), column
)
[docs] def _get_dataset_split(self, **kwargs) -> Any: """ Default accessor for dataset data. REFERS TO RAW DATASETS not the pipelines superimposed. That means that datasets that do not define explicit splits will have no notion of downstream splits (e.g. RandomSplitPipeline) """ return self.dataset.get(**kwargs)
[docs] def save(self, **kwargs) -> None: """ Extend parent function with a few additional save routines """ if self.model is None: raise MetricError("Must set model before saving") if self.values is None: raise MetricError("Must score metric before saving") super().save(**kwargs)
[docs] def score(self, **kwargs): """ Abstract method for each metric to define Should set self.values """ raise NotImplementedError