当前位置: 首页 > news >正文

机器学习入门【经典的CIFAR10分类】

模型

神经网络采用下图
在这里插入图片描述

我使用之后发现迭代多了之后一直最高是正确率65%左右,然后我自己添加了一些Relu激活函数和正则化,现在正确率可以有80%左右。

模型代码

import torch
from torch import nnclass YmModel(nn.Module):def __init__(self):super(YmModel, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, 64),nn.ReLU(),nn.Dropout(0.5),nn.Linear(64, 10),)def forward(self, x):return self.model(x)

训练

有一点要说明的是,数据集中并没有验证集,你可以从训练集扣个1w张出来

import torch
import torchvision
from torchvision import transformsfrom models.YMModel import YmModel
from torch.utils.data import DataLoadertransform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])# 数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform_train, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)
print(len(train_loader), len(test_loader))print(len(train_dataset), len(test_dataset))model = YmModel()
#迭代次数
train_epochs = 300
#优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 损失函数
loss_fn = torch.nn.CrossEntropyLoss()train_epochs_step = 0
best_accuracy = 0.for epoch in range(train_epochs):model.train()print(f'Epoch is {epoch}')for images, labels in train_loader:outputs = model(images)loss = loss_fn(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if train_epochs_step % 100 == 0:print(f'Train_Epoch is {train_epochs_step}\t Loss is {loss.item()}')train_epochs_step += 1train_epochs_step = 0with torch.no_grad():loss_running_total = 0.acc_running_total = 0.for images, labels in test_loader:outputs = model(images)loss = loss_fn(outputs, labels)loss_running_total += loss.item()acc_running_total += (outputs.argmax(1) == labels).sum().item()acc_running_total /= len(test_dataset)if acc_running_total > best_accuracy:best_accuracy = acc_running_totaltorch.save(model.state_dict(), './best_model.pth')print('accuracy is {}'.format(acc_running_total))print('total loss is {}'.format(loss_running_total))print('best accuracy is {}'.format(best_accuracy))

验证

import osimport numpy as np
import torch
import torchvision
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transformsfrom models.TestColor import TextColor
from models.YMModel import YmModeltest_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)classes = ('airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')
model = YmModel()model.load_state_dict(torch.load('best_model.pth'))model.eval()
with torch.no_grad():correct = 0.for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)correct += (predicted == labels).sum().item()print('Accuracy : {}'.format(100 * correct / len(test_dataset)))
folder_path = './images'
files_names = os.listdir(folder_path)
transform_test = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),
])for file_name in files_names:image_path = os.path.join(folder_path, file_name)image = Image.open(image_path)image = transform_test(image)image = np.reshape(image, [1, 3, 32, 32])output = model(image)_, predicted = torch.max(output, 1)source_name = os.path.splitext(file_name)[0]predicted_class = classes[predicted.item()]colors = TextColor.GREEN if predicted_class == source_name else TextColor.REDprint(f"Source is {TextColor.BLUE}{source_name}{TextColor.RESET}, and predicted is {colors}{predicted_class}{TextColor.RESET}")

结果

TextColor是自定义字体颜色的类,image中就是自己的图片。
结果如下:测试集的正确率有82.7%

在这里插入图片描述

相关文章:

机器学习入门【经典的CIFAR10分类】

模型 神经网络采用下图 我使用之后发现迭代多了之后一直最高是正确率65%左右,然后我自己添加了一些Relu激活函数和正则化,现在正确率可以有80%左右。 模型代码 import torch from torch import nnclass YmModel(nn.Module):def __init__(self):super(…...

01 安装

安装和卸载中,用户全部切换为root,一旦安装,普通用户也能使用 初期不进行用户管理,全部用root进行,使用mysql语句 1. 卸载内置环境 检查是否有mariadb存在,存在走a部分卸载 ps axj | grep mysql ps ajx |…...

AI 模型本地推理 - YYPOLOE - Python - Windows - GPU - 吸烟检测(目标检测)- 有配套资源直接上手实现

Python 运行 - GPU 推理 - windows 环境准备python 代码 环境准备 FastDeploy预编译库下载 conda config --add channels conda-forge && conda install cudatoolkit11.2 cudnn8.2 pip install fastdeploy_gpu_python-0.0.0-cp38-cp38-win_amd64.whlpython 代码 impo…...

全国媒体邀约,主流媒体到场出席采访报道

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 全国媒体邀约,确保主流媒体到场出席采访报道,可以带来一系列的好处,这些好处不仅能够增强活动的可见度,还能对品牌或组织的长期形象产生积…...

计算机视觉8 图像增广

图像增广(image augmentation)是通过对训练图像进行一系列随机改变,从而产生相似但又不同的训练样本的技术。 图像增广有以下两个主要作用: 扩大训练数据集的规模;随机改变训练样本可以降低模型对某些属性的依赖&#…...

Transformer中的自注意力是怎么实现的?

在Transformer模型中,自注意力(Self-Attention)是核心组件,用于捕捉输入序列中不同位置之间的关系。自注意力机制通过计算每个标记与其他所有标记之间的注意力权重,然后根据这些权重对输入序列进行加权求和&#xff0c…...

LabVIEW鼠标悬停在波形图上的曲线来自动显示相应点的坐标

步骤 创建事件结构: 打开LabVIEW,创建一个新的VI。 在前面板上添加一个Waveform Graph控件。 在后面板上添加一个While Loop和一个事件结构(Event Structure)。 配置事件结构,选择Waveform Graph作为事件源&#xf…...

操作系统发展简史(Unix/Linux 篇 + DOS/Windows 篇)+ Mac 与 Microsoft 之风云争霸

操作系统发展简史(Unix/Linux 篇) 说到操作系统,大家都不会陌生。我们天天都在接触操作系统 —— 用台式机或笔记本电脑,使用的是 windows 和 macOS 系统;用手机、平板电脑,则是 android(安卓&…...

钡铼分布式 IO 系统 OPC UA边缘计算耦合器BL205

深圳钡铼技术推出的BL205耦合器支持OPC UA Server功能,以服务器形式对外提供数据。符合IEC 62541工业自动化统一架构通讯标准,数据可以选择加密(X.509证书)、身份验证方式传送。 安全策略支持basic128rsa15、basic256、basic256s…...

实现了一个心理测试的小程序,微信小程序学习使用问题总结

1. 如何在跳转页面中传递参数 ,在 onLoad 方法中通过 options 接收 2. radio 如何获取选中的值? bindchange 方法 参数e, e.detail.value 。 如果想要获取其他属性,使用data-xx 指定,然后 e.target.dataset.xx 获取。 3. 不刷…...

vue是如何进行监听数据变化的?vue2和vue3分别是什么?vue3为什么要更换?

Vue如何进行监听数据变化的? Vue.js 通过其响应式系统来监听数据变化。这个系统允许你声明式地将数据和 DOM 绑定,一旦数据发生变化,相关的 DOM 将自动更新。Vue 使用以下机制来实现数据的监听和响应: 响应式数据:在 …...

数据结构day3

一、思维导图 二、 #include "seqlist.h"#include<myhead.h> int main(int argc, const char *argv[]) {//创建一个顺序表SeqListPtr L list_create();if(NULL L){return -1;}//调用添加函数list_add(L,123);list_add(L,435);list_add(L,856);list_add(L,65…...

免费的数字孪生平台助力产业创新,让新质生产力概念有据可依

关于新质生产力的概念&#xff0c;在如今传统企业现代化发展中被反复提及。 那到底什么是新质生产力&#xff1f;它与哪些行业存在联系&#xff0c;我们又该使用什么工具来加快新质生产力的发展呢&#xff1f;今天我将介绍一款为发展新质生产力而量身定做的数字孪生工具。 新…...

mtsys2 编译 qemu 记录

参考链接 下载 MSYS2 MSYS2 MSYS2 换源 进入目录\msys64\etc\pacman.d&#xff0c; 在文件mirrorlist.msys的前面插入 Server http://mirrors.ustc.edu.cn/msys2/msys/$arch在文件mirrorlist.mingw32的前面插入 Server http://mirrors.ustc.edu.cn/msys2/mingw/i686在…...

【Python数据分析】数据分析三剑客:NumPy、SciPy、Matplotlib中常用操作汇总

文章目录 NumPy常见操作汇总SciPy常见操作汇总Matplotlib常见操作汇总官方文档链接NumPy常见操作汇总 在Python的NumPy库中,有许多常用的知识点,这里列出了一些核心功能和常见操作: 类别函数或特性描述基础操作np.array创建数组np.shape获取数组形状np.dtype查看数组数据类…...

STM32智能家居电力管理系统教程

目录 引言环境准备智能家居电力管理系统基础代码实现&#xff1a;实现智能家居电力管理系统 4.1 数据采集模块 4.2 数据处理与控制模块 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景&#xff1a;电力管理与优化问题解决方案与优化收尾与总结 1. 引言 智能家居电…...

C# 邮件发送

创建邮件类 // 有static时候 类名&#xff0c;方法名// MyEmail.方法名/// <summary>/// 给目标发送邮箱/// </summary>/// <param name"maiTo"></param>/// <param name"title"></param>/// <param name"con…...

Kotlin 协程简化回调

suspend 和 suspendCoroutine 实现 suspendCoroutine函数必须在协程作用域或挂起函数中才能调用&#xff0c;它接收一个Lambda表达式参数&#xff0c;主要作用是将当前协程立即挂起&#xff0c;然后在一个普通的线程中执行Lambda表达式中的代码。Lambda表达式的参数列表上会传…...

帝王蝶算法(EBOA)及Python和MATLAB实现

帝王蝶算法&#xff08;Emperor Butterfly Optimization Algorithm&#xff0c;简称EBOA&#xff09;是一种启发式优化算法&#xff0c;灵感来源于蝴蝶群体中的帝王蝶&#xff08;Emperor Butterfly&#xff09;。该算法模拟了帝王蝶群体中帝王蝶和其他蝴蝶之间的交互行为&…...

【学术会议征稿】第六届信息与计算机前沿技术国际学术会议(ICFTIC 2024)

第六届信息与计算机前沿技术国际学术会议(ICFTIC 2024) 2024 6th International Conference on Frontier Technologies of Information and Computer 第六届信息与计算机前沿技术国际学术会议(ICFTIC 2024)将在中国青岛举行&#xff0c;会期是2024年11月8-10日&#xff0c;为…...

Qwen3.5-9B+OpenClaw组合方案:3类高性价比自动化场景实测

Qwen3.5-9BOpenClaw组合方案&#xff1a;3类高性价比自动化场景实测 1. 为什么选择这个组合&#xff1f; 去年夏天&#xff0c;我花了整整两周时间在本地部署各种开源大模型&#xff0c;试图找到一个既能在预算内运行、又能稳定执行自动化任务的方案。经过反复测试&#xff0…...

AI 与大模型相关

一、 AI 与大模型相关 1.1 Agent&#xff08;智能体&#xff09; 定义&#xff1a;具备自主规划、工具调用、记忆管理、任务执行能力的 AI 实体&#xff0c;能主动完成复杂目标。 核心能力&#xff1a;拆解任务、调用 API / 工具、自主决策、持久记忆、后台执行。 区别&am…...

7个高级配置技巧:打造极致Markdown预览体验

7个高级配置技巧&#xff1a;打造极致Markdown预览体验 【免费下载链接】vscode-markdown-preview-enhanced One of the "BEST" markdown preview extensions for Visual Studio Code 项目地址: https://gitcode.com/gh_mirrors/vs/vscode-markdown-preview-enhanc…...

别光知道Levenshtein!Python实战:用Jaro-Winkler算法搞定人名地址模糊匹配

别光知道Levenshtein&#xff01;Python实战&#xff1a;用Jaro-Winkler算法搞定人名地址模糊匹配 在数据清洗和用户输入处理的场景中&#xff0c;字符串相似度计算是个绕不开的话题。当我们需要匹配"张三丰"和"张三風"时&#xff0c;传统的Levenshtein距离…...

15秒生成12个测试用例:AI写的测试比我写的还全

说实话&#xff0c;我一直是个"测试拖延症患者"。每次写完功能代码&#xff0c;心里都清楚应该补测试&#xff0c;但手就是敲不下去。想着"这个功能这么简单&#xff0c;不会有问题的"&#xff0c;然后安慰自己"等有空了再补"。结果呢&#xff1…...

从GOPATH到Go Mod:老项目迁移必知的5个文件结构陷阱

从GOPATH到Go Mod&#xff1a;老项目迁移必知的5个文件结构陷阱 当Golang社区在2018年推出Go Modules时&#xff0c;很少有人预料到这个看似简单的包管理工具会成为Go语言发展史上的分水岭。四年后的今天&#xff0c;仍有大量遗留项目困在GOPATH的泥潭中&#xff0c;而迁移过程…...

WorkBuddy杀疯了?一群AI专家帮我打工,我在微信里当赛博虾工头!

梦瑶 发自 凹非寺量子位 | 公众号 QbitAI到底是谁说&#xff0c;给老板打工自己就当不成老板的&#xff1f;又是谁说&#xff0c;龙虾不好用、还不听使唤的&#xff1f;反正这些事儿&#xff0c;现在跟我没啥关系了。毕竟现在的我&#xff0c;已经转头当起了「虾工头」&#xf…...

新版药典解读:生物制品生产用动物细胞基质的质量控制修订重点

2025年版《中国药典》已正式实施2个多月&#xff0c;其对生物制品生产用动物细胞基质的质量控制要求进行了重要修订。本次修订对生物制品生产企业和检测机构的影响路径和深度虽有差异&#xff0c;但都指向一个核心转变&#xff1a;从“遵循规定”到“证明科学性”。接下来&…...

如何将TaskWeaver与LangChain无缝集成:扩展AI代理能力边界的终极指南

如何将TaskWeaver与LangChain无缝集成&#xff1a;扩展AI代理能力边界的终极指南 【免费下载链接】TaskWeaver A code-first agent framework for seamlessly planning and executing data analytics tasks. 项目地址: https://gitcode.com/gh_mirrors/ta/TaskWeaver T…...

Ollama + DeepSeek + 芋道框架 + SearXNG 本地联网搜索完整教程

1. 环境准备与检查 在开始之前,请确保你的环境满足以下条件: 1.1 硬件要求 内存:建议至少8GB可用内存(运行7B模型需要约4-6GB) 硬盘:DeepSeek模型文件约4-5GB空间 CPU/GPU:如有NVIDIA GPU可加速推理(可选) 1.2 软件要求 操作系统:Windows 10/11、macOS、Linux均可 …...