"""
ORM module for metric objects
"""
[docs]__author__ = "Elisha Yadgaran"
import logging
from sqlalchemy import Column, ForeignKey, Index, UniqueConstraint, func
from sqlalchemy.orm import relationship
from simpleml.orm.dataset import ORMDataset
from simpleml.orm.model import ORMModel
from simpleml.orm.persistable import ORMPersistable
from simpleml.orm.sqlalchemy_types import GUID, MutableJSON
[docs]LOGGER = logging.getLogger(__name__)
[docs]class ORMMetric(ORMPersistable):
"""
Abstract Base class for all Metric objects
-------
Schema
-------
name: the metric name
values: JSON object with key: value pairs for performance on test dataset
(ex: FPR: TPR to create ROC Curve)
Singular value metrics take the form - {'agg': value}
model_id: foreign key to the model that was used to generate predictions
dataset_id:
"""
[docs] __tablename__ = "metrics"
[docs] values = Column(MutableJSON, nullable=False)
# Dependencies are model and dataset
[docs] model_id = Column(GUID, ForeignKey("models.id", name="metrics_model_id_fkey"))
[docs] model = relationship("ORMModel", enable_typechecks=False)
[docs] dataset_id = Column(GUID, ForeignKey("datasets.id", name="metrics_dataset_id_fkey"))
[docs] dataset = relationship("ORMDataset", enable_typechecks=False)
[docs] __table_args__ = (
# Metrics don't have the notion of versions, values should be deterministic
# by class, model, and dataset - name should be the combination of class and dataset
# Still exists to stay consistent with the persistables style of unrestricted duplication
# (otherwise would be impossible to distinguish a duplicated metric -- name and model_id would be the same)
# Unique constraint for versioning
UniqueConstraint(
"name", "model_id", "version", name="metric_name_model_version_unique"
),
# Index for searching through friendly names
Index("metric_name_index", "name"),
)
@classmethod
[docs] def get_latest_version(cls, name: str, model_id: str) -> int:
"""
Versions should be autoincrementing for each object (constrained over
friendly name and model). Executes a database lookup and increments..
"""
last_version = (
cls.query_by(func.max(cls.version))
.filter(cls.name == name, cls.model_id == model_id)
.scalar()
)
if last_version is None:
last_version = 0
return last_version + 1
@classmethod
[docs] def load_dataset(cls, id: str):
return cls.load_reference(ORMDataset, id)
@classmethod
[docs] def load_model(cls, id: str):
return cls.load_reference(ORMModel, id)