当前位置: 首页 > news >正文

Pytorch-MLP-CIFAR10

文章目录

  • model.py
  • main.py
  • 参数设置
  • 注意事项
  • 运行图

model.py

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as initclass MLP_cls(nn.Module):def __init__(self,in_dim=3*32*32):super(MLP_cls,self).__init__()self.lin1 = nn.Linear(in_dim,128)self.lin2 = nn.Linear(128,64)self.lin3 = nn.Linear(64,10)self.relu = nn.ReLU()init.xavier_uniform_(self.lin1.weight)init.xavier_uniform_(self.lin2.weight)init.xavier_uniform_(self.lin3.weight)def forward(self,x):x = x.view(-1,3*32*32)x = self.lin1(x)x = self.relu(x)x = self.lin2(x)x = self.relu(x)x = self.lin3(x)x = self.relu(x)return x

main.py

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from model import MLP_cls,CNN_clsseed = 42
torch.manual_seed(seed)
batch_size_train = 64
batch_size_test  = 64
epochs = 10
learning_rate = 0.01
momentum = 0.5
net = MLP_cls()train_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10('./data/', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))])),batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10('./data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))])),batch_size=batch_size_test, shuffle=True)optimizer = optim.SGD(net.parameters(), lr=learning_rate,momentum=momentum)
criterion = nn.CrossEntropyLoss()print("****************Begin Training****************")
net.train()
for epoch in range(epochs):run_loss = 0correct_num = 0for batch_idx, (data, target) in enumerate(train_loader):out = net(data)_,pred = torch.max(out,dim=1)optimizer.zero_grad()loss = criterion(out,target)loss.backward()run_loss += lossoptimizer.step()correct_num  += torch.sum(pred==target)print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))print("****************Begin Testing****************")
net.eval()
test_loss = 0
test_correct_num = 0
for batch_idx, (data, target) in enumerate(test_loader):out = net(data)_,pred = torch.max(out,dim=1)test_loss += criterion(out,target)test_correct_num  += torch.sum(pred==target)
print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))

参数设置

'./data/' #数据保存路径
seed = 42 #随机种子
batch_size_train = 64
batch_size_test  = 64
epochs = 10optim --> SGD
learning_rate = 0.01
momentum = 0.5

注意事项

CIFAR10是彩色图像,单个大小为3*32*32。所以view的时候后面展平。

运行图

在这里插入图片描述

相关文章:

Pytorch-MLP-CIFAR10

文章目录 model.pymain.py参数设置注意事项运行图 model.py import torch.nn as nn import torch.nn.functional as F import torch.nn.init as initclass MLP_cls(nn.Module):def __init__(self,in_dim3*32*32):super(MLP_cls,self).__init__()self.lin1 nn.Linear(in_dim,1…...

SQL2 查询多列

描述 题目:现在运营同学想要用户的设备id对应的性别、年龄和学校的数据,请你取出相应数据 示例:user_profile iddevice_idgenderageuniversityprovince12138male21北京大学Beijing23214male复旦大学Shanghai36543female20北京大学Beijing42…...

算法分享三个方面学习方法(做题经验,代码编写经验,比赛经验)

目录 0 . 前言:(遇到OI不要慌)(只要道路对了,就不怕遥远) 1. 做题经验谈 1.1 做题的目的 1.2 我对于算法比赛的题目的看法 1.2.1 类似题 1.2.2 套模型: 1.3 在训练过程中如何做题 1.4 一些建议&…...

爬虫 — 验证码反爬

目录 一、超级鹰二、图片验证模拟登录1、页面分析1.1、模拟用户正常登录流程1.2、识别图片里面的文字 2、代码实现 三、滑块模拟登录1、页面分析2、代码实现(通过对比像素获取缺口位置) 四、openCV1、简介2、代码3、案例 五、selenium 反爬六、百度智能云…...

视频图像处理算法opencv模块硬件设计图像颜色识别模块

1、Opencv简介 OpenCV是一个基于Apache2.0许可(开源)发行的跨平台计算机视觉和机器学习软件库,可以运行在Linux、Windows、Android和Mac OS操作系统上 它轻量级而且高效——由一系列 C 函数和少量 C 类构成,同时提供了Python、Rub…...

目标检测网络之Fast-RCNN

文章目录 Fast RCNN解决的问题Fast RCNN网络结构RoI pooling layer合并损失函数及其传播统一的损失函数损失函数的反向传播过程Fast RCNN的训练方法样本选择方法SGD参数设置多尺度图像训练SVD压缩全连接层对比实验对比实验使用到的网络结构VOC2010和VOC2012数据集结果VOC2007数…...

Golang Gorm 创建HOOK

创建的时候,在插入数据之前,想要做一些事情。钩子函数比较简单,就是实现before create的一个方法。 package mainimport ("gorm.io/driver/mysql""gorm.io/gorm" )type Student struct {ID int64Name string gorm:&q…...

计算机视觉的应用15-图片旋转验证码的角度计算模型的应用,解决旋转图片矫正问题

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用15-图片旋转验证码的角度计算模型的应用,解决旋转图片矫正问题,在CV领域,图片旋转验证码的角度计算模型被广泛应用于解决旋转图片矫正问题,有效解决机…...

【Seata】分布式事务问题和理论基础

目录 1.分布式事务问题 1.1本地事务 1.2分布式事务 2.理论基础 2.1CAP定理 2.1.1一致性 2.1.2可用性 2.1.3分区容错 2.1.4矛盾 2.2BASE理论 2.3解决分布式事务的思路 1.分布式事务问题 1.1本地事务 本地事务,也就是传统的单机事务。在传统数据库事务中…...

文件打包解包的方法

在很多情况下,软件需要隐藏一些图片,防止用户对其更改,替换。例如腾讯QQ里面的资源图片,哪怕你用Everything去搜索也搜索不到,那是因为腾讯QQ对这些资源图片进行了打包,当软件运行的时候解包获取资源图片。…...

npm 清缓存(重新安装node-modules)

安装node依赖包的会出现失败的情况,如下图所示: 此时 提示有些依赖树有冲突,根据提示 “ this command with --force or --legacy-peer-deps” 执行命令即可。 具体步骤如下: 1、先删除本地node-modules包 2、删掉page-loacl…...

sqlserver查询表中所有字段信息

精简 SELECT 字段名 a.name,主键 case when exists(SELECT 1 FROM sysobjects where xtypePK and parent_obja.id and name in (SELECT name FROM sysindexes WHERE indid in( SELECT indid FROM sysindexkeys WHERE id a.id AND colida.colid))) then √ else …...

二叉树的概念、存储及遍历

一、二叉树的概念 1、二叉树的定义 二叉树( binary tree)是 n 个结点的有限集合,该集合或为空集(空二叉树),或由一个根结点与两棵互不相交的,称为根结点的左子树、右子树的二叉树构成。 二叉树的…...

【面试题】智力题

文章目录 腾讯1000瓶毒药里面只有1瓶是有毒的,问需要多少只老鼠才能在24小时后试出那瓶有毒。有两根不规则的绳子,两根绳子从头烧到尾均需要一个小时,现在有一个45分钟的比赛,裁判员忘记带计时器,你能否通过烧绳子的方…...

【SpringBoot集成Redis + Session持久化存储到Redis】

目录 SpringBoot集成Redis 1.添加 redis 依赖 2.配置 redis 3.手动操作 redis Session持久化存储到Redis 1.添加依赖 2.修改redis配置 3.存储和读取String类型的代码 4.存储和读取对象类型的代码 5.序列化细节 SpringBoot集成Redis 1.添加 redis 依赖 …...

day49:QT day2,信号与槽、对话框

一、完善登录框 点击登录按钮后,判断账号(admin)和密码(123456)是否一致,如果匹配失败,则弹出错误对话框,文本内容“账号密码不匹配,是否重新登录”,给定两个…...

Meta分析核心技术

Meta分析是针对某一科研问题,根据明确的搜索策略、选择筛选文献标准、采用严格的评价方法,对来源不同的研究成果进行收集、合并及定量统计分析的方法,最早出现于“循证医学”,现已广泛应用于农林生态,资源环境等方面。…...

Gof23设计模式之责任链模式

1.概述 责任链模式又名职责链模式,为了避免请求发送者与多个请求处理者耦合在一起,将所有请求的处理者通过前一对象记住其下一个对象的引用而连成一条链;当有请求发生时,可将请求沿着这条链传递,直到有对象处理它为止…...

数字孪生和元宇宙:打造未来的数字边界

数字孪生和元宇宙是近两年来被热议的两个概念,但由于技术的交叉两者也极易被混淆。本文希望带大家深入探讨一下这两者之间的关系,以及它们如何一起构建了数字时代的新格局。 1. 数字孪生的本质 数字孪生是一种虚拟模型,它通过数字手段对现实…...

【新版】系统架构设计师 - 软件架构设计<新版>

个人总结,仅供参考,欢迎加好友一起讨论 文章目录 架构 - 软件架构设计<新版>考点摘要概念架构的 4 1 视图架构描述语言ADL基于架构的软件开发方法ABSDABSD的开发模型ABSDMABSD(ABSDM模型)的开发过程 软件架…...

后进先出(LIFO)详解

LIFO 是 Last In, First Out 的缩写,中文译为后进先出。这是一种数据结构的工作原则,类似于一摞盘子或一叠书本: 最后放进去的元素最先出来 -想象往筒状容器里放盘子: (1)你放进的最后一个盘子&#xff08…...

Cesium1.95中高性能加载1500个点

一、基本方式&#xff1a; 图标使用.png比.svg性能要好 <template><div id"cesiumContainer"></div><div class"toolbar"><button id"resetButton">重新生成点</button><span id"countDisplay&qu…...

关于iview组件中使用 table , 绑定序号分页后序号从1开始的解决方案

问题描述&#xff1a;iview使用table 中type: "index",分页之后 &#xff0c;索引还是从1开始&#xff0c;试过绑定后台返回数据的id, 这种方法可行&#xff0c;就是后台返回数据的每个页面id都不完全是按照从1开始的升序&#xff0c;因此百度了下&#xff0c;找到了…...

MySQL 8.0 OCP 英文题库解析(十三)

Oracle 为庆祝 MySQL 30 周年&#xff0c;截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始&#xff0c;将英文题库免费公布出来&#xff0c;并进行解析&#xff0c;帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...

蓝桥杯3498 01串的熵

问题描述 对于一个长度为 23333333的 01 串, 如果其信息熵为 11625907.5798&#xff0c; 且 0 出现次数比 1 少, 那么这个 01 串中 0 出现了多少次? #include<iostream> #include<cmath> using namespace std;int n 23333333;int main() {//枚举 0 出现的次数//因…...

SAP学习笔记 - 开发26 - 前端Fiori开发 OData V2 和 V4 的差异 (Deepseek整理)

上一章用到了V2 的概念&#xff0c;其实 Fiori当中还有 V4&#xff0c;咱们这一章来总结一下 V2 和 V4。 SAP学习笔记 - 开发25 - 前端Fiori开发 Remote OData Service(使用远端Odata服务)&#xff0c;代理中间件&#xff08;ui5-middleware-simpleproxy&#xff09;-CSDN博客…...

mac:大模型系列测试

0 MAC 前几天经过学生优惠以及国补17K入手了mac studio,然后这两天亲自测试其模型行运用能力如何&#xff0c;是否支持微调、推理速度等能力。下面进入正文。 1 mac 与 unsloth 按照下面的进行安装以及测试&#xff0c;是可以跑通文章里面的代码。训练速度也是很快的。 注意…...

Linux中《基础IO》详细介绍

目录 理解"文件"狭义理解广义理解文件操作的归类认知系统角度文件类别 回顾C文件接口打开文件写文件读文件稍作修改&#xff0c;实现简单cat命令 输出信息到显示器&#xff0c;你有哪些方法stdin & stdout & stderr打开文件的方式 系统⽂件I/O⼀种传递标志位…...

Java数组Arrays操作全攻略

Arrays类的概述 Java中的Arrays类位于java.util包中&#xff0c;提供了一系列静态方法用于操作数组&#xff08;如排序、搜索、填充、比较等&#xff09;。这些方法适用于基本类型数组和对象数组。 常用成员方法及代码示例 排序&#xff08;sort&#xff09; 对数组进行升序…...

spring boot使用HttpServletResponse实现sse后端流式输出消息

1.以前只是看过SSE的相关文章&#xff0c;没有具体实践&#xff0c;这次接入AI大模型使用到了流式输出&#xff0c;涉及到给前端流式返回&#xff0c;所以记录一下。 2.resp要设置为text/event-stream resp.setContentType("text/event-stream"); resp.setCharacter…...