翼度科技»论坛 编程开发 python 查看内容

使用tensorflow保存和恢复模型saver.restore

9

主题

9

帖子

27

积分

新手上路

Rank: 1

积分
27
tensorflow保存和恢复模型saver.restore

本文只对一些细节点做补充,大体的步骤就不详述了

保存模型

① 首先我使用的是tensorflow-gpu 1.4.0
② 这个版本生成的ckpt文件是这样的:

其中.meta存放的是网络模型和所有的变量;
.index 和.data一起存放变量数据
-0 -500表示checkpoint点
③ 保存的配置(一定细看代码注释!!!)
  1. import tensorflow as tf
  2. w1 = tf.Variable(变量的初始化, name='w1')
  3. w2 = tf.Variable(变量的初始化, name='w2')
  4. saver = tf.train.Saver([w1,w2],max_to_keep=5, keep_checkpoint_every_n_hours=2)   # 这里是细节部分,可以指定保存的变量,每两小时保存最近的5个模型
  5. sess = tf.Session()
  6. sess.run(tf.global_variables_initializer())
  7. saver.save(sess, './checkpoint_dir/MyModel',global_step=step,write_meta_graph=False))   # 因为模型没必要多次保存,所以写为False
复制代码
恢复模型(一定细看代码注释!!!)

代码:
  1. import tensorflow as tf
  2. with tf.Session() as sess:   
  3.     saver = tf.train.import_meta_graph(模型路径)  # 模型路径中必须指定到具体的模型下如:xx.ckpt-500.meta,且一般来讲,所有模型都是一样的,如果没有改变模型的条件下。
  4.     # 下面的restore就是在当前的sess下恢复了所有的变量
  5.     saver.restore(sess,数据路径)  # 数据路径也必须指定到具体某个模型的数据,但创建这个路径的方法很多,比如调用最后一个保存的模型tf.train.latest_checkpoint('./checkpoint_dir'),也可以是xx.ckpt-500.data,并且这两个是等效的,如果是xx.ckpt-0.data,就是第一个模型的数据
  6.     print(sess.run('w1:0'))  # 这里的w1必须加上:0
复制代码
tensorflow里的,保存和恢复模型的方式

重点在于,第一个文件用于 训练,保存图meta和训练好的参数data(后缀),在另一个文件中导入这个图和训练好的参数,用于预测或者接着训练。
大大减少了另一个文件里的 重复

第一种情况

产生变量的代码和恢复变量的代码在同一个文件时,可以直接如下调用:
  1. # 建模型
  2. saver = tf.train.Saver()

  3. with tf.Session() as sess:
  4.     # 存模型,注意此处的model是文件名,不是路径
  5.     saver.save(sess, "/tmp/model")

  6. with tf.Session() as sess:
  7.     # 恢复模型
  8.     saver.restore(sess, "/tmp/model")
复制代码
第二种情况

不想在另一个文件中,把产生变量的 一大堆代码重敲一遍,可以直接从保存好的 meta文件和data文件中恢复出来
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # @Time    : 2019/9/9 20:49
  4. # @Author  : ZZL
  5. # @File    : 保存检查点文件,并恢复.py
  6. import tensorflow as tf
  7. # Saving contents and operations.
  8. v1 = tf.placeholder(tf.float32, name="v1")
  9. v2 = tf.placeholder(tf.float32, name="v2")
  10. v3 = tf.multiply(v1, v2)
  11. vx = tf.Variable(10.0, name="vx")
  12. v4 = tf.add(v3, vx, name="v4")
  13. saver = tf.train.Saver([vx])
  14. with tf.Session() as sess:
  15.     with tf.device('/cpu:0'):
  16.         sess.run(tf.global_variables_initializer())
  17.         sess.run(vx.assign(tf.add(vx, vx)))
  18.         result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})
  19.         print(result)
  20.         print(saver.save(sess, "./model_ex1"))  # 该方法返回新创建的检查点文件的路径前缀。这个字符串可以直接传递给对“restore()”的调用。
复制代码
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # @Time    : 2019/9/9 20:54
  4. # @Author  : ZZL
  5. # @File    : 恢复文件.py
  6. import  tensorflow as tf

  7. saver = tf.train.import_meta_graph("./model_ex1.meta")
  8. sess = tf.Session()
  9. saver.restore(sess, "./model_ex1")
  10. result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
  11. print(result)
复制代码
先来个空图,loaded_graph,在会话中,导入之前构建好的图的文件 后缀 meta,loader.restore(sess, save_model_path)
在当前的loaded_graph中,导入构建好的图和图上的变量值。
  1. def test_model():

  2.     test_features, test_labels = pickle.load(open('preprocess_test.p', mode='rb'))
  3.     loaded_graph = tf.Graph()  # <tensorflow.python.framework.ops.Graph object at 0x0000017CB3702320>
  4. #     print( loaded_graph)
  5. #     print(tf.get_default_graph())  # <tensorflow.python.framework.ops.Graph object at 0x0000017C9A0C0C50>
  6.     with tf.Session(graph=loaded_graph) as sess:
  7.         # 读取模型
  8.         loader = tf.train.import_meta_graph(save_model_path + '.meta')
  9.         print(loader)
  10.         loader.restore(sess, save_model_path)

  11.         print(tf.get_default_graph())  # <tensorflow.python.framework.ops.Graph object at 0x0000017CB3702320>
  12.         # 从已经读入的模型中 获取tensors
  13.         loaded_x = loaded_graph.get_tensor_by_name('x:0')
  14.         loaded_y = loaded_graph.get_tensor_by_name('y:0')
  15.         loaded_keep_prob = loaded_graph.get_tensor_by_name('keep_prob:0')
  16.         loaded_logits = loaded_graph.get_tensor_by_name('logits:0')
  17.         loaded_acc = loaded_graph.get_tensor_by_name('accuracy:0')
  18.         
  19.         # 获取每个batch的准确率,再求平均值,这样可以节约内存
  20.         test_batch_acc_total = 0
  21.         test_batch_count = 0
  22.         
  23.         for test_feature_batch, test_label_batch in helper.batch_features_labels(test_features, test_labels, batch_size):
  24.             test_batch_acc_total += sess.run(
  25.                 loaded_acc,
  26.                 feed_dict={loaded_x: test_feature_batch, loaded_y: test_label_batch, loaded_keep_prob: 1.0})
  27.             test_batch_count += 1
复制代码
总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

来源:https://www.jb51.net/python/316248oz1.htm
免责声明:由于采集信息均来自互联网,如果侵犯了您的权益,请联系我们【E-Mail:cb@itdo.tech】 我们会及时删除侵权内容,谢谢合作!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x

举报 回复 使用道具