Tensorflow 保留还原模型编程概论 gaunthan Posted on Jan 26 2018 ? Tensorflow ? > Tensorflow 提供了非常简单的 API 来保存和还原一个神经网络模型。持久化网络模型不仅降低/避免了程序异常退出时训练进度的损失,更使得我们可以复用已训练过的模型。 ## 概述 在学习深度学习时,我们的实验代码往往是“训练-输出结果-退出”这样的流程,并没有把训练得到的结果保存下来。在解决实际问题时,这种做法是十分危险的。Tensorflow 提供了一个 API 来持久化神经网络模型,它就是 tf.train.Saver 类。 ## 保存模型 下面的代码实现了持久化一个简单的 Tensorflow 模型的功能: ```py import tensorflow as tf a = tf.Variable(tf.constant(1.0, shape=[1]), name="a") b = tf.Variable(tf.constant(2.0, shape=[1]), name="b") result = a + b init_op = tf.global_variables_initializer() saver = tf.train.Saver() # 声明 tf.train.Saver 用于保存模型 with tf.Session() as sess: sess.run(init_op) saver.save(sess, "./checkpoint/model.ckpt") ``` 在这段代码中,通过 `saver.save()` 函数将 Tensorflow 模型保存到了 ./checkpoint/model.ckpt 文件中。执行该操作后,./checkpoint/ 目录下会多出几个文件: 文件名|作用 --|-- model.ckpt.meta|保存 Tensorflow 计算图的结构 model.ckpt.index| model.ckpt.data*|保存了 Tensorflow 程序中每一个变量的取值 checkpoint|保存了一个目录下所有的模型文件的清单 ## 恢复模型 下面这段代码加载了先前保存起来的网络模型: ```py import tensorflow as tf a = tf.Variable(tf.constant(1.0, shape=[1]), name="a") b = tf.Variable(tf.constant(2.0, shape=[1]), name="b") result = a + b saver = tf.train.Saver() # 声明 tf.train.Saver 用于保存模型 with tf.Session() as sess: saver.restore(sess, "./checkpoint/model.ckpt") print(sess.run(result)) ``` 注意到它和保存模型的代码几乎是一样的,但是其中少了初始化操作。 ## 直接加载持久化的图 如果不希望重新定义图上的运算,可以直接加载已经持久化的图: ```py import tensorflow as tf saver = tf.train.import_meta_graph("./checkpoint/model.ckpt/model.ckpt.meta") with tf.Session() as sess: saver.restore(sess, "./checkpoint/model.ckpt") # 通过张量的名称来获取计算结果 print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) ``` ## 部分保存/加载 先前给出的代码,默认保存和加载了计算图上定义的全部变量。然而,有时候只需要保存或加载一部分变量。可以通过在声明 tf.train.Saver 类时提供一个列表,指定需要保存或加载的变量。比如在加载模型的代码中使用 `saver = tf.train.Saver([a])` 命令来构建 Saver 类,那么只有变量 a 会被加载进来。 ## 变量重命名 有时候保存模型和加载模型的代码是由不同的人编写的,这就使得其中的变量命名很可能不同。例如,在上面的代码中我们保存了变量 a 和 b,但加载时可能编写了 `v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")` 这样的代码。此时直接使用 `tf.train.Saver()` 来加载模型会报变量找不到的错误。 为了解决这个问题,可以在构建 tf.tran.Saver 类时,向它传递一个字典,这个字典指定了保存时的变量名和需要加载的变量之间的对应关系: ```py saver = tf.train.Saver({"a": v1, "b": v2}) ``` 赏 Wechat Pay Alipay 研究生生存手册 Linux 安装 Anaconda 后引起的包问题