Copyright © 2022-2024 aizws.net · 网站版本: v1.2.6·内部版本: v1.23.3·
页面加载耗时 0.00 毫秒·物理内存 62.2MB ·虚拟内存 1300.8MB
欢迎来到 AI 中文社区(简称 AI 中文社),这里是学习交流 AI 人工智能技术的中文社区。 为了更好的体验,本站推荐使用 Chrome 浏览器。
在神经网络学习中slim常用函数与如何训练、保存模型文章里已经讲述了如何使用slim训练出来一个模型,这篇文章将会讲述如何预测。
载入模型的过程主要分为以下四步:
1、建立会话Session;
2、将img_input的placeholder传入网络,建立网络结构;
3、初始化所有变量;
4、利用saver对象restore载入所有参数。
这里要注意的重点是,在利用saver对象restore载入所有参数之前,必须要建立网络结构,因为网络结构对应着cpkt文件中的参数。
(网络层具有对应的名称scope。)
在运行实验代码前,可以直接下载代码,因为存在许多依赖的文件
import tensorflow as tf import numpy as np from nets import Net from tensorflow.examples.tutorials.mnist import input_data def compute_accuracy(x_data,y_data): global prediction y_pre = sess.run(prediction,feed_dict={img_input:x_data}) correct_prediction = tf.equal(tf.arg_max(y_data,1),tf.arg_max(y_pre,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) result = sess.run(accuracy,feed_dict = {img_input:x_data}) return result mnist = input_data.read_data_sets("MNIST_data",one_hot = "true") slim = tf.contrib.slim # img_input的placeholder img_input = tf.placeholder(tf.float32, shape = (None, 784)) img_reshape = tf.reshape(img_input,shape = (-1,28,28,1)) # 载入模型 sess = tf.Session() Conv_Net = Net.Conv_Net() # 将img_input的placeholder传入网络 prediction = Conv_Net.net(img_reshape) # 载入模型 ckpt_filename = './logs/model.ckpt-20000' # 初始化所有变量 sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() # 恢复 saver.restore(sess, ckpt_filename) print(compute_accuracy(mnist.test.images,mnist.test.labels))
运行结果为:
0.9921
以上就是python神经网络tensorflow利用训练好的模型进行预测的详细内容,更多关于tensorflow模型预测的资料请关注编程教程其它相关文章!
在SSD的框架中,除去tfrecord处理是非常重要的一环之外,slim框架的使用也是非常重要的一环,于是我开始学习slim啦 slim是什么slim的英文本意是苗条的意思,其实在 ...