保存与加载Tensorflow模型
保存与加载模型
安装tensorflow-datasets,导入依赖项:
%pip install tensorflow-datasets
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()
tf.__version__
'2.3.0'
mirrored_strategy = tf.distribute.MirroredStrategy()
WARNING:tensorflow:There are non-GPU devices in tf.distribute.Strategy, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
创建一个分发变量和图形的策略
tf.distribute.MirroredStrategy 策略是如何运作的?
所有变量和模型图都复制在副本上。
输入都均匀分布在副本中。
每个副本在收到输入后计算输入的损失和梯度。
通过求和,每一个副本上的梯度都能同步。
同步后,每个副本上的复制的变量都可以同样更新。
def get_data():
datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
return train_dataset, eval_dataset
def get_model():
with mirrored_strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPool2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
return model
训练模型
model = get_model()
train_dataset, eval_dataset = get_data()
[1mDownloading and preparing dataset mnist/3.0.1 (download: Unknown size, generated: Unknown size, total: Unknown size) to C:\Users\tensorflow_datasets\mnist\3.0.1...[0m
Shuffling and writing examples to C:\Users\tensorflow_datasets\mnist\3.0.1.incompleteI371SH\mnist-train.tfrecord
Shuffling and writing examples to C:\Users\tensorflow_datasets\mnist\3.0.1.incompleteI371SH\mnist-test.tfrecord
[1mDataset mnist downloaded and prepared to C:\Users\tensorflow_datasets\mnist\3.0.1. Subsequent calls will reuse this data.[0m
model.fit(train_dataset, epochs=2)
Epoch 1/2
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\data\ops\multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.data.Iterator.get_next_as_optional() instead.
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\data\ops\multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.data.Iterator.get_next_as_optional() instead.
938/938 [] - 18s 19ms/step - loss: 0.2089 - accuracy: 0.9389
Epoch 2/2
938/938 [] - 19s 20ms/step - loss: 0.0689 - accuracy: 0.97980s - los
<tensorflow.python.keras.callbacks.History at 0x2b4c939e278>
保存并加载模型
现在有了一个简单的模型可以使用,让我们看一下保存/加载API。有两套可用的API:
高级的keras model.save和tf.keras.models.load_model
低级的tf.saved_model.save和tf.saved_model.load
使用keras API
keras_model_path = "./tmp/keras_save"
model.save(keras_model_path)
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\tracking\tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\tracking\tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\tracking\tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\tracking\tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: ./tmp/keras_save\assets
INFO:tensorflow:Assets written to: ./tmp/keras_save\assets
模型保存成功,看一下保存的文件
└───tmp
└───keras_save
├───assets
└───variables
└───variables.data-00000-of-00001
└───variables.index
└───saved_model.pb
接着还原模型
restored_keras_model = tf.keras.models.load_model(keras_model_path)
还原后的模型可以继续训练
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [] - 16s 17ms/step - loss: 0.0494 - accuracy: 0.09890s - loss: 0.0493 - accuracy: 0.09
Epoch 2/2
938/938 [] - 16s 17ms/step - loss: 0.0353 - accuracy: 0.0989
<tensorflow.python.keras.callbacks.History at 0x2b4c5b5a128>
现在加载模型并使用进行训练tf.distribute.Strategy
another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [] - 16s 17ms/step - loss: 0.0501 - accuracy: 0.0990
Epoch 2/2
938/938 [] - 16s 17ms/step - loss: 0.0354 - accuracy: 0.0989
restored_keras_model_ds.predict
<tensorflow.python.keras.engine.sequential.Sequential at 0x2b4c75ea2b0>
使用tf.saved_model API
现在使用低级的api,保存方法和keras类似
model = get_model()
saved_model_path = "./tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: ./tmp/tf_save\assets
INFO:tensorflow:Assets written to: ./tmp/tf_save\assets
可以使用进行加载tf.saved_model.load()。但是,由于它是一个较低级别的API(因此具有更广泛的用例范围),因此它不会返回Keras模型。相反,它返回一个对象,该对象包含可用于进行推断的函数。例如:
还可以以分布式方式加载和进行推断:
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
dist_predict_dataset = another_strategy.experimental_distribute_dataset(
predict_dataset)
# Calling the function in a distributed manner
for batch in dist_predict_dataset:
another_strategy.run(inference_func,args=(batch,))
WARNING:tensorflow:There are non-GPU devices in tf.distribute.Strategy, not using nccl allreduce.
WARNING:tensorflow:There are non-GPU devices in tf.distribute.Strategy, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap call_for_each_replica or experimental_run or experimental_run_v2 inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap call_for_each_replica or experimental_run or experimental_run_v2 inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap call_for_each_replica or experimental_run or experimental_run_v2 inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap call_for_each_replica or experimental_run or experimental_run_v2 inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap call_for_each_replica or experimental_run or experimental_run_v2 inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap call_for_each_replica or experimental_run or experimental_run_v2 inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap call_for_each_replica or experimental_run or experimental_run_v2 inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap call_for_each_replica or experimental_run or experimental_run_v2 inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap call_for_each_replica or experimental_run or experimental_run_v2 inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap call_for_each_replica or experimental_run or experimental_run_v2 inside a tf.function to get the best performance.
保存检查点
检查站捕获所有参数(的精确值tf.Variable由模型中使用的对象)。检查点不包含由模型所定义的计算的任何描述,因此通常仅当将使用保存的参数值源代码可用有用。
在另一方面中SavedModel格式包括由除了参数值(检查点)模型中定义的计算的序列化描述。这种格式的模型是独立于创建模型的源代码。因此,它们适用于通过TensorFlow部署服务,TensorFlow精简版,TensorFlow.js,或在其他编程语言(的C,C ++,JAVA,围棋,防锈,C#等TensorFlow API)的程序。
本指南涵盖API进行写入和读出检查站。
建立
class Net(tf.keras.Model):
# 一个简单的线性模型
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
net.save_weights('./tmp/easy_checkpoint')
写检查站
一个TensorFlow模型的持久状态被存储在tf.Variable对象。这些可以直接构造,但通常通过高级API等生成tf.keras.layers或tf.keras.Model 。
管理变量最简单的方法是将其安装到Python对象,然后引用这些对象。
的子类tf.train.Checkpoint , tf.keras.layers.Layer和tf.keras.Model自动跟踪分配给它们的属性变量。下面的例子构造了一个简单的线性模型,然后写入其中包含所有模型的变量值的检查站。
您可以轻松地保存模型检查点与Model.save_weights
手动检查点
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
with tf.GradientTape() as tape:
output = net(example['x'])
loss = tf.reduce_mean(tf.abs(output - example['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
创建检查点的对象
手动进行检查点,您将需要一个tf.train.Checkpoint对象。凡检查点你想要的对象被设置为对象的属性。
一个tf.train.CheckpointManager也可用于管理多个检查点有帮助。
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tmp/tf_ckpts', max_to_keep=3)
训练和保存检查点模型
下面的训练循环创建模型的实例和优化的,然后收集他们入tf.train.Checkpoint对象。它在循环中调用数据的每批训练步骤,并定期检查点写入到磁盘。
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("恢复点:{}".format(manager.latest_checkpoint))
else:
print("开始初始化")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("保存检查点 {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
开始初始化
保存检查点 10: ./tmp/tf_ckpts\ckpt-1
loss 29.78
保存检查点 20: ./tmp/tf_ckpts\ckpt-2
loss 23.19
保存检查点 30: ./tmp/tf_ckpts\ckpt-3
loss 16.63
保存检查点 40: ./tmp/tf_ckpts\ckpt-4
loss 10.17
保存检查点 50: ./tmp/tf_ckpts\ckpt-5
loss 4.09
恢复和继续训练
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tmp/tf_ckpts', max_to_keep=3)
train_and_checkpoint(net, manager)
恢复点:./tmp/tf_ckpts\ckpt-10
保存检查点 110: ./tmp/tf_ckpts\ckpt-11
loss 0.27
保存检查点 120: ./tmp/tf_ckpts\ckpt-12
loss 0.20
保存检查点 130: ./tmp/tf_ckpts\ckpt-13
loss 0.16
保存检查点 140: ./tmp/tf_ckpts\ckpt-14
loss 0.21
保存检查点 150: ./tmp/tf_ckpts\ckpt-15
loss 0.20
print(manager.checkpoints)
['./tmp/tf_ckpts\ckpt-13', './tmp/tf_ckpts\ckpt-14', './tmp/tf_ckpts\ckpt-15']
这些路径,如'./tf_ckpts/ckpt-10' ,不是磁盘上的文件。相反,它们是一个前缀index文件和包含可变值的一个或多个数据文件。这些前缀在单个组合在一起checkpoint文件( './tf_ckpts/checkpoint' ),其中CheckpointManager保存其状态。
手动检查
tf.train.list_variables列出了检查点键和变量的形状在一个检查点。检查点键是显示在以上图上的路径。
tf.train.list_variables(tf.train.latest_checkpoint('./tmp/tf_ckpts'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
[1, 5]),
('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
[1, 5]),
('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
('step/.ATTRIBUTES/VARIABLE_VALUE', [])]
保存与估计基于对象的检查站
通过默认保存变量名,而不是在前面的章节中描述的对象图检查点估计。 tf.train.Checkpoint将接受基于域名的检查点,但移动估计的模型以外的部位时,变量名可以更改model_fn 。保存基于对象的检查站,使得它更容易培养的估算内部模型,然后外面用它之一。
import tensorflow.compat.v1 as tf_compat
def model_fn(features, labels, mode):
net = Net()
opt = tf.keras.optimizers.Adam(0.1)
ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
optimizers=opt, net=net)
with tf.GradientTape()as tape:
output = net(features['x'])
loss = tf.reduce_mean(tf.abs(output - features['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
return tf.estimator.EstimatorSpec(
mode,
loss=loss,
train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
ckpt.step.assign_add(1)),scaffold=tf_compat.train.Scaffold(saver=ckpt))
tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tmp/tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tmp/tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec(
INFO:tensorflow:Using config: {'_model_dir': './tmp/tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec(
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From D:\Anaconda3\envs\py36_tf2\lib\site-packages\tensorflow\python\training\training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tmp/tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into ./tmp/tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.505075, step = 1
INFO:tensorflow:loss = 4.505075, step = 1
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tmp/tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into ./tmp/tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 36.96539.
INFO:tensorflow:Loss for final step: 36.96539.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x2b4ccb5f588>
tf.train.latest_checkpoint('./tmp/tf_estimator_example')
'./tmp/tf_estimator_example\model.ckpt-10'
tf.train.Checkpoint则可以从其加载估计的检查站model_dir 。
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tmp/tf_estimator_example/'))
ckpt.step.numpy()
10