I am Charmie

メモとログ

Keras: a callback function to save loss on each batch

Inspired by this stackoverflow answer and keras.callbacks.CSVLogger, I implemented a callback function to save loss on each batch.

[code lang='python'] class LossHistory(Callback):

def __init__(self, filename, separator=',', append=False):
    self.sep = separator
    self.filename = filename
    self.append = append
    self.writer = None
    self.keys = None
    self.append_header = True
    self.file_flags = 'b' if six.PY2 and os.name == 'nt' else ''

def on_train_begin(self, logs={}):
    if self.append:
        if os.path.exists(self.filename):
            with open(self.filename, 'r' + self.file_flags) as f:
                self.append_header = not bool(len(f.readline()))
        self.csv_file = open(self.filename, 'a' + self.file_flags)
    else:
        self.csv_file = open(self.filename, 'w' + self.file_flags)

def on_batch_end(self, batch, logs={}):
    logs = logs or {}

    def handle_value(k):
        is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
        if isinstance(k, six.string_types):
            return k
        elif isinstance(k, Iterable) and not is_zero_dim_ndarray:
            return '"[%s]"' % (', '.join(map(str, k)))
        else:
            return k

    if self.keys is None:
        self.keys = sorted(logs.keys())

    if self.model.stop_training:
        # We set NA so that csv parsers do not fail for this last epoch.
        logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])

    if not self.writer:
        class CustomDialect(csv.excel):
            delimiter = self.sep

        self.writer = csv.DictWriter(self.csv_file,
                                     fieldnames=['batch'] + self.keys, dialect=CustomDialect)
        if self.append_header:
            self.writer.writeheader()

    row_dict = OrderedDict({'batch': batch})
    row_dict.update((key, handle_value(logs[key])) for key in self.keys)
    self.writer.writerow(row_dict)
    self.csv_file.flush()

def on_train_end(self, logs=None):
    self.csv_file.close()
    self.writer = None

file_history = 'batch.log' cb_history = LossHistory(file_history)

model.fit(..., callbacks=[cb_history]) [/code]