TenorFlow多层感知机识别手写体
文章目录
- 数据准备
- 建立模型
- 建立输入层 x
- 建立隐藏层h1
- 建立隐藏层h2
- 建立输出层
- 定义训练方式
- 建立训练数据label真实值 placeholder
- 定义loss function
- 选择optimizer
- 定义评估模型的准确率
- 计算每一项数据是否正确预测
- 将计算预测正确结果,加总平均
- 开始训练
- 画出误差执行结果
- 画出准确率执行结果
- 评估模型的准确率
- 进行预测
- 找出预测错误
GITHUB地址https://github.com/fz861062923/TensorFlow
注意下载数据连接的是外网,有一股神秘力量让你403
数据准备
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.from ._conv import register_converters as _register_convertersWARNING:tensorflow:From <ipython-input-1-2ee827ab903d>:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
print('train images :', mnist.train.images.shape,'labels:' , mnist.train.labels.shape)
print('validation images:', mnist.validation.images.shape,' labels:' , mnist.validation.labels.shape)
print('test images :', mnist.test.images.shape,'labels:' , mnist.test.labels.shape)
train images : (55000, 784) labels: (55000, 10)
validation images: (5000, 784) labels: (5000, 10)
test images : (10000, 784) labels: (10000, 10)
建立模型
def layer(output_dim,input_dim,inputs, activation=None):#激活函数默认为NoneW = tf.Variable(tf.random_normal([input_dim, output_dim]))#以正态分布的随机数建立并且初始化权重Wb = tf.Variable(tf.random_normal([1, output_dim]))XWb = tf.matmul(inputs, W) + bif activation is None:outputs = XWbelse:outputs = activation(XWb)return outputs
建立输入层 x
x = tf.placeholder("float", [None, 784])
建立隐藏层h1
h1=layer(output_dim=1000,input_dim=784,inputs=x ,activation=tf.nn.relu)
建立隐藏层h2
h2=layer(output_dim=1000,input_dim=1000,inputs=h1 ,activation=tf.nn.relu)
建立输出层
y_predict=layer(output_dim=10,input_dim=1000,inputs=h2,activation=None)
定义训练方式
建立训练数据label真实值 placeholder
y_label = tf.placeholder("float", [None, 10])#训练数据的个数很多所以设置为None
定义loss function
# 深度学习模型的训练中使用交叉熵训练的效果比较好
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_predict , labels=y_label))
选择optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=0.001) \.minimize(loss_function)
#使用Loss_function来计算误差,并且按照误差更新模型权重与偏差,使误差最小化
定义评估模型的准确率
计算每一项数据是否正确预测
correct_prediction = tf.equal(tf.argmax(y_label , 1),tf.argmax(y_predict, 1))#将one-hot encoding转化为1所在的位数,方便比较
将计算预测正确结果,加总平均
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
开始训练
trainEpochs = 15#执行15个训练周期
batchSize = 100#每一批的数量为100
totalBatchs = int(mnist.train.num_examples/batchSize)#计算每一个训练周期应该执行的次数
epoch_list=[];accuracy_list=[];loss_list=[];
from time import time
startTime=time()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(trainEpochs):#执行15个训练周期#每个训练周期执行550批次训练for i in range(totalBatchs):batch_x, batch_y = mnist.train.next_batch(batchSize)#用该函数批次读取数据sess.run(optimizer,feed_dict={x: batch_x,y_label: batch_y})#使用验证数据计算准确率loss,acc = sess.run([loss_function,accuracy],feed_dict={x: mnist.validation.images, #验证数据的featuresy_label: mnist.validation.labels})#验证数据的labelepoch_list.append(epoch)loss_list.append(loss);accuracy_list.append(acc) print("Train Epoch:", '%02d' % (epoch+1), \"Loss=","{:.9f}".format(loss)," Accuracy=",acc)duration =time()-startTime
print("Train Finished takes:",duration)
Train Epoch: 01 Loss= 133.117172241 Accuracy= 0.9194
Train Epoch: 02 Loss= 88.949943542 Accuracy= 0.9392
Train Epoch: 03 Loss= 80.701606750 Accuracy= 0.9446
Train Epoch: 04 Loss= 72.045913696 Accuracy= 0.9506
Train Epoch: 05 Loss= 71.911483765 Accuracy= 0.9502
Train Epoch: 06 Loss= 63.642936707 Accuracy= 0.9558
Train Epoch: 07 Loss= 67.192626953 Accuracy= 0.9494
Train Epoch: 08 Loss= 55.959281921 Accuracy= 0.9618
Train Epoch: 09 Loss= 58.867351532 Accuracy= 0.9592
Train Epoch: 10 Loss= 61.904548645 Accuracy= 0.9612
Train Epoch: 11 Loss= 58.283069611 Accuracy= 0.9608
Train Epoch: 12 Loss= 54.332244873 Accuracy= 0.9646
Train Epoch: 13 Loss= 58.152175903 Accuracy= 0.9624
Train Epoch: 14 Loss= 51.552104950 Accuracy= 0.9688
Train Epoch: 15 Loss= 52.803482056 Accuracy= 0.9678
Train Finished takes: 545.0556836128235
画出误差执行结果
%matplotlib inline
import matplotlib.pyplot as plt
fig = plt.gcf()#获取当前的figure图
fig.set_size_inches(4,2)#设置图的大小
plt.plot(epoch_list, loss_list, label = 'loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss'], loc='upper left')
<matplotlib.legend.Legend at 0x1edb8d4c240>
画出准确率执行结果
plt.plot(epoch_list, accuracy_list,label="accuracy" )
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show()
评估模型的准确率
print("Accuracy:", sess.run(accuracy,feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))
Accuracy: 0.9643
进行预测
prediction_result=sess.run(tf.argmax(y_predict,1),feed_dict={x: mnist.test.images })
prediction_result[:10]
array([7, 2, 1, 0, 4, 1, 4, 9, 6, 9], dtype=int64)
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,prediction,idx,num=10):fig = plt.gcf()fig.set_size_inches(12, 14)if num>25: num=25 for i in range(0, num):ax=plt.subplot(5,5, 1+i)ax.imshow(np.reshape(images[idx],(28, 28)), cmap='binary')title= "label=" +str(np.argmax(labels[idx]))if len(prediction)>0:title+=",predict="+str(prediction[idx]) ax.set_title(title,fontsize=10) ax.set_xticks([]);ax.set_yticks([]) idx+=1 plt.show()
plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,0)
y_predict_Onehot=sess.run(y_predict,feed_dict={x: mnist.test.images })
y_predict_Onehot[8]
array([-6185.544 , -5329.589 , 1897.1707 , -3942.7764 , 347.9809 ,5513.258 , 6735.7153 , -5088.5273 , 649.2062 , 69.50408],dtype=float32)
找出预测错误
for i in range(400):if prediction_result[i]!=np.argmax(mnist.test.labels[i]):print("i="+str(i)+" label=",np.argmax(mnist.test.labels[i]),"predict=",prediction_result[i])
i=8 label= 5 predict= 6
i=18 label= 3 predict= 8
i=149 label= 2 predict= 4
i=151 label= 9 predict= 8
i=233 label= 8 predict= 7
i=241 label= 9 predict= 8
i=245 label= 3 predict= 5
i=247 label= 4 predict= 2
i=259 label= 6 predict= 0
i=320 label= 9 predict= 1
i=340 label= 5 predict= 3
i=381 label= 3 predict= 7
i=386 label= 6 predict= 5
sess.close()
相关文章:

TenorFlow多层感知机识别手写体
文章目录 数据准备建立模型建立输入层 x建立隐藏层h1建立隐藏层h2建立输出层 定义训练方式建立训练数据label真实值 placeholder定义loss function选择optimizer 定义评估模型的准确率计算每一项数据是否正确预测将计算预测正确结果,加总平均 开始训练画出误差执行结…...

Java基础(二十六):Java8 Stream流及Optional类
Java基础系列文章 Java基础(一):语言概述 Java基础(二):原码、反码、补码及进制之间的运算 Java基础(三):数据类型与进制 Java基础(四):逻辑运算符和位运算符 Java基础(五):流程控制语句 Java基础(六)࿱…...

qt - 19种精美软件样式
qt - 19种精美软件样式 一、效果演示二、核心程序三、下载链接 一、效果演示 二、核心程序 #include "mainwindow.h"#include <QtAdvancedStylesheet.h> #include <QmlStyleUrlInterceptor.h>#include "ui_mainwindow.h" #include <QDir&g…...
vue 使用docx库生成word表格文档
在Vue.js中生成Word表格文档,可以通过前端库来实现。这些库可以帮助我们轻松地将HTML表格转换为Word文档(通常是.docx格式)。以下是一些流行的前端库,它们可以用于在Vue项目中生成Word表格文档: docx…...

ElementUI table表格组件实现双击编辑单元格失去焦点还原,支持多单元格
在使用ElementUI table表格组件时有时需要双击单元格显示编辑状态,失去焦点时还原表格显示。 实现思路: 在数据中增加isFocus:false.控制是否显示在table中用cell-dblclick双击方法 先看效果: 上源码:在表格模板中用scope.row…...

Java基于SpringBoot+Vue的图书管理系统
博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…...
【云安全】Hypervisor与虚拟机
Hypervisor 也被称为虚拟机监视器(Virtual Machine Monitor,VMM),主要作用是让多个操作系统可以在同一台物理机上运行。 Type-1 Hypervisor 与 Typer-2 Hypervisor Type-1 Hypervisor 直接安装在物理服务器上,不依赖…...

JS文本加密方法探究
在前端开发中,有时候我们需要对敏感文本进行简单的加密,以提高安全性。本文将介绍一种基于 JavaScript 实现的文本加密方法,使用了 Base64、Unicode 和 ROT13 编码。 示例代码 function encodeText(text) {// Base64编码var base64Encoded …...

推荐彩虹知识付费商城免授权7.0源码
彩虹知识付费商城免授权7.0源码,最低配置环境 PHP7.2 1、上传源码到网站根目录,导入数据库文件:xydai.sql 2、修改数据库配置文件:/config.php 3、后台:/admin 账号:admin 密码:123456 4、前…...

【天衍系列 04】深入理解Flink的ElasticsearchSink组件:实时数据流如何无缝地流向Elasticsearch
文章目录 01 Elasticsearch Sink 基础概念02 Elasticsearch Sink 工作原理03 Elasticsearch Sink 核心组件04 Elasticsearch Sink 配置参数05 Elasticsearch Sink 依赖管理06 Elasticsearch Sink 初阶实战07 Elasticsearch Sink 进阶实战7.1 包结构 & 项目配置项目配置appl…...

一、ActiveMQ介绍
ActiveMQ介绍 一、JMS1.jms介绍2.jms消息传递模式3.JMS编码总体架构 二、消息中间件三、ActiveMQ介绍1.引入的原因1.1 原因1.2 遇到的问题1.3 解决思路 2.定义3.特点3.1 异步处理3.2 应用系统之间解耦3.3 实际-整体架构 4.作用 一、JMS 1.jms介绍 jms是java消息服务接口规范&…...
【牛客】寒假训练营1 I-It‘s bertrand paradox. Again! 题解
传送门:It’s bertrand paradox. Again! 标签:随机 题目大意 有两个人分别用两种方式在二维平面上随机生成1e5个圆,每个圆上的每一个点(x,y)都满足-100<x<100且-100<y<100,现在将某个人生成的1e5个圆的圆心和半径告…...

各种手型都合适,功能高度可定制,雷柏VT9PRO mini和VT9PRO游戏鼠标上手
去年雷柏推出了一系列支持4KHz回报率的鼠标,有着非常敏捷的反应速度,在游戏中操作体验十分出色。尤其是这系列4K鼠标不仅型号丰富,而且对玩家的操作习惯、手型适应也很好,像是VT9系列就主打轻巧,还有专门针对小手用户的…...
sql建库,建表基础操作
当涉及到SQL建库和建表操作时,以下是一个简单的示例: 1. 建库(创建数据库) sql复制代码 CREATE DATABASE mydatabase; 上述语句将创建一个名为mydatabase的数据库。 2. 选择数据库 在创建表之前,需要选择要在其中…...
算法训练营day32,贪心算法6
import "strconv" //738. 单调递增的数字 func monotoneIncreasingDigits(n int) int { str : strconv.Itoa(n) nums : []byte(str) length : len(nums) if length < 1 { return n } for i : length - 1; i > 0; i-- { //如果前一个数字比当前值大࿰…...

CTR之行为序列建模用户兴趣:DIN
在前面的文章中,已经介绍了很多关于推荐系统中CTR预估的相关技术,今天这篇文章也是延续这个主题。但不同的,重点是关于用户行为序列建模,阿里出品。 概要 论文:Deep Interest Network for Click-Through Rate Predict…...
Java使用Redis实现分页功能
分页功能实现应该是比较常见的,对于redis来说,近期刷题就发现了lrange、zrange这些指令,这个指令怎么使用呢? 我们接下来就来讲解下。 目录 指令简介lrangezrange Java实现Redis实现分页功能 指令简介 lrange lrange 是 Redis 中…...
Qt标准对话框设置
Qt标准对话框设置,设置字体、调色板、进度条等。 #include "mainwindow.h" #include "ui_mainwindow.h"MainWindow::MainWindow(QWidget *parent): QMainWindow(parent), ui(new Ui::MainWindow) {ui->setupUi(this); }MainWindow::~MainWi…...

如何让Obsidian实现电脑端和安卓端同步
Obsidian是一款知名的笔记软件,支持Markdown语法,它允许用户在多个设备之间同步文件。要在安卓设备上实现同步,可以使用remote save插件,以下是具体操作步骤: 首先是安装电脑端的obsidian,然后依次下载obs…...
windows系统中jenkins构建报错提示“拒绝访问”
一.背景 之前徒弟在windows中安装的jenkins,运行的时候用的是java -jar jenkins.war来运行的。服务器只有1个盘符C盘。今天说构建错误了,问我修改了啥,我年前是修改过构建思路的。 二.问题分析 先看jenkins构建任务的日志,大概是xcopy命令执…...
C++:std::is_convertible
C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...
《Playwright:微软的自动化测试工具详解》
Playwright 简介:声明内容来自网络,将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具,支持 Chrome、Firefox、Safari 等主流浏览器,提供多语言 API(Python、JavaScript、Java、.NET)。它的特点包括&a…...
Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务
通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务,进行预测并输…...

vulnyx Blogger writeup
信息收集 arp-scan nmap 获取userFlag 上web看看 一个默认的页面,gobuster扫一下目录 可以看到扫出的目录中得到了一个有价值的目录/wordpress,说明目标所使用的cms是wordpress,访问http://192.168.43.213/wordpress/然后查看源码能看到 这…...
NPOI Excel用OLE对象的形式插入文件附件以及插入图片
static void Main(string[] args) {XlsWithObjData();Console.WriteLine("输出完成"); }static void XlsWithObjData() {// 创建工作簿和单元格,只有HSSFWorkbook,XSSFWorkbook不可以HSSFWorkbook workbook new HSSFWorkbook();HSSFSheet sheet (HSSFSheet)workboo…...
uniapp 字符包含的相关方法
在uniapp中,如果你想检查一个字符串是否包含另一个子字符串,你可以使用JavaScript中的includes()方法或者indexOf()方法。这两种方法都可以达到目的,但它们在处理方式和返回值上有所不同。 使用includes()方法 includes()方法用于判断一个字…...

群晖NAS如何在虚拟机创建飞牛NAS
套件中心下载安装Virtual Machine Manager 创建虚拟机 配置虚拟机 飞牛官网下载 https://iso.liveupdate.fnnas.com/x86_64/trim/fnos-0.9.2-863.iso 群晖NAS如何在虚拟机创建飞牛NAS - 个人信息分享...

保姆级【快数学会Android端“动画“】+ 实现补间动画和逐帧动画!!!
目录 补间动画 1.创建资源文件夹 2.设置文件夹类型 3.创建.xml文件 4.样式设计 5.动画设置 6.动画的实现 内容拓展 7.在原基础上继续添加.xml文件 8.xml代码编写 (1)rotate_anim (2)scale_anim (3)translate_anim 9.MainActivity.java代码汇总 10.效果展示 逐帧…...
《Offer来了:Java面试核心知识点精讲》大纲
文章目录 一、《Offer来了:Java面试核心知识点精讲》的典型大纲框架Java基础并发编程JVM原理数据库与缓存分布式架构系统设计二、《Offer来了:Java面试核心知识点精讲(原理篇)》技术文章大纲核心主题:Java基础原理与面试高频考点Java虚拟机(JVM)原理Java并发编程原理Jav…...

数据结构:泰勒展开式:霍纳法则(Horner‘s Rule)
目录 🔍 若用递归计算每一项,会发生什么? Horners Rule(霍纳法则) 第一步:我们从最原始的泰勒公式出发 第二步:从形式上重新观察展开式 🌟 第三步:引出霍纳法则&…...