[TOC]
回调函数使用
回调函数是一个函数的合集,会在训练的阶段中所使用。你可以使用回调函数来查看训练模型的内在状态和统计。你可以传递一个列表的回调函数(作为 callbacks
关键字参数)到 Sequential
或 Model
类型的 .fit()
方法。在训练时,相应的回调函数的方法就会在各自的阶段被调用。
Callback
1 | keras.callbacks.Callback() |
用来组建新的回调函数的抽象基类。
BaseLogger
1 | keras.callbacks.BaseLogger() |
会积累训练轮平均评估的回调函数。
这个回调函数被自动应用到每一个 Keras 模型上面。
TerminateOnNaN
1 | keras.callbacks.TerminateOnNaN() |
当遇到 NaN 损失会停止训练的回调函数。
ProgbarLogger
1 | keras.callbacks.ProgbarLogger(count_mode='samples') |
会把评估以标准输出打印的回调函数。
参数
- count_mode: “steps” 或者 “samples”。 进度条是否应该计数看见的样本或步骤(批量)。
触发
- ValueError: 防止不正确的
count_mode
History
1 | keras.callbacks.History() |
把所有事件都记录到 History
对象的回调函数。
这个回调函数被自动启用到每一个 Keras 模型。History
对象会被模型的 fit
方法返回。
ModelCheckpoint
1 | keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1) |
在每个训练期之后保存模型。
filepath
可以包括命名格式选项,可以由 epoch
的值和 logs
的键(由 on_epoch_end
参数传递)来填充。
例如:如果 filepath
是 weights.{epoch:02d}-{val_loss:.2f}.hdf5
, 那么模型被保存的的文件名就会有训练轮数和验证损失。
EarlyStopping
1 | keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto') |
当被监测的数量不再提升,则停止训练。
RemoteMonitor
1 | keras.callbacks.RemoteMonitor(root='http://localhost:9000', path='/publish/epoch/end/', field='data', headers=None) |
将事件数据流到服务器的回调函数。
需要 requests
库。 事件被默认发送到 root + '/publish/epoch/end/'
。 采用 HTTP POST ,其中的 data
参数是以 JSON 编码的事件数据字典。
LearningRateScheduler
1 | keras.callbacks.LearningRateScheduler(schedule, verbose=0) |
学习速率定时器。
参数
- schedule: 一个函数,接受轮索引数作为输入(整数,从 0 开始迭代) 然后返回一个学习速率作为输出(浮点数)。
- verbose: 整数。 0:安静,1:更新信息。
TensorBoard
1 | keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, batch_size=32, write_graph=True, write_grads=False, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None) |
Tensorboard 基本可视化。
TensorBoard 是由 Tensorflow 提供的一个可视化工具。
这个回调函数为 Tensorboard 编写一个日志, 这样你可以可视化测试和训练的标准评估的动态图像, 也可以可视化模型中不同层的激活值直方图。
如果你已经使用 pip 安装了 Tensorflow,你应该可以从命令行启动 Tensorflow:
1 | tensorboard --logdir=/full_path_to_your_logs |
参数
- log_dir: 用来保存被 TensorBoard 分析的日志文件的文件名。
- histogram_freq: 对于模型中各个层计算激活值和模型权重直方图的频率(训练轮数中)。 如果设置成 0 ,直方图不会被计算。对于直方图可视化的验证数据(或分离数据)一定要明确的指出。
- write_graph: 是否在 TensorBoard 中可视化图像。 如果 write_graph 被设置为 True,日志文件会变得非常大。
- write_grads: 是否在 TensorBoard 中可视化梯度值直方图。
histogram_freq
必须要大于 0 。 - batch_size: 用以直方图计算的传入神经元网络输入批的大小。
- write_images: 是否在 TensorBoard 中将模型权重以图片可视化。
- embeddings_freq: 被选中的嵌入层会被保存的频率(在训练轮中)。
- embeddings_layer_names: 一个列表,会被监测层的名字。 如果是 None 或空列表,那么所有的嵌入层都会被监测。
- embeddings_metadata: 一个字典,对应层的名字到保存有这个嵌入层元数据文件的名字。 查看 详情 关于元数据的数据格式。 以防同样的元数据被用于所用的嵌入层,字符串可以被传入。
ReduceLROnPlateau
1 | keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0) |
当标准评估已经停止时,降低学习速率。
当学习停止时,模型总是会受益于降低 2-10 倍的学习速率。 这个回调函数监测一个数据并且当这个数据在一定「有耐心」的训练轮之后还没有进步, 那么学习速率就会被降低。
例
1 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, |
CSVLogger
1 | keras.callbacks.CSVLogger(filename, separator=',', append=False) |
把训练轮结果数据流到 csv 文件的回调函数。
支持所有可以被作为字符串表示的值,包括 1D 可迭代数据,例如,np.ndarray。
例
1 | csv_logger = CSVLogger('training.log') |
参数
- filename: csv 文件的文件名,例如 ‘run/log.csv’。
- separator: 用来隔离 csv 文件中元素的字符串。
- append: True:如果文件存在则增加(可以被用于继续训练)。False:覆盖存在的文件。
LambdaCallback
1 | keras.callbacks.LambdaCallback(on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None, on_batch_end=None, on_train_begin=None, on_train_end=None) |
在训练进行中创建简单,自定义的回调函数。
这个回调函数和匿名函数在合适的时间被创建。 需要注意的是回调函数要求位置型参数,如下:
on_epoch_begin
和on_epoch_end
要求两个位置型的参数:epoch
,logs
on_batch_begin
和on_batch_end
要求两个位置型的参数:batch
,logs
on_train_begin
和on_train_end
要求一个位置型的参数:logs
参数
- on_epoch_begin: 在每轮开始时被调用。
- on_epoch_end: 在每轮结束时被调用。
- on_batch_begin: 在每批开始时被调用。
- on_batch_end: 在每批结束时被调用。
- on_train_begin: 在模型训练开始时被调用。
- on_train_end: 在模型训练结束时被调用。
1 |
|