训练器
tflearn.helpers.trainer.Trainer (train_ops, graph=None, clip_gradients=5.0, tensorboard_dir='/tmp/tflearn_logs/', tensorboard_verbose=0, checkpoint_path=None, best_checkpoint_path=None, max_checkpoints=None, keep_checkpoint_every_n_hours=10000.0, random_seed=None, session=None, best_val_accuracy=0.0)
用于处理任何 TensorFlow 图训练的通用类。它需要使用 TrainOp
来指定所有优化参数。
参数
- train_ops:
TrainOp
列表。用于执行优化的网络训练操作列表。 - graph:
tf.Graph
。要使用的 TensorFlow 图。默认值:默认 tf 图。 - clip_gradients:
float
。梯度裁剪。默认值:5.0。 - tensorboard_dir:
str
。Tensorboard 日志目录。默认值:"/tmp/tflearn_logs/". - tensorboard_verbose:
int
。详细级别。它支持
0 - Loss, Accuracy. (Best Speed)
1 - Loss, Accuracy, Gradients.
2 - Loss, Accuracy, Gradients, Weights.
3 - Loss, Accuracy, Gradients, Weights, Activations, Sparsity.(Best Visualization)
- checkpoint_path:
str
。存储模型检查点的路径。如果为 None,则不会保存模型检查点。默认值:None。 - best_checkpoint_path:
str
。当验证率达到当前训练阶段的最高点并且高于 best_val_accuracy 时,存储模型的路径。默认值:None。 - max_checkpoints:
int
或 None。最大检查点数。如果为 None,则没有限制。默认值:None。 - keep_checkpoint_every_n_hours:
float
。每个模型检查点之间的小时数。 - random_seed:
int
。随机种子,用于测试可重复性。默认值:None。 - session:
Session
。用于运行操作的会话。如果为 None,则将创建一个新的会话。注意:提供会话时,必须已经初始化变量,否则将引发错误。 - best_val_accuracy:
float
在将模型权重保存到 best_checkpoint_path 之前需要达到的最小验证精度。这允许用户跳过早期保存,并在继续训练重新加载的模型时设置最小保存点。默认值:0.0。
方法
fit (feed_dicts, n_epoch=10, val_feed_dicts=None, show_metric=False, snapshot_step=None, snapshot_epoch=True, shuffle_all=None, dprep_dict=None, daug_dict=None, excl_trainops=None, run_id=None, callbacks=[])
使用馈送的数据字典训练网络。
示例
# 1 Optimizer
trainer.fit(feed_dicts={input1: X, output1: Y},val_feed_dicts={input1: X, output1: Y})
trainer.fit(feed_dicts={input1: X1, input2: X2, output1: Y},val_feed_dicts=0.1) # 10% of data used for validation
# 2 Optimizers
trainer.fit(feed_dicts=[{in1: X1, out1:Y}, {in2: X2, out2:Y2}],val_feed_dicts=[{in1: X1, out1:Y}, {in2: X2, out2:Y2}])
参数
- feed_dicts:
dict
或dict
列表。用于将数据馈送到网络的字典。它遵循 Tensorflow 馈送字典规范:'{placeholder: data}'。如果是多个优化器,则需要一个字典列表,分别馈送到优化器。 - n_epoch:
int
。要运行的时期数。 - val_feed_dicts:
dict
、dict
列表、float
或float
列表。用于验证的数据。馈送字典遵循与上述feed_dicts
相同的规范。也可以提供一个float
来拆分训练数据以进行验证(请注意,这将打乱数据)。 - show_metric:
bool
。如果为 True,则将在每个步骤计算并显示准确率。可能会导致训练速度变慢。 - snapshot_step:
int
。如果不为 None,则网络将在每个提供的步骤进行快照(计算验证损失/准确率并保存模型,如果在Trainer
中指定了checkpoint_path
)。 - snapshot_epoch:
bool
。如果为 True,则在每个时期结束时对网络进行快照。 - shuffle_all:
bool
。如果为 True,则打乱所有数据批次(覆盖TrainOp
shuffle 参数行为)。 - dprep_dict: 以
Placeholder
为键、以DataPreprocessing
为值的dict
。对给定的占位符应用实时数据预处理(在训练和测试时应用)。 - daug_dict: 以
Placeholder
为键、以DataAugmentation
为值的dict
。对给定的占位符应用实时数据增强(仅在训练时应用)。 - excl_trainops:
TrainOp
列表。要从训练过程中排除的训练操作列表。 - run_id:
str
。当前运行的名称。用于 Tensorboard 显示。如果没有提供名称,则将生成一个随机名称。 - callbacks:
Callback
或list
。要在训练生命周期中使用的自定义回调
fit_batch (feed_dicts, dprep_dict=None, daug_dict=None)
使用单个批次训练网络。
参数
- feed_dicts:
dict
或dict
列表。用于将数据馈送到网络的字典。它遵循 Tensorflow 馈送字典规范:'{placeholder: data}'。如果是多个优化器,则需要一个字典列表,分别馈送到优化器。 - dprep_dict: 以
Placeholder
为键、以DataPreprocessing
为值的dict
。对给定的占位符应用实时数据预处理(在训练和测试时应用)。 - daug_dict: 以
Placeholder
为键、以DataAugmentation
为值的dict
。对给定的占位符应用实时数据增强(仅在训练时应用)。
restore (model_file, trainable_variable_only=False, variable_name_map=None, scope_for_restore=None, create_new_session=True, verbose=False)
恢复 Tensorflow 模型
参数
- model_file: 要恢复的 tensorflow 模型的路径
- trainable_variable_only: 如果为 True,则仅恢复可训练变量。
- variable_name_map: - 一个 (pattern, repl) 元组,提供一个正则表达式模式和替换,在从模型文件恢复之前应用于变量名称 -- 或者,一个函数 map_func,用于执行映射,调用方式为:name_in_file = map_func(existing_var_op_name) 该函数可以返回 None 以指示不恢复变量。
- scope_for_restore: 字符串,指定要限制的范围,在恢复变量时使用。- 还从 var 名称中删除范围名称前缀,以便在恢复时使用。
- create_new_session: 如果要保留当前会话,则设置为 False。设置为 True(默认值)以创建新会话并重新初始化所有变量。
- verbose : 设置为 True 以查看在使用 scope_for_restore 或 variable_name_map 时正在恢复的变量的打印输出。
save (model_file, global_step=None)
保存 Tensorflow 模型
参数
- model_file:
str
。tensorflow 模型的保存路径 - global_step:
int
。要附加到模型文件名的训练步骤(可选)。
训练操作
tflearn.helpers.trainer.TrainOp (loss, optimizer, metric=None, batch_size=64, ema=0.0, trainable_vars=None, shuffle=True, step_tensor=None, validation_monitors=None, validation_batch_size=None, name=None, graph=None)
TrainOp 表示一组用于优化网络的操作。
TrainOp 旨在保存优化器的所有训练参数。然后,Trainer
类将实例化所有这些参数,并特别考虑网络的所有优化器(设置名称、范围...设置优化操作...)。
参数
- loss:
Tensor
。用于评估网络成本的损失操作。优化器将使用此成本函数来训练网络。 - optimizer:
Optimizer
。Tensorflow 优化器。用于训练网络的优化器。 - metric:
Tensor
。用于评估的度量张量。 - batch_size:
int
。馈送到此优化器的数据的批次大小。默认值:64。 - ema:
float
。指数移动平均。 - trainable_vars:
tf.Variable
列表。用于训练的可训练变量列表。默认值:所有可训练变量。 - shuffle:
bool
。打乱数据。 - step_tensor:
tf.Tensor
。保存训练步骤的变量。如果没有提供,则将创建它。尽早定义步骤张量可能对网络创建很有用,例如学习率衰减。 - validation_monitors:
Tensor
对象列表。要在验证期间计算的变量列表,这些变量也用于生成摘要以输出到 TensorBoard。例如,这可以用于在训练期间定期记录混淆矩阵或 AUC 度量。每个变量的秩应为 1,即形状为 [None]。 - validation_batch_size:
int
或 None。如果为int
,则指定用于验证数据馈送的批次大小;否则默认为与batch_size
相同。 - name:
str
。此类的名称(可选)。 - graph:
tf.Graph
。用于训练的 Tensorflow 图。默认值:默认 tf 图。
方法
initialize_fit (feed_dict, val_feed_dict, dprep_dict, daug_dict, show_metric, summ_writer, coord)
初始化用于馈送训练过程的数据。它旨在在开始拟合数据之前由 Trainer
使用。
参数
- feed_dict:
dict
。要馈送的数据字典。 - val_feed_dict:
dict
或float
。要馈送的验证数据字典或验证拆分。 - dprep_dict:
dict
。数据预处理字典(以占位符为键,以相应的DataPreprocessing
对象为值)。 - daug_dict:
dict
。数据增强字典(以占位符为键,以相应的DataAugmentation
对象为值)。 - show_metric:
bool
。如果为 True,则在每个步骤显示准确率。 - summ_writer:
SummaryWriter
。用于 Tensorboard 日志记录的摘要写入器。
initialize_training_ops (i, session, tensorboard_verbose, clip_gradients)
初始化用于训练的所有操作。因为一个网络可以有多个优化器,所以分配了一个 id 'i' 来区分它们。这旨在由 Trainer
在初始化所有训练操作时使用。
参数
- i:
int
。此优化器训练过程 ID。 - session:
tf.Session
。用于训练网络的会话。 - tensorboard_verbose:
int
。日志详细级别。支持
0 - Loss, Accuracy.
1 - Loss, Accuracy, Gradients.
2 - Loss, Accuracy, Gradients, Weights.
3 - Loss, Accuracy, Gradients, Weights, Activations, Sparsity..
- clip_gradients:
float
。用于裁剪梯度的选项。