当前位置: 首页 > news >正文

使用pytorch利用神经网络原理进行图片的训练(持续学习中....)

1.做这件事的目的
语言只是工具,使用python训练图片数据,最终会得到.pth的训练文件,java有使用这个文件进行图片识别的工具,顺便整合,我觉得Neo4J正确率太低了,草莓都能识别成为苹果,而且速度慢,不能持续识别视频帧

2.什么是神经网络?(其实就是数学的排列组合最终得到统计结果的概率)

1.先把二维数组转为一维
2.通过公式得到节点个数和值
3…同2
4.通过节点得到概率(softmax归一化公式)
5.对比模型的和 差值=原始概率-目标结果概率
6.不断优化原来模型的概率
5.激活函数,激活某个节点的函数,可以引入非线性的(因为所有问题不可能是线性的比如 很少图片识别一定可以识别出绝对的正方形,他可能中间有一定弯曲或者线在中心短开了)

在这里插入图片描述
在这里插入图片描述

3.训练的代码
//环境python3.8 最好使用conda进行版本管理,不然每个版本都可能不兼容,到处碰壁

 #安装依赖pip install numpy torch torchvision matplotlib

#文件夹结构,图片一定要是28x28的
在这里插入图片描述

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolderclass Net(torch.nn.Module):def __init__(self):super().__init__()self.fc1 = torch.nn.Linear(28 * 28, 64)self.fc2 = torch.nn.Linear(64, 64)self.fc3 = torch.nn.Linear(64, 64)self.fc4 = torch.nn.Linear(64, 10)def forward(self, x):x = torch.nn.functional.relu(self.fc1(x))x = torch.nn.functional.relu(self.fc2(x))x = torch.nn.functional.relu(self.fc3(x))x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)return x#导入数据
def get_data_loader(is_train):#张量,多维数组to_tensor = transforms.Compose([transforms.ToTensor()])# 下载数据集 下载目录data_set = MNIST("", is_train, transform=to_tensor, download=True)#一个批次15张,顺序打乱return DataLoader(data_set, batch_size=15, shuffle=True)def get_image_loader(folder_path):to_tensor = transforms.Compose([transforms.ToTensor()])data_set = ImageFolder(folder_path, transform=to_tensor)return DataLoader(data_set, batch_size=1)#评估准确率
def evaluate(test_data, net):n_correct = 0n_total = 0with torch.no_grad():#按批次取数据for (x, y) in test_data:#计算神经网络预测值outputs = net.forward(x.view(-1, 28 * 28))for i, output in enumerate(outputs):#比较预测结果和测试集结果if torch.argmax(output) == y[i]:#统计正确预测结果数n_correct += 1#统计全部预测结果n_total += 1#返回准确率=正确/全部的return n_correct / n_totaldef main():#加载训练集train_data = get_data_loader(is_train=True)#加载测试集test_data = get_data_loader(is_train=False)#初始化神经网络net = Net()#打印测试网络的准确率 0.1print("initial accuracy:", evaluate(test_data, net))#训练神经网络optimizer = torch.optim.Adam(net.parameters(), lr=0.001)#重复利用数据集 2次for epoch in range(100):for (x, y) in train_data:#初始化 固定写法net.zero_grad()#正向传播output = net.forward(x.view(-1, 28 * 28))#计算差值loss = torch.nn.functional.nll_loss(output, y)#反向误差传播loss.backward()#优化网络参数optimizer.step()print("epoch", epoch, "accuracy:", evaluate(test_data, net))# #使用3张图片进行预测# for (n, (x, _)) in enumerate(test_data):#     if n > 3:#         break#     predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))#     plt.figure(n)#     plt.imshow(x[0].view(28, 28))#     plt.title("prediction: " + str(int(predict)))# plt.show()image_loader = get_image_loader("aa")for (n, (x, _)) in enumerate(image_loader):if n > 2:breakpredict = torch.argmax(net.forward(x.view(-1, 28 * 28)))plt.figure(n)plt.imshow(x[0].permute(1, 2, 0))plt.title("prediction: " + str(int(predict)))plt.show()if __name__ == "__main__":main()

#运行结果 弹框出现图片和识别结果

4.测试电脑的cuda是否安装成功,不成功不能运行下面的代码

import torchdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('CUDA version:', torch.version.cuda)
print('PyTorch version:', torch.__version__)

5.在gpu上运行,需要去官网下载cuda安装
https://developer.nvidia.com/cuda-toolkit-archive
#并且需要安装和torch对应的版本,我的电脑是1660ti的所以安装了10.2的cuda
#安装torchgpu版本

pip install torch==1.9.0+cu102 -f
https://download.pytorch.org/whl/cu102/torch_stable.html

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from torchvision.datasets.folder import ImageFolderdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class Net(torch.nn.Module):def __init__(self):super().__init__()self.fc1 = torch.nn.Linear(28 * 28, 64)self.fc2 = torch.nn.Linear(64, 64)self.fc3 = torch.nn.Linear(64, 64)self.fc4 = torch.nn.Linear(64, 10)def forward(self, x):x = torch.nn.functional.relu(self.fc1(x))x = torch.nn.functional.relu(self.fc2(x))x = torch.nn.functional.relu(self.fc3(x))x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)return xdef get_data_loader(is_train):to_tensor = transforms.Compose([transforms.ToTensor()])data_set = MNIST("", is_train, transform=to_tensor, download=True)return DataLoader(data_set, batch_size=15, shuffle=True)def get_image_loader(folder_path):to_tensor = transforms.Compose([transforms.ToTensor()])data_set = ImageFolder(folder_path, transform=to_tensor)return DataLoader(data_set, batch_size=1)def evaluate(test_data, net):n_correct = 0n_total = 0with torch.no_grad():for (x, y) in test_data:x, y = x.to(device), y.to(device)outputs = net.forward(x.view(-1, 28 * 28))for i, output in enumerate(outputs):if torch.argmax(output.cpu()) == y[i].cpu():n_correct += 1n_total += 1return n_correct / n_totaldef main():train_data = get_data_loader(is_train=True)test_data = get_data_loader(is_train=False)net = Net().to(device)print("initial accuracy:", evaluate(test_data, net))optimizer = torch.optim.Adam(net.parameters(), lr=0.001)for epoch in range(100):for (x, y) in train_data:x, y = x.to(device), y.to(device)net.zero_grad()output = net.forward(x.view(-1, 28 * 28))loss = torch.nn.functional.nll_loss(output, y)loss.backward()optimizer.step()print("epoch", epoch, "accuracy:", evaluate(test_data, net))image_loader = get_image_loader("aa")for (n, (x, _)) in enumerate(image_loader):if n > 2:breakx = x.to(device)predict = torch.argmax(net.forward(x.view(-1, 28 * 28)).cpu())plt.figure(n)plt.imshow(x[0].permute(1, 2, 0).cpu())plt.title("prediction: " + str(int(predict)))plt.show()if __name__ == "__main__":main()

相关文章:

使用pytorch利用神经网络原理进行图片的训练(持续学习中....)

1.做这件事的目的 语言只是工具,使用python训练图片数据,最终会得到.pth的训练文件,java有使用这个文件进行图片识别的工具,顺便整合,我觉得Neo4J正确率太低了,草莓都能识别成为苹果,而且速度慢,不能持续识别视频帧 2.什么是神经网络?(其实就是数学的排列组合最终得到统计结果…...

2023年中国合成云母行业现状及市场格局分析[图]

合成云母是一种通过化工原料经高温熔融冷却析晶而制得的单斜晶系矿物,属于典型的层状硅酸盐,许多性能都优于天然云母,如合成云母的耐温高达1200℃以上,而天然白云母在550℃下就会开始分解,金云母则在800℃开始分解。除…...

Vue3+Vite实现工程化,插值表达式和v-text以及v-html

1、插值表达式 插值表达式最基本的数据绑定形式是文本插值,它使用的是"Mustache"语法,即 双大括号{{}} 插值表达式是将数据 渲染 到元素的指定位置的手段之一插值表达式 不绝对依赖标签,其位置相对自由插值表达式中支持javascript的…...

艾泊宇产品战略:灵感于鬼屋,掌握打造卓越用户体验的关键要素

在当今的商业环境中,用户体验已经成为产品成功的关键因素。 无论是线上产品还是实体产品,用户体验都是决定用户是否愿意使用和推荐该产品的关键因素。 那么,艾泊宇产品战略理论告诉大家,如何做好用户体验? 我们可以…...

深度学习环境配置(Anaconda+pytorch+pycharm+cuda)

NVIDIA驱动安装 首先查看电脑的显卡版本,步骤为:此电脑右击-->管理-->设备管理器-->显示适配器。就可以看到电脑显卡的版本了。 然后按照电脑信息,到地址 去安装相应的驱动,Notebooks是笔记本的意思,然后下…...

不是说人工智能是风口吗,那为什么工作还那么难找?

最近确实有很多媒体、机构渲染人工智能可以拿高薪,这在行业内也是事实,但前提是你有足够的竞争力,真的懂人工智能。 首先,人工智能岗位技能要求高,人工智能是一个涵盖了多个学科领域的综合性学科,包括数学、…...

new Vue() 发生了什么

前言: 在Vue.js中,当你创建一个新的Vue实例时,通过 new Vue() 发生了一系列重要的操作,包括Vue实例的初始化、数据绑定、模板编译等。这个过程是Vue应用的核心,本文将深入探讨new Vue()发生了什么以及其原理,提供示例…...

【算法】二叉树的存储与遍历模板

二叉树的存储与遍历 const int N 1e6 10;// 二叉树的存储,l数组为左节点,r数组为右结点 int l[N], r[N]; // 存储节点的数据 char w[N]; // 节点的下标指针 int idx 0;// 先序创建 int pre_create(int n) {cin >> w[n];if (w[n] #) return -1;l[n] pre_create(idx)…...

【Go学习之 go mod】gomod小白入门,在github上发布自己的项目(项目初始化、项目发布、项目版本升级等)

参考 Go语言基础之包 | 李文周的博客Go mod的使用、发布、升级 | weiGo Module如何发布v2及以上版本1.2.7. go mod命令 — 新溪-gordon V1.7.9 文档golang go 包管理工具 go mod的详细介绍-腾讯云开发者社区-腾讯云Go Mod 常见错误的原因 | walker的博客 项目案例 oceanweav…...

79基于matlab的大米粒中杂质识别

基于matlab的大米粒中杂质识别,数据可更换自己的,程序已调通,可直接运行。 79matlab图像处理杂质识别 (xiaohongshu.com)...

Vue 项目实战——如何在页面中展示 PDF 文件以及 PDFObject 插件实战

文章目录 📋前言🎯使用 HTML 标签🧩 embed 标签🧩 object标签🧩 iframe标签🧩完整代码 🎯使用 PDFObject 插件🧩为什么使用 PDFObject 插件(AI翻译)&#x1f…...

系列六、ThreadLocal内存泄露案例

一、ThreadLocal内存泄露案例 /*** Author : 一叶浮萍归大海* Date: 2023/11/22 10:56* Description: 写一段代码导致内存泄露* VM Options:-Xms20m -Xmx20m -Xmn10m -XX:PrintGCDetails* 说明:内存泄露最终会导致内存溢出*/ public class ThreadLocalO…...

Java学习笔记44——Stream流

Stream流 体验Stream流Stream流的生成方式ColLection体系的集合可以使用默认方法stream ()生成流Map体系的集合间接的生成流数组可以通过stream接口的静态方法of (T... values)生成流 Stream流的中间操作方法Stream<T> filter(Predicate predicate)Stream<T>limit(…...

excel表格忘记密码,如何找回?

找回和去除Excel表格密码的方法非常简单。具体步骤如下&#xff1a;第一步百度搜索【 密码帝官网 】&#xff0c;第二步点击“立即开始”在用户中心上传文件即可。这个方法既安全又简单&#xff0c;不需要下载任何软件&#xff0c;而且可以在手机和电脑上都使用。密码帝官网支持…...

IDEA版SSM入门到实战(Maven+MyBatis+Spring+SpringMVC) -Mybatis初识和框架搭建

第一章 初识Mybatis 1.1 框架概述 生活中“框架” 买房子笔记本电脑 程序中框架【代码半成品】 Mybatis框架&#xff1a;持久化层框架【dao层】SpringMVC框架&#xff1a;控制层框架【Servlet层】Spring框架&#xff1a;全能… 1.2 Mybatis简介 Mybatis是一个半自动化持久化…...

差分放大器工作原理(差分放大器和功率放大器区别)

差分放大器是一种特殊的放大器&#xff0c;它可以将两个输入信号的差异放大输出。其工作原理基于差分放大器的电路结构和差分输入特性。 一、差分放大器电路结构 差分放大器一般由四个基本电路组成&#xff1a;正反馈网络、反相输入端、共模抑制电路和差分输入端。其中&#xf…...

SystemV

a...

LiteOS同步实验(实现生产者-消费者问题)

效果如下图&#xff1a; 给大家解释一下上述效果&#xff1a;在左侧&#xff08;顶格&#xff09;的是生产者&#xff08;Producer&#xff09;&#xff1b;在右侧&#xff08;空格&#xff09;的是消费者&#xff08;Consumer&#xff09;。生产者有1个&#xff0c;代号为“0”…...

redis的性能管理和雪崩

redis的性能管理 redis的数据是缓存在内存当中的 系统巡检&#xff1a; 硬件巡检、数据库、nginx、redis、docker、k8s 运维人员必须要关注的redis指标 在日常巡检中需要经常查看这些指标使用情况 info memory #查看redis使用内存的指标 used_memory:11285512 #数据占用的…...

python:关于函数内 * 和 / 是什么意思?

总结&#xff1a;如果你希望调用者使用函数时一定不能使用关键字传参&#xff0c;要求它使用位置进行传参&#xff0c;那么就可以把这些参数放在 / 的前面即可&#xff1b;如果你希望调用者使用函数时一定要使用某些参数&#xff0c;且必须是关键字传参时&#xff0c;那么就可以…...

如何用 watchEffect 实现根据参数自动获取数据?代码简化干货

用 watchEffect 实现参数变化自动重拉&#xff0c;核心是将请求逻辑写在回调中并直接读取响应式依赖&#xff08;如 route.params.id、searchKey.value&#xff09;&#xff0c;Vue 自动追踪&#xff1b;需封装请求函数但不可提前解构响应式值&#xff1b;可同步控制 loading/e…...

为什么很多公司服务器一多,运维反而越来越“失控”?

为什么很多公司服务器一多,运维反而越来越“失控”? 很多人刚入行运维的时候。 总觉得: 运维 = 装系统 + 部署服务 + 改配置后来进了真正的大型互联网公司才发现: 根本不是这么回事。 真正的大规模运维现场,经常是这样的: 凌晨 3 点。 报警群疯狂闪烁。 Promethe…...

基于人工智能优化算法的宽带多频功率放大器【附代码】

✅ 博主简介&#xff1a;擅长数据搜集与处理、建模仿真、程序设计、仿真代码、论文写作与指导&#xff0c;毕业论文、期刊论文经验交流。 ✅ 如需沟通交流&#xff0c;扫描文章底部二维码。&#xff08;1&#xff09;电路-电磁场联合仿真自动优化框架&#xff1a;提出了一种直接…...

筑牢水域安全防线:那些值得深思的防溺水之问

每到夏季&#xff0c;溺水事故便进入高发期&#xff0c;一条条鲜活生命的逝去&#xff0c;给无数家庭带来无法磨灭的伤痛。溺水已成为未成年人意外伤害致死的主要原因之一&#xff0c;面对频发的悲剧&#xff0c;我们不得不静下心来&#xff0c;追问那些关乎生命安全的核心问题…...

如何快速上手OpenBoardView:5个实用技巧与完整操作指南

如何快速上手OpenBoardView&#xff1a;5个实用技巧与完整操作指南 【免费下载链接】OpenBoardView View .brd files 项目地址: https://gitcode.com/gh_mirrors/op/OpenBoardView OpenBoardView是一款功能强大的开源电路板设计文件查看工具&#xff0c;专为替代传统的&…...

AI Agent 入门课:RAG 不是检索外挂,而是 Agent 的知识闭环

在企业知识问答里&#xff0c;最常见的失败并不是“完全搜不到”&#xff0c;而是第一次搜到的内容看起来相关&#xff0c;答案也写得流畅&#xff0c;结论却经不起复核。用户问一句“帮我总结这份文档”&#xff0c;普通 RAG 往往会先搜一批材料&#xff0c;再把结果塞回上下文…...

如何在 Taotoken 平台快速获取并管理你的 API Key

如何在 Taotoken 平台快速获取并管理你的 API Key 1. 注册与登录 Taotoken 平台 要开始使用 Taotoken 的服务&#xff0c;首先需要注册一个账号。访问 Taotoken 官方网站完成注册流程&#xff0c;使用邮箱验证后即可登录控制台。登录后你将看到仪表盘界面&#xff0c;这里提供…...

别再复制粘贴了!手把手教你用C语言实现CRC-16 IBM校验(附四种代码对比与性能分析)

CRC-16 IBM校验实战指南&#xff1a;从原理到四种高效C语言实现 在嵌入式系统和通信协议开发中&#xff0c;数据完整性校验是确保信息可靠传输的基石。CRC-16 IBM作为工业界广泛采用的校验算法&#xff0c;其独特的多项式处理和位反序特性使其在Modbus等协议中表现优异。但网上…...

别再只用plt.grid(True)了!Matplotlib网格线自定义的5个实用技巧(附代码)

别再只用plt.grid(True)了&#xff01;Matplotlib网格线自定义的5个实用技巧&#xff08;附代码&#xff09; 如果你还在用plt.grid(True)来简单开启网格线&#xff0c;那可能错过了Matplotlib一半的美学潜力。网格线不只是背景装饰&#xff0c;它能引导视线、强化数据对比、甚…...

AI代码沙盒:安全执行AI生成代码的容器化实践

1. 项目概述&#xff1a;AI时代的代码沙盒最近在GitHub上看到一个挺有意思的项目&#xff0c;叫typper-io/ai-code-sandbox。光看名字&#xff0c;你大概能猜到它是个跟AI和代码执行环境相关的工具。简单来说&#xff0c;这是一个专门为AI应用设计的、安全隔离的代码执行环境&a…...