全部課程
發(fā)布時(shí)間: 2019-01-03 23:52:17
8.1 實(shí)驗(yàn)介紹8.1.1 關(guān)于本實(shí)驗(yàn)本實(shí)驗(yàn)主要介紹如何保存模型和使用保存的模型,一般而言,訓(xùn)練好的模型都需要保存。8.1.2 實(shí)驗(yàn)?zāi)康睦斫馊绾伪4婺P汀?br>
理解如何載入模型,并使用。8.1.3 實(shí)驗(yàn)介紹本實(shí)驗(yàn)主要是基于前面的實(shí)驗(yàn)添加模型的保存及載入功能。在文件中生成模擬數(shù)據(jù)之后,加入對(duì)圖變量的充值,在 session 創(chuàng)建之前定義 saver 及保存路徑,在 session 種訓(xùn)練結(jié)束后,保存模型。8.1.4 實(shí)驗(yàn)步驟步驟 1 登陸華為云。
步驟 2 點(diǎn)擊右上方的控制臺(tái)。
步驟 3 選擇彈性云服務(wù)器,網(wǎng)頁(yè)中會(huì)顯示該彈性云的可進(jìn)行的操作,選擇遠(yuǎn)程登錄。即登錄到彈性云服務(wù)器。
步驟 4 輸入指令 ll,查看當(dāng)前目錄下的文件。
步驟 5 輸入命令 vi mnist_train.py,創(chuàng)建新的 Python 腳本。
步驟 6 輸入命令 i,進(jìn)入編輯模式開(kāi)始編輯,輸入腳本內(nèi)容。
步驟 7 輸入命令 :wq!,保存并退出。
步驟 8 輸入命令 cat mnist_train.py 查看代碼。
步驟 9 運(yùn)行測(cè)試。輸入命令 python3 mnist_train.py。
8.2 實(shí)驗(yàn)過(guò)程8.2.1 導(dǎo)入數(shù)據(jù)集# -*- coding: utf-8 -*- #!/usr/bin/env python
# 導(dǎo)入 mnist 數(shù)據(jù)庫(kù)
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import tensorflow as tf import os8.2.2 定義變量# 定義輸入變量
x = tf.placeholder(tf.float32, [None, 784])
# 定義參數(shù)
W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10]))
# 定義激勵(lì)函數(shù)
y = tf.nn.softmax(tf.matmul(x, W) + b)
# 定義輸出變量
y_ = tf.placeholder(tf.float32, [None, 10])
# 定義成本函數(shù)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
# 定義優(yōu)化函數(shù)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)8.2.3 初始化# 初始化變量
init = tf.global_variables_initializer()
# 定義會(huì)話(huà)
sess = tf.Session()8.2.4 運(yùn)行 session# 運(yùn)行初始化
sess.run(init)
# 定義模型保存對(duì)象
saver = tf.train.Saver()
# 循環(huán)訓(xùn)練 1000 次
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_:batch_ys})
print("訓(xùn)練完成!")8.2.5 創(chuàng)建模型保存目錄# 創(chuàng)建模型保存目錄
model_dir = "mnist_model" model_name = "ckp"
if not os.path.exists(model_dir): os.mkdir(model_dir)8.2.6 保存模型# 保存模型
saver.save(sess, os.path.join(model_dir, model_name))
print("保存模型成功!")8.2.7 實(shí)驗(yàn)結(jié)果訓(xùn)練完成!
保存模型成功!
8.3 使用模型8.3.1 導(dǎo)入數(shù)據(jù)集# -*- coding: utf-8 -*- #!/usr/bin/env python
# 導(dǎo)入 mnist 數(shù)據(jù)庫(kù)
from tensorflow.examples.tutorials.mnist import input_data mnist=input_data.read_data_sets("MNIST_data",one_hot=True)
import tensorflow as tf8.3.2 創(chuàng)建回話(huà)# 創(chuàng)建會(huì)話(huà)
sess = tf.Session()
8.3.3 定義變量# 定義輸入變量
x = tf.placeholder(tf.float32, [None, 784])
# 定義參數(shù)
W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10]))
# 定義模型和激勵(lì)函數(shù)
y = tf.nn.softmax(tf.matmul(x, W) + b)
# 定義模型保存對(duì)象
saver = tf.train.Saver([W, b])8.3.4 恢復(fù)模型# 恢復(fù)模型
saver.restore(sess, "mnist/ckp")
print("恢復(fù)模型成功!")
# 取出一個(gè)測(cè)試圖片
idx=0
img = mnist.test.images[idx]8.3.5 計(jì)算結(jié)果# 根據(jù)模型計(jì)算結(jié)果
ret = sess.run(y, feed_dict = {x : img.reshape(1, 784)})
print("計(jì)算模型結(jié)果成功!")
# 顯示測(cè)試結(jié)果
print("預(yù)測(cè)結(jié)果:%d"%(ret.argmax()))
print("實(shí)際結(jié)果:%d"%(mnist.test.labels[idx].argmax()))8.3.6 實(shí)驗(yàn)結(jié)果恢復(fù)模型成功!
計(jì)算模型結(jié)果成功! 預(yù)測(cè)結(jié)果:7
實(shí)際結(jié)果:7
8.4 實(shí)例描述
本實(shí)驗(yàn)主要是保存以后模型,方便后續(xù)的模型載入與使用,這里輸出的是中間狀態(tài) cost 損
上一篇: {HTML5}DOM節(jié)點(diǎn)操作-第一節(jié)
下一篇: {springboot}工程文件介紹