零基础学人工智能:TensorFlow 入门例子
识别手写图片
因为这个例子是 TensorFlow 官方的例子,不会说的太详细,会加入了一点个人的理解,因为TensorFlow提供了各种工具和库,帮助开发人员构建和训练基于神经网络的模型。TensorFlow 中最重要的概念是张量(Tensor),它代表了多维数组或矩阵,因此 TensorFlow 支持各种不同类型的计算,如线性回归、逻辑回归、卷积神经网络、循环神经网络等。所以帮我们极大减少了对数学与算法基础的要求。
准备数据
这里用来识别的手写图片大致是这样的,为了降低复杂度,每个图片是 28*28 大小。
但是直接丢图片给我们的模型,模型是不认识的,所以必须要对图片进行一些处理。
如果了解线性代数,大概知道图片的每个像素点其实可以表示为一个二维的矩阵,对图片做各种变换,比如翻转啊什么的就是对这个矩阵进行运算,于是我们的手写图片大概可以看成是这样的:
这个矩阵展开成一个向量,长度是 28*28=784。我们还需要另一个东西用来告诉模型我们认为这个图片是几,也就是给图片打个 label。这个 label 也不是随便打的,这里用一个类似有 10 个元素的数组,其中只有一个是 1,其它都是 0,哪位为 1 表示对应的图片是几,例如表示数字 8 的标签值就是 ([0,0,0,0,0,0,0,0,1,0])。
这些就是单张图片的数据处理,实际上为了高效的训练模型,会把图片数据和 label 数据分别打包到一起,也就是 MNIST 数据集了。
MNIST数据集
MNIST 数据集是一个入门级的计算机视觉数据集,官网是Yann LeCun's website。 我们不需要手动去下载这个数据集, 1.0 的 TensorFlow 会自动下载。
这个训练数据集有 55000 个图片数据,用张量的方式组织的,形状如 [55000,784],如下图:
还记得为啥是 784 吗,因为 28*28 的图片。
label 也是如此,[55000,10]:
这个数据集里面除了有训练用的数据之外,还有 10000 个测试模型准确度的数据。
整个数据集大概是这样的:
现在数据有了,来看下我们的模型。
Softmax 回归模型
Softmax 中文名叫归一化指数函数,这个模型可以用来给不同的对象分配概率。比如判断
的时候可能认为有 80% 是 9,有 5% 认为可能是 8,因为上面都有个圈。
我们对图片像素值进行加权求和。比如某个像素具有很强的证据说明这个图不是 1,则这个像素相应的权值为负数,相反,如果这个像素特别有利,则权值为正数。
如下图,红色区域代表负数权值,蓝色代表正数权值。
同时,还有一个偏置量(bias) 用来减小一些无关的干扰量。
Softmax 回归模型的原理大概就是这样,更多的推导过程,可以查阅一下官方文档,有比较详细的内容。
说了那么久,终于可以上代码了。
训练模型
具体引入的一些包这里就不一一列出来,主要是两个,一个是 tensorflow 本身,另一个是官方例子里面用来输入数据用的方法。
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
之后就可以建立我们的模型。
# Create the modelx = tf.placeholder(tf.float32, [None, 784])W = tf.Variable(tf.zeros([784, 10]))b = tf.Variable(tf.zeros([10]))y = tf.matmul(x, W) + b
这里的代码都是类似占位符,要填了数据才有用。
- x 是从图片数据文件里面读来的,理解为一个常量,一个输入值,因为是 28*28 的图片,所以这里是 784;
- W 代表权重,因为有 784 个点,然后有 10 个数字的权重,所以是 [784, 10],模型运算过程中会不断调整这个值,可以理解为一个变量;
- b 代表偏置量,每个数字的偏置量都不同,所以这里是 10,模型运算过程中也会不断调整这个值,也是一个变量;
- y 基于前面的数据矩阵乘积计算。
tf.zeros 表示初始化为 0。
我们会需要一个东西来接受正确的输入,也就是放训练时准确的 label。
# Define loss and optimizery_ = tf.placeholder(tf.float32, [None, 10])
我们会用一个叫交叉熵的东西来衡量我们的预测的「惊讶」程度。
关于交叉熵,举个例子,我们平常写代码的时候,一按编译,一切顺利,程序跑起来了,我们就没那么「惊讶」,因为我们的代码是那么的优秀;而如果一按编译,整个就 Crash 了,我们很「惊讶」,一脸蒙逼的想,这怎么可能。
交叉熵感性的认识就是表达这个的,当输出的值和我们的期望是一致的时候,我们就「惊讶」值就比较低,当输出值不是我们期望的时候,「惊讶」值就比较高。
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
这里就用了 TensorFlow 实现的 softmax 模型来计算交叉熵。
交叉熵,就是我们想要尽量优化的值,让结果符合预期,不要让我们太「惊讶」。
TensorFlow 会自动使用反向传播算法(backpropagation algorithm) 来有效的确定变量是如何影响你想最小化的交叉熵。然后 TensorFlow 会用你选择的优化算法来不断地修改变量以降低交叉熵。
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
这里用了梯度下降算法(gradient descent algorithm)来优化交叉熵,这里是以 0.5 的速度来一点点的优化交叉熵。
之后就是初始化变量,以及启动 Session
sess = tf.InteractiveSession() tf.global_variables_initializer().run()
启动之后,开始训练!
# Train for _ in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
这里训练 1000 次,每次随机找 100 个数据来练习,这里的 feed_dict={x: batch_xs, y_: batch_ys}
,就是我们前面那设置的两个留着占位的输入值。
到这里基本训练就完成了。
评估模型
训练完之后,我们来评估一下模型的准确度。
# Test trained modelcorrect_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))print(sess.run(accuracy, feed_dict={x: mnist.test.images,y_: mnist.test.labels}))
tf.argmax 给出某个tensor对象在某一维上的其数据最大值所在的索引值。因为我们的 label 只有一个 1,所以 tf.argmax(_y, 1) 就是 label 的索引,也就是表示图片是几。把计算值和预测值 equal 一下就可以得出模型算的是否准确。
下面的 accuracy 计算的是一个整体的精确度。
这里填入的数据不是训练数据,是测试数据和测试 label。
最终结果,我的是 0.9151,91.51% 的准确度。官方说这个不太好,如果用一些更好的模型,比如多层卷积网络等,这个识别率可以到 99% 以上。
官方的例子到这里就结束了,虽然说识别了几万张图片,但是我一张像样的图片都没看到,于是我决定想办法拿这个模型真正找几个图片测试一下。
用模型测试
看下上面的例子,重点就是放测试数据进去这里,如果我们要拿图片测,需要先把图片变成相应格式的数据。
sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
看下这里 mnist 是从 tensorflow.examples.tutorials.mnist 中的 input_data 的 read_data_sets 方法中来的。
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
Python 就是好,有啥不懂看下源码。源码的在线地址在这里
打开找 read_data_sets 方法,发现:
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
在这个文件里面。
def read_data_sets(train_dir,fake_data=False,one_hot=False,dtype=dtypes.float32,reshape=True,validation_size=5000):.........train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)validation = DataSet(validation_images,validation_labels,dtype=dtype,reshape=reshape)test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)return base.Datasets(train=train, validation=validation, test=test)
可以看到,最后返回的都是是一个对象,而我们用的 feeddict={x: mnist.test.images, y: mnist.test.labels} 就是从这来的,是一个 DataSet 对象。这个对象也在这个文件里面。
class DataSet(object):def __init__(self,images,labels,fake_data=False,one_hot=False,dtype=dtypes.float32,reshape=True):"""Construct a DataSet.one_hot arg is used only if fake_data is true. `dtype` can be either`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into`[0, 1]`."""......
这个对象很长,我就只挑重点了,主要看构造方法。一定要传入的有 images 和 labels。其实这里已经比较明朗了,我们只要把单张图片弄成 mnist 格式,分别传入到这个 DataSet 里面,就可以得到我们要的数据。
网上查了下还真有,代码地址,对应的文章:www.jianshu.com/p/419557758…,文章讲的有点不清楚,需要针对 TensorFlow 1.0 版本以及实际目录情况做点修改。
直接上我修改后的代码:
from PIL import Image
from numpy import *def GetImage(filelist):width=28height=28value=zeros([1,width,height,1])value[0,0,0,0]=-1label=zeros([1,10])label[0,0]=-1for filename in filelist:img=array(Image.open(filename).convert("L"))width,height=shape(img);index=0tmp_value=zeros([1,width,height,1])for i in range(width):for j in range(height):tmp_value[0,i,j,0]=img[i,j]index+=1if(value[0,0,0,0]==-1):value=tmp_valueelse:value=concatenate((value,tmp_value))tmp_label=zeros([1,10])index=int(filename.strip().split('/')[2][0])print "input:",indextmp_label[0,index]=1if(label[0,0]==-1):label=tmp_labelelse:label=concatenate((label,tmp_label))return array(value),array(label)
这里读取图片依赖 PIL 这个库,由于 PIL 比较少维护了,可以用它的一个分支 Pillow 来代替。另外依赖 numpy 这个科学计算库,没装的要装一下。
这里就是把图片读取,并按 mnist 格式化,label 是取图片文件名的第一个字,所以图片要用数字开头命名。
如果懒得 PS 画或者手写的画,可以把测试数据集的数据给转回图片,实测成功,参考这篇文章:如何用python解析mnist图片
新建一个文件夹叫 test_num,里面图片如下,这里命名就是 label 值,可以看到 label 和图片是对应的:
开始测试:
print("Start Test Images")dir_name = "./test_num"files = glob2.glob(dir_name + "/*.png")cnt = len(files)for i in range(cnt):print(files[i])test_img, test_label = GetImage([files[i]])testDataSet = DataSet(test_img, test_label, dtype=tf.float32)res = accuracy.eval({x: testDataSet.images, y_: testDataSet.labels})print("output: ", res)print("----------"
这里用了 glob2 这个库来遍历以及过滤文件,需要安装,常规的遍历会把 Mac 上的 .DS_Store 文件也会遍历进去。
可以看到我们打的 label 和模型算出来的是相符的。
然后我们可以打乱文件名,把 9 说成 8,把 0 也说成 8:
可以看到,我们的 label 和模型算出来的是不相符的。
恭喜,到着你就完成了一次简单的人工智能之旅。
总结
从这个例子中我们可以大致知道 TensorFlow 的运行模式:
例子中是每次都要走一遍训练流程,实际上是可以用 tf.train.Saver() 来保存训练好的模型的。这个入门例子完成之后能对 TensorFlow 有个感性认识。
TensorFlow 没有那么神秘,没有我们想的那么复杂,也没有我们想的那么简单,并且还有很多数学知识要补充呢。
相关文章:

零基础学人工智能:TensorFlow 入门例子
识别手写图片 因为这个例子是 TensorFlow 官方的例子,不会说的太详细,会加入了一点个人的理解,因为TensorFlow提供了各种工具和库,帮助开发人员构建和训练基于神经网络的模型。TensorFlow 中最重要的概念是张量(Tenso…...
go从0到1项目实战体系二一:gin框架安装
(1). 设置公用的代理服务地址: 如果设置了全局可忽略. $ export GOPROXYhttps://goproxy.io // linux > go env可以查看 $ export GOPROXYhttps://goproxy.cn // linux国内镜像 $ set GOPROXYhttps://goproxy.io // windows(2). 创建以下目录: 请忘记GOPATH目录…...
运用JavaSE知识实现图书管理系统
目录 一.Main函数二.用户类三.普通用户类四.管理员类五.图书类六.书架类七.操作类1.操作接口2.增加操作3.删除操作4.查找操作5.展示操作6.借阅操作7.归还操作8.退出系统 总结 这篇图书管理系统是对JavaSE知识总结复习的一个小作业,检测自己对知识的掌握程度。 一.Ma…...
微信小程序生成一个天气查询的小程序
微信小程序生成一个天气查询的小程序 基本的页面结构和逻辑 页面结构:包括一个输入框和一个查询按钮。 页面逻辑:在用户输入城市名称后,点击查询按钮,跳转到天气详情页面,并将城市名称作为参数传递。 主要代码 index…...

Seata源码——TCC模式解析02
初始化 在SpringBoot启动的时候通过自动注入机制将GlobalTransactionScanner注入进ioc而GlobalTransactionScanner继承AbstractAutoProxyCreatorAbstract 在postProcessAfterInitialization阶段由子类创建代理TccActionInterceptor GlobalTransactionScanner protected Obje…...
缓存-Redis
Springboot使用Redis 引入pom依赖: <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId> </dependency>在application.yml、application-dev.yml中配置Redis的访…...

PADS Layout安全间距检查报错
问题: 在Pads Layout完成layout后,进行工具-验证设计安全间距检查时,差分对BAK_FIXCLK_100M_P / BAK_FIXCLK_100M_N的安全间距检查报错,最小为3.94mil,但是应该大于等于5mil;如下两张图: 检查&…...
ebpf基础篇(二) ----- ebpf前世今生
bpf 要追述ebpf的历史,就不得不提bpf. bpf(Berkeley Packet Filter)从早(1992年)诞生于类Unix系统中,用于数据包分析. 它提供了数据链路层的接口,可以在数据链路层发送和接收数据.如果网卡支持混杂模式,所有的数据包都可以被接收,即使这些数据包的目的地址是其它主机. BPF最为…...
我的一天:追求专业成长与生活平衡
早晨的序幕:奋斗的开始 今天的一天始于清晨的6点47分。实现了昨天的早睡早起的蜕变计划。洗漱完成之后,7点17分出门,7点33分我抵达公司,为新的一天做好准备。7点52分,我开始我的学习之旅。正如我所体会的,“…...

【动态规划】斐波那契数列模型
欢迎来到Cefler的博客😁 🕌博客主页:那个传说中的man的主页 🏠个人专栏:题目解析 🌎推荐文章:题目大解析(3) 前言 算法原理 1.状态表示 是什么?dp表(一维数组…...
机器人运动学分析与动力学分析主要作用
机器人运动学分析和动力学分析是两个重要的概念,它们在研究和设计工业机器人时起着关键作用。 1. 机器人运动学分析: 机器人运动学是研究机器人运动的科学,它涉及机器人的位置、速度、加速度和轨迹等方面。机器人运动学分析主要包括正解和逆…...

【Java 基础】33 JDBC
文章目录 1. 数据库连接1)加载驱动2)建立连接 2. 常见操作1)创建表2)插入数据3)查询数据4)使用 PreparedStatement5)事务管理 3. 注意事项总结 Java Database Connectivity(JDBC&…...

Unity中Shader缩放矩阵
文章目录 前言一、直接相乘缩放1、在属性面板定义一个四维变量,用xyz分别控制在xyz轴上的缩放2、在常量缓存区申明该变量3、在顶点着色器对其进行相乘,来缩放变换4、我们来看看效果 二、使用矩阵乘法代替直接相乘缩放的原理1、我们按如下格式得到缩放矩阵…...

Nessus详细安装-windows (保姆级教程)
Nessus描述 Nessus 是一款广泛使用的网络漏洞扫描工具。它由 Tenable Network Security 公司开发,旨在帮助组织评估其计算机系统和网络的安全性。 Nessus 可以执行自动化的漏洞扫描,通过扫描目标系统、识别和评估可能存在的安全漏洞和弱点。它可以检测…...

Stream流的简单使用
stream流的三类方法 获取Stream流 ○ 创建一条流水线,并把数据放到流水线上准备进行操作中间方法 ○ 流水线上的操作 ○ 一次操作完毕之后,还可以继续进行其他操作终结方法 ○ 一个Stream流只能有一个终结方法 ○ 是流水线上的最后一个操作 其实Stream流非常简单,只…...

智能优化算法应用:基于蛇优化算法3D无线传感器网络(WSN)覆盖优化 - 附代码
智能优化算法应用:基于蛇优化算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于蛇优化算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.蛇优化算法4.实验参数设定5.算法结果6.参考文…...
vue和react diff的详解和不同
diff算法 简述:第一次对比真实dom和虚拟树之间的同层差别,后面为对比新旧虚拟dom树之间的同层差别。 虚拟dom 简述:js对象形容模拟真实dom 具体: 1.虚拟dom是存在内存中的js对象,利用内存的高效率运算。虚拟dom属…...

智能优化算法应用:基于鹈鹕算法3D无线传感器网络(WSN)覆盖优化 - 附代码
智能优化算法应用:基于鹈鹕算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于鹈鹕算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.鹈鹕算法4.实验参数设定5.算法结果6.参考文献7.MA…...

10:IIC通信
1:IIC通信 I2C总线(Inter IC BUS) 是由Philips公司开发的一种通用数据总线,应用广泛,下面是一些指标参数: 两根通信线:SCL(Serial Clock,串行时钟线)、SDA&a…...

互联网上门洗衣洗鞋小程序优势有哪些?
互联网洗鞋店小程序相较于传统洗鞋方式,具有以下优势; 1. 便捷性:用户只需通过手机即可随时随地下单并查询,省去了许多不必要的时间和精力。学生们无需走出宿舍或校园,就能轻松预约洗鞋并取件。 2. 精准定位࿱…...

Zustand 状态管理库:极简而强大的解决方案
Zustand 是一个轻量级、快速和可扩展的状态管理库,特别适合 React 应用。它以简洁的 API 和高效的性能解决了 Redux 等状态管理方案中的繁琐问题。 核心优势对比 基本使用指南 1. 创建 Store // store.js import create from zustandconst useStore create((set)…...

用docker来安装部署freeswitch记录
今天刚才测试一个callcenter的项目,所以尝试安装freeswitch 1、使用轩辕镜像 - 中国开发者首选的专业 Docker 镜像加速服务平台 编辑下面/etc/docker/daemon.json文件为 {"registry-mirrors": ["https://docker.xuanyuan.me"] }同时可以进入轩…...

项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)
Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败,具体原因是客户端发送了密码认证请求,但Redis服务器未设置密码 1.为Redis设置密码(匹配客户端配置) 步骤: 1).修…...
根目录0xa0属性对应的Ntfs!_SCB中的FileObject是什么时候被建立的----NTFS源代码分析--重要
根目录0xa0属性对应的Ntfs!_SCB中的FileObject是什么时候被建立的 第一部分: 0: kd> g Breakpoint 9 hit Ntfs!ReadIndexBuffer: f7173886 55 push ebp 0: kd> kc # 00 Ntfs!ReadIndexBuffer 01 Ntfs!FindFirstIndexEntry 02 Ntfs!NtfsUpda…...
Python竞赛环境搭建全攻略
Python环境搭建竞赛技术文章大纲 竞赛背景与意义 竞赛的目的与价值Python在竞赛中的应用场景环境搭建对竞赛效率的影响 竞赛环境需求分析 常见竞赛类型(算法、数据分析、机器学习等)不同竞赛对Python版本及库的要求硬件与操作系统的兼容性问题 Pyth…...
k8s从入门到放弃之HPA控制器
k8s从入门到放弃之HPA控制器 Kubernetes中的Horizontal Pod Autoscaler (HPA)控制器是一种用于自动扩展部署、副本集或复制控制器中Pod数量的机制。它可以根据观察到的CPU利用率(或其他自定义指标)来调整这些对象的规模,从而帮助应用程序在负…...
LLaMA-Factory 微调 Qwen2-VL 进行人脸情感识别(二)
在上一篇文章中,我们详细介绍了如何使用LLaMA-Factory框架对Qwen2-VL大模型进行微调,以实现人脸情感识别的功能。本篇文章将聚焦于微调完成后,如何调用这个模型进行人脸情感识别的具体代码实现,包括详细的步骤和注释。 模型调用步骤 环境准备:确保安装了必要的Python库。…...
WEB3全栈开发——面试专业技能点P4数据库
一、mysql2 原生驱动及其连接机制 概念介绍 mysql2 是 Node.js 环境中广泛使用的 MySQL 客户端库,基于 mysql 库改进而来,具有更好的性能、Promise 支持、流式查询、二进制数据处理能力等。 主要特点: 支持 Promise / async-await…...
算法刷题-回溯
今天给大家分享的还是一道关于dfs回溯的问题,对于这类问题大家还是要多刷和总结,总体难度还是偏大。 对于回溯问题有几个关键点: 1.首先对于这类回溯可以节点可以随机选择的问题,要做mian函数中循环调用dfs(i&#x…...
生成对抗网络(GAN)损失函数解读
GAN损失函数的形式: 以下是对每个部分的解读: 1. , :这个部分表示生成器(Generator)G的目标是最小化损失函数。 :判别器(Discriminator)D的目标是最大化损失函数。 GAN的训…...