Source code for simpleml.models.base_keras_model

"""
Base module for keras models. Keras has a native persistence mechanism so
need to overwrite other methods at the root
"""

[docs]__author__ = "Elisha Yadgaran"
import logging from abc import abstractmethod from simpleml.constants import TRAIN_SPLIT, VALIDATION_SPLIT from simpleml.utils.signature_inspection import signature_kwargs_validator from .base_model import LibraryModel from .split_iterators import DatasetSequence, PythonIterator
[docs]LOGGER = logging.getLogger(__name__)
[docs]class KerasModel(LibraryModel): """ Base Keras model class. Keras objects are incrementally structured until fit. Also dont have separable params. Class hijacks params to store fit params instead (enables full specification on init for reproducibility) """ def __init__( self, use_training_generator=False, training_generator_params=None, use_validation_generator=False, validation_generator_params=None, use_sequence_object=False, **kwargs, ): """ Pass default save method as Keras's persistence pattern :param use_training_generator: Whether to propagate use of a generator object when training -- does not allow for using a generator in production -- only fit_generator :type use_training_generator: Bool :param use_validation_generator: Whether to ALSO use a generator for validation data while training. Does nothing if use_training_generator is false :type use_validation_generator: Bool :param training_generator_params: parameters to pass to the generator method for train split - normal fit(_generator) params should be passed as params={} :param validation_generator_params: parameters to pass to the generator method for validation split - normal fit(_generator) params should be passed as params={} """ # Overwrite default model save pattern to keras specific (if not already passed) if "save_patterns" not in kwargs: LOGGER.info("Setting model save pattern to `disk_keras_hdf5`") kwargs["save_patterns"] = {"model": ["disk_keras_hdf5"]} elif "model" not in kwargs["save_patterns"]: LOGGER.info("Setting model save pattern to `disk_keras_hdf5`") kwargs["save_patterns"]["model"] = ["disk_keras_hdf5"] super(KerasModel, self).__init__(**kwargs) # Keras supports training and validation with generators # Design choice to put this in config as opposed to state because while # it is true that a specific combination of generator params will yield # the same model artifact as the traditional fit, it is very unlikely and # therefore assumed to be different (hashes will not be equal because of differing param structure) if training_generator_params is None: training_generator_params = {} if validation_generator_params is None: validation_generator_params = {} generator_params = { "use_training_generator": use_training_generator, "use_sequence_object": use_sequence_object, "training_generator_params": training_generator_params, "use_validation_generator": use_validation_generator, "validation_generator_params": validation_generator_params, } self.config.update(generator_params) @abstractmethod
[docs] def _create_external_model(self, **kwargs): """ Abstract method for each subclass to implement should return the desired model object Must return external_file Keras pattern is: external_model = SomeWrappedKerasClass(**kwargs) return self.build_network(external_model) """ external_model = None self.build_network(external_model, **kwargs)
[docs] def build_network(self, external_model, **kwargs): """ Design choice to require build network method instead of exposing raw Keras objects that can be modified later. Simplifies saving and loading pattern because initialized object should also be the final state (as long as manual override doesnt happen) """ return external_model
[docs] def _fit(self): """ Keras fit parameters (epochs, callbacks...) are stored as self.params so retrieve them automatically """ # Keras supports fitting on generator objects, so expose additional internal # method, if flag set if self.config["use_training_generator"]: self._fit_generator() else: # Explicitly fit only on default (train) split split = self.transform(X=None, dataset_split=TRAIN_SPLIT) # keras api uses lowercase x if "X" in split: split["x"] = split.pop("X") supported_fit_params = signature_kwargs_validator( self.external_model.fit, **split ) self.external_model.fit(**supported_fit_params, **self.get_params())
[docs] def _fit_generator(self): """ Keras fit parameters (epochs, callbacks...) are stored as self.params so retrieve them automatically """ use_keras_sequence = self.config.get("use_sequence_object", False) if use_keras_sequence: iterator_cls = DatasetSequence else: iterator_cls = PythonIterator # Explicitly fit only on default (train) split transformed_training_data = self.transform(X=None, dataset_split=TRAIN_SPLIT) training_generator_params = self.config.get( "training_generator_params", {} ).copy() training_generator_params[ "return_tuple" ] = True # force tuple return for compatibility training_generator = iterator_cls( transformed_training_data, **training_generator_params ) if self.config["use_validation_generator"]: transformed_validation_data = self.transform( X=None, dataset_split=VALIDATION_SPLIT ) validation_generator_params = self.config.get( "validation_generator_params", {} ).copy() validation_generator_params[ "return_tuple" ] = True # force tuple return for compatibility validation_generator = iterator_cls( transformed_validation_data, **validation_generator_params ) else: validation_generator = None self.external_model.fit_generator( training_generator, validation_data=validation_generator, **self.get_params(),
)
[docs] def set_params(self, **kwargs): """ Keras networks don't have params beyond layers, which should be configured in `self.build_network`, so use this for fit params - self.fit will auto pull params and pass them to the fit method. TODO: Figure out if changing params should be allowed after fit. If they are, would need to reinitialize model, otherwise it would train more epochs and not forget the original training. If not, once fit, we can treat the model as static, and no longer able to be changed For now going with option 2 - cannot refit models """ if self.fitted: LOGGER.warning( "Cannot change fit params and retrain model, skipping operation" ) LOGGER.warning("Initialize a new model for new fit params") return None self.params = kwargs
[docs] def get_params(self, **kwargs): """ Get fit params """ # keras params are fit params which only exist if passed. cannot inspect # from model if hasattr(self, "params"): return self.params return {}
@staticmethod
[docs] def transfer_weights(new_model, old_model): new_layers = {i.name: i for i in new_model.layers} old_layers = {i.name: i for i in old_model.layers} for name, layer in new_layers.items(): if name in old_layers: layer.set_weights(old_layers[name].get_weights())