训练器

tflearn.helpers.trainer.Trainer (train_ops, graph=None, clip_gradients=5.0, tensorboard_dir='/tmp/tflearn_logs/', tensorboard_verbose=0, checkpoint_path=None, best_checkpoint_path=None, max_checkpoints=None, keep_checkpoint_every_n_hours=10000.0, random_seed=None, session=None, best_val_accuracy=0.0)

用于处理任何 TensorFlow 图训练的通用类。它需要使用 TrainOp 来指定所有优化参数。

参数

  • train_ops: TrainOp 列表。用于执行优化的网络训练操作列表。
  • graph: tf.Graph。要使用的 TensorFlow 图。默认值:默认 tf 图。
  • clip_gradients: float。梯度裁剪。默认值:5.0。
  • tensorboard_dir: str。Tensorboard 日志目录。默认值:"/tmp/tflearn_logs/".
  • tensorboard_verbose: int。详细级别。它支持
0 - Loss, Accuracy. (Best Speed)
1 - Loss, Accuracy, Gradients.
2 - Loss, Accuracy, Gradients, Weights.
3 - Loss, Accuracy, Gradients, Weights, Activations, Sparsity.(Best Visualization)
  • checkpoint_path: str。存储模型检查点的路径。如果为 None,则不会保存模型检查点。默认值:None。
  • best_checkpoint_path: str。当验证率达到当前训练阶段的最高点并且高于 best_val_accuracy 时,存储模型的路径。默认值:None。
  • max_checkpoints: int 或 None。最大检查点数。如果为 None,则没有限制。默认值:None。
  • keep_checkpoint_every_n_hours: float。每个模型检查点之间的小时数。
  • random_seed: int。随机种子,用于测试可重复性。默认值:None。
  • session: Session。用于运行操作的会话。如果为 None,则将创建一个新的会话。注意:提供会话时,必须已经初始化变量,否则将引发错误。
  • best_val_accuracy: float 在将模型权重保存到 best_checkpoint_path 之前需要达到的最小验证精度。这允许用户跳过早期保存,并在继续训练重新加载的模型时设置最小保存点。默认值:0.0。

方法

fit (feed_dicts, n_epoch=10, val_feed_dicts=None, show_metric=False, snapshot_step=None, snapshot_epoch=True, shuffle_all=None, dprep_dict=None, daug_dict=None, excl_trainops=None, run_id=None, callbacks=[])

使用馈送的数据字典训练网络。

示例
# 1 Optimizer
trainer.fit(feed_dicts={input1: X, output1: Y},val_feed_dicts={input1: X, output1: Y})
trainer.fit(feed_dicts={input1: X1, input2: X2, output1: Y},val_feed_dicts=0.1) # 10% of data used for validation

# 2 Optimizers
trainer.fit(feed_dicts=[{in1: X1, out1:Y}, {in2: X2, out2:Y2}],val_feed_dicts=[{in1: X1, out1:Y}, {in2: X2, out2:Y2}])
参数
  • feed_dicts: dictdict 列表。用于将数据馈送到网络的字典。它遵循 Tensorflow 馈送字典规范:'{placeholder: data}'。如果是多个优化器,则需要一个字典列表,分别馈送到优化器。
  • n_epoch: int。要运行的时期数。
  • val_feed_dicts: dictdict 列表、floatfloat 列表。用于验证的数据。馈送字典遵循与上述 feed_dicts 相同的规范。也可以提供一个 float 来拆分训练数据以进行验证(请注意,这将打乱数据)。
  • show_metric: bool。如果为 True,则将在每个步骤计算并显示准确率。可能会导致训练速度变慢。
  • snapshot_step: int。如果不为 None,则网络将在每个提供的步骤进行快照(计算验证损失/准确率并保存模型,如果在 Trainer 中指定了 checkpoint_path)。
  • snapshot_epoch: bool。如果为 True,则在每个时期结束时对网络进行快照。
  • shuffle_all: bool。如果为 True,则打乱所有数据批次(覆盖 TrainOp shuffle 参数行为)。
  • dprep_dict: 以 Placeholder 为键、以 DataPreprocessing 为值的 dict。对给定的占位符应用实时数据预处理(在训练和测试时应用)。
  • daug_dict: 以 Placeholder 为键、以 DataAugmentation 为值的 dict。对给定的占位符应用实时数据增强(仅在训练时应用)。
  • excl_trainops: TrainOp 列表。要从训练过程中排除的训练操作列表。
  • run_id: str。当前运行的名称。用于 Tensorboard 显示。如果没有提供名称,则将生成一个随机名称。
  • callbacks: Callbacklist。要在训练生命周期中使用的自定义回调

fit_batch (feed_dicts, dprep_dict=None, daug_dict=None)

使用单个批次训练网络。

参数
  • feed_dicts: dictdict 列表。用于将数据馈送到网络的字典。它遵循 Tensorflow 馈送字典规范:'{placeholder: data}'。如果是多个优化器,则需要一个字典列表,分别馈送到优化器。
  • dprep_dict: 以 Placeholder 为键、以 DataPreprocessing 为值的 dict。对给定的占位符应用实时数据预处理(在训练和测试时应用)。
  • daug_dict: 以 Placeholder 为键、以 DataAugmentation 为值的 dict。对给定的占位符应用实时数据增强(仅在训练时应用)。

restore (model_file, trainable_variable_only=False, variable_name_map=None, scope_for_restore=None, create_new_session=True, verbose=False)

恢复 Tensorflow 模型

参数
  • model_file: 要恢复的 tensorflow 模型的路径
  • trainable_variable_only: 如果为 True,则仅恢复可训练变量。
  • variable_name_map: - 一个 (pattern, repl) 元组,提供一个正则表达式模式和替换,在从模型文件恢复之前应用于变量名称 -- 或者,一个函数 map_func,用于执行映射,调用方式为:name_in_file = map_func(existing_var_op_name) 该函数可以返回 None 以指示不恢复变量。
  • scope_for_restore: 字符串,指定要限制的范围,在恢复变量时使用。- 还从 var 名称中删除范围名称前缀,以便在恢复时使用。
  • create_new_session: 如果要保留当前会话,则设置为 False。设置为 True(默认值)以创建新会话并重新初始化所有变量。
  • verbose : 设置为 True 以查看在使用 scope_for_restore 或 variable_name_map 时正在恢复的变量的打印输出。

save (model_file, global_step=None)

保存 Tensorflow 模型

参数
  • model_file: str。tensorflow 模型的保存路径
  • global_step: int。要附加到模型文件名的训练步骤(可选)。

训练操作

tflearn.helpers.trainer.TrainOp (loss, optimizer, metric=None, batch_size=64, ema=0.0, trainable_vars=None, shuffle=True, step_tensor=None, validation_monitors=None, validation_batch_size=None, name=None, graph=None)

TrainOp 表示一组用于优化网络的操作。

TrainOp 旨在保存优化器的所有训练参数。然后,Trainer 类将实例化所有这些参数,并特别考虑网络的所有优化器(设置名称、范围...设置优化操作...)。

参数

  • loss: Tensor。用于评估网络成本的损失操作。优化器将使用此成本函数来训练网络。
  • optimizer: Optimizer。Tensorflow 优化器。用于训练网络的优化器。
  • metric: Tensor。用于评估的度量张量。
  • batch_size: int。馈送到此优化器的数据的批次大小。默认值:64。
  • ema: float。指数移动平均。
  • trainable_vars: tf.Variable 列表。用于训练的可训练变量列表。默认值:所有可训练变量。
  • shuffle: bool。打乱数据。
  • step_tensor: tf.Tensor。保存训练步骤的变量。如果没有提供,则将创建它。尽早定义步骤张量可能对网络创建很有用,例如学习率衰减。
  • validation_monitors: Tensor 对象列表。要在验证期间计算的变量列表,这些变量也用于生成摘要以输出到 TensorBoard。例如,这可以用于在训练期间定期记录混淆矩阵或 AUC 度量。每个变量的秩应为 1,即形状为 [None]。
  • validation_batch_size: int 或 None。如果为 int,则指定用于验证数据馈送的批次大小;否则默认为与 batch_size 相同。
  • name: str。此类的名称(可选)。
  • graph: tf.Graph。用于训练的 Tensorflow 图。默认值:默认 tf 图。

方法

initialize_fit (feed_dict, val_feed_dict, dprep_dict, daug_dict, show_metric, summ_writer, coord)

初始化用于馈送训练过程的数据。它旨在在开始拟合数据之前由 Trainer 使用。

参数
  • feed_dict: dict。要馈送的数据字典。
  • val_feed_dict: dictfloat。要馈送的验证数据字典或验证拆分。
  • dprep_dict: dict。数据预处理字典(以占位符为键,以相应的 DataPreprocessing 对象为值)。
  • daug_dict: dict。数据增强字典(以占位符为键,以相应的 DataAugmentation 对象为值)。
  • show_metric: bool。如果为 True,则在每个步骤显示准确率。
  • summ_writer: SummaryWriter。用于 Tensorboard 日志记录的摘要写入器。

initialize_training_ops (i, session, tensorboard_verbose, clip_gradients)

初始化用于训练的所有操作。因为一个网络可以有多个优化器,所以分配了一个 id 'i' 来区分它们。这旨在由 Trainer 在初始化所有训练操作时使用。

参数
  • i: int。此优化器训练过程 ID。
  • session: tf.Session。用于训练网络的会话。
  • tensorboard_verbose: int。日志详细级别。支持
0 - Loss, Accuracy.
1 - Loss, Accuracy, Gradients.
2 - Loss, Accuracy, Gradients, Weights.
3 - Loss, Accuracy, Gradients, Weights, Activations, Sparsity..
  • clip_gradients: float。用于裁剪梯度的选项。