Source code for simpleml.save_patterns.database

'''
Module for save patterns registered for database persistence
'''

[docs]__author__ = 'Elisha Yadgaran'
import pandas as pd from typing import Any, Dict from simpleml.persistables.base_sqlalchemy import DatasetStorageSqlalchemy from simpleml.save_patterns.decorators import SavePatternDecorators from simpleml.save_patterns.base import BaseSavePattern from simpleml.utils.binary_blob import BinaryBlob
[docs]@SavePatternDecorators.register_save_pattern class DatabaseTableSavePattern(BaseSavePattern): ''' Save pattern implementation to save dataframes to a database table '''
[docs] SAVE_PATTERN = 'database_table'
@classmethod
[docs] def save(cls, obj: pd.DataFrame, persistable_id: str, schema: str = DatasetStorageSqlalchemy.SCHEMA, **kwargs) -> Dict[str, str]: ''' Save method to save dataframe into a new table with name = GUID Updates filepath for the artifact with the schema and table ''' engine = DatasetStorageSqlalchemy.metadata.bind cls.df_to_sql(engine, df=obj, table=persistable_id, schema=schema) return {'schema': schema, 'table': persistable_id}
@classmethod
[docs] def load(cls, filepath_data: Dict[str, str], **kwargs) -> pd.DataFrame: ''' Load method to load dataframe from database ''' schema = filepath_data['schema'] table = filepath_data['table'] engine = DatasetStorageSqlalchemy.metadata.bind df = cls.load_sql( 'select * from "{}"."{}"'.format(schema, table), engine ) return df
[docs]@SavePatternDecorators.register_save_pattern class DatabasePickleSavePattern(BaseSavePattern): ''' Save pattern implementation to save binary objects to a database table '''
[docs] SAVE_PATTERN = 'database_pickled'
@classmethod
[docs] def save(cls, obj: Any, persistable_type: str, persistable_id: str, **kwargs) -> str: ''' Save method to save files into binary schema Hardcoded to only store pickled objects in database so overwrite to use other storage mechanism ''' pickled_stream = cls.pickle_object(obj, as_stream=True) pickled_record = BinaryBlob.create( object_type=persistable_type, object_id=persistable_id, binary_blob=pickled_stream) return str(pickled_record.id)
@classmethod
[docs] def load(cls, primary_key: str, **kwargs) -> Any: ''' Load method to load files from database Hardcoded to only pull from pickled so overwrite to use other storage mechanism ''' pickled_stream = BinaryBlob.find(primary_key).binary_blob return cls.load_pickled_object(pickled_stream, stream=True)