Source code for simpleml.models.base_keras_model

'''
Base module for keras models. Keras has a native persistence mechanism so
need to overwrite other methods at the root
'''

__author__ = 'Elisha Yadgaran'


from .base_model import Model

import logging
from abc import abstractmethod


LOGGER = logging.getLogger(__name__)


[docs]class KerasModel(Model): def __init__(self, save_method='disk_keras_hdf5', **kwargs): ''' Pass default save method as Keras's persistence pattern ''' super(KerasModel, self).__init__(save_method=save_method, **kwargs) @abstractmethod def _create_external_model(self, **kwargs): ''' Abstract method for each subclass to implement should return the desired model object Must return external_file Keras pattern is: external_model = SomeWrappedKerasClass(**kwargs) return self.build_network(external_model) ''' external_model = None self.build_network(external_model, **kwargs)
[docs] def build_network(self, external_model, **kwargs): ''' Design choice to require build network method instead of exposing raw Keras objects that can be modified later. Simplifies saving and loading pattern because initialized object should also be the final state (as long as manual override doesnt happen) ''' return external_model
def _fit(self, X, y): ''' Keras fit parameters (epochs, callbacks...) are stored as self.params so retrieve them automatically ''' # Reduce dimensionality of y if it is only 1 column self.external_model.fit(X, y.squeeze(), **self.get_params())
[docs] def set_params(self, **kwargs): ''' Keras networks don't have params beyond layers, which should be configured in `self.build_network`, so use this for fit params - self.fit will auto pull params and pass them to the fit method. TODO: Figure out if changing params should be allowed after fit. If they are, would need to reinitialize model, otherwise it would train more epochs and not forget the original training. If not, once fit, we can treat the model as static, and no longer able to be changed For now going with option 2 - cannot refit models ''' if self.fitted: LOGGER.warning('Cannot change fit params and retrain model, skipping operation') LOGGER.warning('Initialize a new model for new fit params') return None self.params = kwargs
[docs] def get_params(self, **kwargs): ''' Get fit params ''' return self.params
[docs] @staticmethod def transfer_weights(new_model, old_model): new_layers = {i.name: i for i in new_model.layers} old_layers = {i.name: i for i in old_model.layers} for name, layer in new_layers.items(): if name in old_layers: layer.set_weights(old_layers[name].get_weights())