TensorFlow 训练模型的保存&加载

Skr
5个月前 阅读 158 点赞 2

什么是Tensorflow的模型


Tensorflow的模型主要包括神经网络的架构设计(或者称为计算图的设计)和已经训练好的网络参数。因此,Tensorflow模型包括的主要文件:


  • “.meta”:包含了计算图的结构
  • “.data”:包含了变量的值
  • “.index”:确认checkpoint
  • “checkpiont”:一个protocol buffer,包含了最近的一些checkpoints


存储一个Tensorflow的模型


当我们训练的神经网络模型的损失函数或者精度收敛时,我们需要把参数或者网络结构存储起来。如果我们想要存储整个网络结构和该网络的所有参数,我们需要创建一个tf.train.Saver()的实例。Tensorflow变量的作用域仅在Session内部。因此,我们必须在一个Session的内部存储有关的数据。


saver.save(sess,'my_test_model')


sess是我们创建的一个Session实例,my_test_model是我们给模型的命名。

具体的实例:


import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './my_test_model')
sess.close()


执行上述语句,我们会同级目录下看到新增的文件:


my_test_model.data-00000-of-00001
my_test_model.index
my_test_model.meta


如果网络架构更改了,Tensorflow会重写上述的文件。

如果我们想要每1000步保存一次,那么需要更改语句:


saver.save(sess, 'my_test_model', global_step=1000)


那么当训练时,我们会每1000次迭代存储一次模型。.meta会在第一次到达1000次迭代时创建,之后的每千步,就不需要在重新创建.meta文件了。只要图的架构 不更改,就不需要重新创建.meta文件。 如果不写步数,默认每次迭代保存一次。

如果我们要仅仅保留最近4次创建的模型,并且每两个小时存储一次模型,可以进行下面的操作:


# saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)


如果我们在tf.train.Saver()中不指定任何参数,那么Tensorflow会默认保存所有的变量。假设我们只想保留部分变量或者collection,那么需要显式地表明需要保留的对象。当创建tf.train.Saver()对象时,使用一个包含有关变量的list或者字典声明。比如:


import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1, w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './my_test_model')
sess.close()


导入一个训练好的模型


如果我们要导入一个训练好的模型,需要做以下两步:


创建一个网络


使用函数:


saver = tf.train.import_meta_graph('my_test_model-1000.meta')


把存储在my_test_model-1000.meta加载到saver当中。这个操作知识会把在.meta文件中定义的网络追加到当前网络的后面,我们仍然需要加载原来网络的参数数值。


加载参数


操作如下:


with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
    new_saver.restore(sess, tf.train.lasters_checkpoint('./'))


在这之后,w1和w2的数据就会被重新加载进来。


对导入的模型进行的操作


现在,学着加载模型,把模型用于预测、训练甚至更改模型的架构。现在构造一个简单的网络模型,保存并重新导入。注意一点:tf.placeholder的数据不会被保存 !!!!

先定义训练文件:


import tensorflow as tf

# 定义用于恢复变量的例子
w1 = tf.placeholder(dtype=tf.float32, name="w1")
w2 = tf.placeholder(dtype=tf.float32, name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

# 定义用于恢复操作的例子   w4=w3*b1,w3=(w1+w2)*b1
w3 = tf.add(w1, w2, name="part_op")
w4 = tf.multiply(w3, b1, name="op_to_restore")

sess = tf.Session()
sess.run(tf.global_variables_initializer())  # 时刻记着,要初始化

saver = tf.train.Saver()

print(sess.run(w4, feed_dict))  # 24.0

saver.save(sess, './my_test_model', global_step=1000)

sess.close()


定义加载文件:


import tensorflow as tf

sess = tf.Session()

saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}

# w4=w3*b1,w3=(w1+w2)*b1
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")  # 60.0

print(sess.run(op_to_restore, feed_dict))

sess.close()


当导入模型的时候,不但需要恢复计算图和相关的参数,而且需要重新对tf.placeholder喂数据。通过graph.get_tensor_by_name获取保存的操作和占位符。如果我们想要使用网络计算,仅需要给不同的占位符添加不同的数据即可。

如果我们想要对原来的网络添加更多的层数并接着训练它,可以按照下面的步骤处理:


import tensorflow as tf

sess = tf.Session()
# 恢复计算图
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
# 获取占位符
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}
# 恢复操作
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
# 增加新的操作
add_on_op = tf.multiply(op_to_restore, 2.0)
# 别忘了喂数据
print(sess.run(add_on_op, feed_dict))

sess.close()


由此可以看出,只需要把原来的操作加载完毕后,当成一个输出数据接入新的网络即可。

也可以把原来网络的一部分加载 到新的网络中,比如下面的操作:

先更改之前的一行代码


w3 = tf.add(w1, w2, name="part_op")


加载操作:


import tensorflow as tf

sess = tf.Session()

saver = tf.train.import_meta_graph("my_test_model-1000.meta")
saver.restore(sess, tf.train.latest_checkpoint('./'))

graph = tf.get_default_graph()

w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 14.0}

w3 = graph.get_tensor_by_name("part_op:0")

op = tf.multiply(w3, 4)
print(sess.run(op, feed_dict))  # 108.0
sess.close()


使用SavedModel的格式


SavedMode类把Saver类进行了一个更高层的封装,开发效率可能会更高,但是暂时没有前一种方法常用。Saver类更看重对变量的封装, 而SavedModel更看重压缩封装保存所有有用的信息。

保存操作:


import tensorflow as tf

tf.reset_default_graph()

w1 = tf.Variable(1.0, name="w1")
w2 = tf.Variable(2.0, name="w2")
w3 = tf.multiply(w1, w2, name="w3")

builder = tf.saved_model.builder.SavedModelBuilder('./SavedModel/')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(w3)
    builder.add_meta_graph_and_variables(sess,
                                         [tf.saved_model.tag_constants.TRAINING],
                                         signature_def_map=None,
                                         assets_collection=None)
builder.save()


读取操作:


import tensorflow as tf

with tf.Session() as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING],
                               './SavedModel')

    w1 = sess.run('w1:0')
    w2 = sess.run('w2:0')
    w3 = sess.run('w3:0')

    print(w1, w2, w3)


Enjoy your coding!


| 2
登录后可评论,马上登录吧~
评论 ( 0 )

还没有人评论...

相关推荐