Source code for deepctr_torch.callbacks

import torch
from tensorflow.python.keras.callbacks import EarlyStopping
from tensorflow.python.keras.callbacks import ModelCheckpoint
from tensorflow.python.keras.callbacks import History

EarlyStopping = EarlyStopping
History = History

[docs]class ModelCheckpoint(ModelCheckpoint): """Save the model after every epoch. `filepath` can contain named formatting options, which will be filled the value of `epoch` and keys in `logs` (passed in `on_epoch_end`). For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model checkpoints will be saved with the epoch number and the validation loss in the filename. Arguments: filepath: string, path to save the model file. monitor: quantity to monitor. verbose: verbosity mode, 0 or 1. save_best_only: if `save_best_only=True`, the latest best model according to the quantity monitored will not be overwritten. mode: one of {auto, min, max}. If `save_best_only=True`, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For `val_acc`, this should be `max`, for `val_loss` this should be `min`, etc. In `auto` mode, the direction is automatically inferred from the name of the monitored quantity. save_weights_only: if True, then only the model's weights will be saved (`model.save_weights(filepath)`), else the full model is saved (`model.save(filepath)`). period: Interval (number of epochs) between checkpoints. """
[docs] def on_epoch_end(self, epoch, logs=None): logs = logs or {} self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 filepath = self.filepath.format(epoch=epoch + 1, **logs) if self.save_best_only: current = logs.get(self.monitor) if current is None: print('Can save best model only with %s available, skipping.' % self.monitor) else: if self.monitor_op(current, self.best): if self.verbose > 0: print('Epoch %05d: %s improved from %0.5f to %0.5f,' ' saving model to %s' % (epoch + 1, self.monitor, self.best, current, filepath)) self.best = current if self.save_weights_only: torch.save(self.model.state_dict(), filepath) else: torch.save(self.model, filepath) else: if self.verbose > 0: print('Epoch %05d: %s did not improve from %0.5f' % (epoch + 1, self.monitor, self.best)) else: if self.verbose > 0: print('Epoch %05d: saving model to %s' % (epoch + 1, filepath)) if self.save_weights_only: torch.save(self.model.state_dict(), filepath) else: torch.save(self.model, filepath)