tensorflow保存和恢复模型saver.restore
本文只对一些细节点做补充,大体的步骤就不详述了
保存模型
① 首先我使用的是tensorflow-gpu 1.4.0
② 这个版本生成的ckpt文件是这样的:
其中.meta存放的是网络模型和所有的变量;
.index 和.data一起存放变量数据
-0 -500表示checkpoint点
③ 保存的配置(一定细看代码注释!!!)
1 2 3 4 5 6 7 | import tensorflow as tf w1 = tf.Variable(变量的初始化, name = 'w1' ) w2 = tf.Variable(变量的初始化, name = 'w2' ) saver = tf.train.Saver([w1,w2],max_to_keep = 5 , keep_checkpoint_every_n_hours = 2 ) # 这里是细节部分,可以指定保存的变量,每两小时保存最近的5个模型 sess = tf.Session() sess.run(tf.global_variables_initializer()) saver.save(sess, './checkpoint_dir/MyModel' ,global_step = step,write_meta_graph = False )) # 因为模型没必要多次保存,所以写为False |
恢复模型(一定细看代码注释!!!)
代码:
1 2 3 4 5 6 | import tensorflow as tf with tf.Session() as sess: saver = tf.train.import_meta_graph(模型路径) # 模型路径中必须指定到具体的模型下如:xx.ckpt-500.meta,且一般来讲,所有模型都是一样的,如果没有改变模型的条件下。 # 下面的restore就是在当前的sess下恢复了所有的变量 saver.restore(sess,数据路径) # 数据路径也必须指定到具体某个模型的数据,但创建这个路径的方法很多,比如调用最后一个保存的模型tf.train.latest_checkpoint('./checkpoint_dir'),也可以是xx.ckpt-500.data,并且这两个是等效的,如果是xx.ckpt-0.data,就是第一个模型的数据 print (sess.run( 'w1:0' )) # 这里的w1必须加上:0 |
tensorflow里的,保存和恢复模型的方式
重点在于,第一个文件用于 训练,保存图meta和训练好的参数data(后缀),在另一个文件中导入这个图和训练好的参数,用于预测或者接着训练。
大大减少了另一个文件里的 重复
第一种情况
产生变量的代码和恢复变量的代码在同一个文件时,可以直接如下调用:
1 2 3 4 5 6 7 8 9 10 | # 建模型 saver = tf.train.Saver() with tf.Session() as sess: # 存模型,注意此处的model是文件名,不是路径 saver.save(sess, "/tmp/model" ) with tf.Session() as sess: # 恢复模型 saver.restore(sess, "/tmp/model" ) |
第二种情况
不想在另一个文件中,把产生变量的 一大堆代码重敲一遍,可以直接从保存好的 meta文件和data文件中恢复出来
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | #!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2019/9/9 20:49 # @Author : ZZL # @File : 保存检查点文件,并恢复.py import tensorflow as tf # Saving contents and operations. v1 = tf.placeholder(tf.float32, name = "v1" ) v2 = tf.placeholder(tf.float32, name = "v2" ) v3 = tf.multiply(v1, v2) vx = tf.Variable( 10.0 , name = "vx" ) v4 = tf.add(v3, vx, name = "v4" ) saver = tf.train.Saver([vx]) with tf.Session() as sess: with tf.device( '/cpu:0' ): sess.run(tf.global_variables_initializer()) sess.run(vx.assign(tf.add(vx, vx))) result = sess.run(v4, feed_dict = {v1: 12.0 , v2: 3.3 }) print (result) print (saver.save(sess, "./model_ex1" )) # 该方法返回新创建的检查点文件的路径前缀。这个字符串可以直接传递给对“restore()”的调用。 |
1 2 3 4 5 6 7 8 9 10 11 12 | #!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2019/9/9 20:54 # @Author : ZZL # @File : 恢复文件.py import tensorflow as tf saver = tf.train.import_meta_graph( "./model_ex1.meta" ) sess = tf.Session() saver.restore(sess, "./model_ex1" ) result = sess.run( "v4:0" , feed_dict = { "v1:0" : 12.0 , "v2:0" : 3.3 }) print (result) |
先来个空图,loaded_graph,在会话中,导入之前构建好的图的文件 后缀 meta,loader.restore(sess, save_model_path)
在当前的loaded_graph中,导入构建好的图和图上的变量值。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | def test_model(): test_features, test_labels = pickle.load( open ( 'preprocess_test.p' , mode = 'rb' )) loaded_graph = tf.Graph() # # print( loaded_graph) # print(tf.get_default_graph()) # with tf.Session(graph = loaded_graph) as sess: # 读取模型 loader = tf.train.import_meta_graph(save_model_path + '.meta' ) print (loader) loader.restore(sess, save_model_path) print (tf.get_default_graph()) # # 从已经读入的模型中 获取tensors loaded_x = loaded_graph.get_tensor_by_name( 'x:0' ) loaded_y = loaded_graph.get_tensor_by_name( 'y:0' ) loaded_keep_prob = loaded_graph.get_tensor_by_name( 'keep_prob:0' ) loaded_logits = loaded_graph.get_tensor_by_name( 'logits:0' ) loaded_acc = loaded_graph.get_tensor_by_name( 'accuracy:0' ) # 获取每个batch的准确率,再求平均值,这样可以节约内存 test_batch_acc_total = 0 test_batch_count = 0 for test_feature_batch, test_label_batch in helper.batch_features_labels(test_features, test_labels, batch_size): test_batch_acc_total + = sess.run( loaded_acc, feed_dict = {loaded_x: test_feature_batch, loaded_y: test_label_batch, loaded_keep_prob: 1.0 }) test_batch_count + = 1 |
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持IT俱乐部。