当前位置: 首页 > news >正文

Pytorch-MLP-CIFAR10

文章目录

  • model.py
  • main.py
  • 参数设置
  • 注意事项
  • 运行图

model.py

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as initclass MLP_cls(nn.Module):def __init__(self,in_dim=3*32*32):super(MLP_cls,self).__init__()self.lin1 = nn.Linear(in_dim,128)self.lin2 = nn.Linear(128,64)self.lin3 = nn.Linear(64,10)self.relu = nn.ReLU()init.xavier_uniform_(self.lin1.weight)init.xavier_uniform_(self.lin2.weight)init.xavier_uniform_(self.lin3.weight)def forward(self,x):x = x.view(-1,3*32*32)x = self.lin1(x)x = self.relu(x)x = self.lin2(x)x = self.relu(x)x = self.lin3(x)x = self.relu(x)return x

main.py

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from model import MLP_cls,CNN_clsseed = 42
torch.manual_seed(seed)
batch_size_train = 64
batch_size_test  = 64
epochs = 10
learning_rate = 0.01
momentum = 0.5
net = MLP_cls()train_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10('./data/', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))])),batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10('./data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))])),batch_size=batch_size_test, shuffle=True)optimizer = optim.SGD(net.parameters(), lr=learning_rate,momentum=momentum)
criterion = nn.CrossEntropyLoss()print("****************Begin Training****************")
net.train()
for epoch in range(epochs):run_loss = 0correct_num = 0for batch_idx, (data, target) in enumerate(train_loader):out = net(data)_,pred = torch.max(out,dim=1)optimizer.zero_grad()loss = criterion(out,target)loss.backward()run_loss += lossoptimizer.step()correct_num  += torch.sum(pred==target)print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))print("****************Begin Testing****************")
net.eval()
test_loss = 0
test_correct_num = 0
for batch_idx, (data, target) in enumerate(test_loader):out = net(data)_,pred = torch.max(out,dim=1)test_loss += criterion(out,target)test_correct_num  += torch.sum(pred==target)
print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))

参数设置

'./data/' #数据保存路径
seed = 42 #随机种子
batch_size_train = 64
batch_size_test  = 64
epochs = 10optim --> SGD
learning_rate = 0.01
momentum = 0.5

注意事项

CIFAR10是彩色图像,单个大小为3*32*32。所以view的时候后面展平。

运行图

在这里插入图片描述

相关文章:

Pytorch-MLP-CIFAR10

文章目录 model.pymain.py参数设置注意事项运行图 model.py import torch.nn as nn import torch.nn.functional as F import torch.nn.init as initclass MLP_cls(nn.Module):def __init__(self,in_dim3*32*32):super(MLP_cls,self).__init__()self.lin1 nn.Linear(in_dim,1…...

SQL2 查询多列

描述 题目:现在运营同学想要用户的设备id对应的性别、年龄和学校的数据,请你取出相应数据 示例:user_profile iddevice_idgenderageuniversityprovince12138male21北京大学Beijing23214male复旦大学Shanghai36543female20北京大学Beijing42…...

算法分享三个方面学习方法(做题经验,代码编写经验,比赛经验)

目录 0 . 前言:(遇到OI不要慌)(只要道路对了,就不怕遥远) 1. 做题经验谈 1.1 做题的目的 1.2 我对于算法比赛的题目的看法 1.2.1 类似题 1.2.2 套模型: 1.3 在训练过程中如何做题 1.4 一些建议&…...

爬虫 — 验证码反爬

目录 一、超级鹰二、图片验证模拟登录1、页面分析1.1、模拟用户正常登录流程1.2、识别图片里面的文字 2、代码实现 三、滑块模拟登录1、页面分析2、代码实现(通过对比像素获取缺口位置) 四、openCV1、简介2、代码3、案例 五、selenium 反爬六、百度智能云…...

视频图像处理算法opencv模块硬件设计图像颜色识别模块

1、Opencv简介 OpenCV是一个基于Apache2.0许可(开源)发行的跨平台计算机视觉和机器学习软件库,可以运行在Linux、Windows、Android和Mac OS操作系统上 它轻量级而且高效——由一系列 C 函数和少量 C 类构成,同时提供了Python、Rub…...

目标检测网络之Fast-RCNN

文章目录 Fast RCNN解决的问题Fast RCNN网络结构RoI pooling layer合并损失函数及其传播统一的损失函数损失函数的反向传播过程Fast RCNN的训练方法样本选择方法SGD参数设置多尺度图像训练SVD压缩全连接层对比实验对比实验使用到的网络结构VOC2010和VOC2012数据集结果VOC2007数…...

Golang Gorm 创建HOOK

创建的时候,在插入数据之前,想要做一些事情。钩子函数比较简单,就是实现before create的一个方法。 package mainimport ("gorm.io/driver/mysql""gorm.io/gorm" )type Student struct {ID int64Name string gorm:&q…...

计算机视觉的应用15-图片旋转验证码的角度计算模型的应用,解决旋转图片矫正问题

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用15-图片旋转验证码的角度计算模型的应用,解决旋转图片矫正问题,在CV领域,图片旋转验证码的角度计算模型被广泛应用于解决旋转图片矫正问题,有效解决机…...

【Seata】分布式事务问题和理论基础

目录 1.分布式事务问题 1.1本地事务 1.2分布式事务 2.理论基础 2.1CAP定理 2.1.1一致性 2.1.2可用性 2.1.3分区容错 2.1.4矛盾 2.2BASE理论 2.3解决分布式事务的思路 1.分布式事务问题 1.1本地事务 本地事务,也就是传统的单机事务。在传统数据库事务中…...

文件打包解包的方法

在很多情况下,软件需要隐藏一些图片,防止用户对其更改,替换。例如腾讯QQ里面的资源图片,哪怕你用Everything去搜索也搜索不到,那是因为腾讯QQ对这些资源图片进行了打包,当软件运行的时候解包获取资源图片。…...

npm 清缓存(重新安装node-modules)

安装node依赖包的会出现失败的情况,如下图所示: 此时 提示有些依赖树有冲突,根据提示 “ this command with --force or --legacy-peer-deps” 执行命令即可。 具体步骤如下: 1、先删除本地node-modules包 2、删掉page-loacl…...

sqlserver查询表中所有字段信息

精简 SELECT 字段名 a.name,主键 case when exists(SELECT 1 FROM sysobjects where xtypePK and parent_obja.id and name in (SELECT name FROM sysindexes WHERE indid in( SELECT indid FROM sysindexkeys WHERE id a.id AND colida.colid))) then √ else …...

二叉树的概念、存储及遍历

一、二叉树的概念 1、二叉树的定义 二叉树( binary tree)是 n 个结点的有限集合,该集合或为空集(空二叉树),或由一个根结点与两棵互不相交的,称为根结点的左子树、右子树的二叉树构成。 二叉树的…...

【面试题】智力题

文章目录 腾讯1000瓶毒药里面只有1瓶是有毒的,问需要多少只老鼠才能在24小时后试出那瓶有毒。有两根不规则的绳子,两根绳子从头烧到尾均需要一个小时,现在有一个45分钟的比赛,裁判员忘记带计时器,你能否通过烧绳子的方…...

【SpringBoot集成Redis + Session持久化存储到Redis】

目录 SpringBoot集成Redis 1.添加 redis 依赖 2.配置 redis 3.手动操作 redis Session持久化存储到Redis 1.添加依赖 2.修改redis配置 3.存储和读取String类型的代码 4.存储和读取对象类型的代码 5.序列化细节 SpringBoot集成Redis 1.添加 redis 依赖 …...

day49:QT day2,信号与槽、对话框

一、完善登录框 点击登录按钮后,判断账号(admin)和密码(123456)是否一致,如果匹配失败,则弹出错误对话框,文本内容“账号密码不匹配,是否重新登录”,给定两个…...

Meta分析核心技术

Meta分析是针对某一科研问题,根据明确的搜索策略、选择筛选文献标准、采用严格的评价方法,对来源不同的研究成果进行收集、合并及定量统计分析的方法,最早出现于“循证医学”,现已广泛应用于农林生态,资源环境等方面。…...

Gof23设计模式之责任链模式

1.概述 责任链模式又名职责链模式,为了避免请求发送者与多个请求处理者耦合在一起,将所有请求的处理者通过前一对象记住其下一个对象的引用而连成一条链;当有请求发生时,可将请求沿着这条链传递,直到有对象处理它为止…...

数字孪生和元宇宙:打造未来的数字边界

数字孪生和元宇宙是近两年来被热议的两个概念,但由于技术的交叉两者也极易被混淆。本文希望带大家深入探讨一下这两者之间的关系,以及它们如何一起构建了数字时代的新格局。 1. 数字孪生的本质 数字孪生是一种虚拟模型,它通过数字手段对现实…...

【新版】系统架构设计师 - 软件架构设计<新版>

个人总结,仅供参考,欢迎加好友一起讨论 文章目录 架构 - 软件架构设计<新版>考点摘要概念架构的 4 1 视图架构描述语言ADL基于架构的软件开发方法ABSDABSD的开发模型ABSDMABSD(ABSDM模型)的开发过程 软件架…...

【机器视觉】单目测距——运动结构恢复

ps:图是随便找的,为了凑个封面 前言 在前面对光流法进行进一步改进,希望将2D光流推广至3D场景流时,发现2D转3D过程中存在尺度歧义问题,需要补全摄像头拍摄图像中缺失的深度信息,否则解空间不收敛&#xf…...

Golang dig框架与GraphQL的完美结合

将 Go 的 Dig 依赖注入框架与 GraphQL 结合使用,可以显著提升应用程序的可维护性、可测试性以及灵活性。 Dig 是一个强大的依赖注入容器,能够帮助开发者更好地管理复杂的依赖关系,而 GraphQL 则是一种用于 API 的查询语言,能够提…...

Frozen-Flask :将 Flask 应用“冻结”为静态文件

Frozen-Flask 是一个用于将 Flask 应用“冻结”为静态文件的 Python 扩展。它的核心用途是:将一个 Flask Web 应用生成成纯静态 HTML 文件,从而可以部署到静态网站托管服务上,如 GitHub Pages、Netlify 或任何支持静态文件的网站服务器。 &am…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成

厌倦手动写WordPress文章?AI自动生成,效率提升10倍! 支持多语言、自动配图、定时发布,让内容创作更轻松! AI内容生成 → 不想每天写文章?AI一键生成高质量内容!多语言支持 → 跨境电商必备&am…...

Android15默认授权浮窗权限

我们经常有那种需求,客户需要定制的apk集成在ROM中,并且默认授予其【显示在其他应用的上层】权限,也就是我们常说的浮窗权限,那么我们就可以通过以下方法在wms、ams等系统服务的systemReady()方法中调用即可实现预置应用默认授权浮…...

SiFli 52把Imagie图片,Font字体资源放在指定位置,编译成指定img.bin和font.bin的问题

分区配置 (ptab.json) img 属性介绍: img 属性指定分区存放的 image 名称,指定的 image 名称必须是当前工程生成的 binary 。 如果 binary 有多个文件,则以 proj_name:binary_name 格式指定文件名, proj_name 为工程 名&…...

CRMEB 中 PHP 短信扩展开发:涵盖一号通、阿里云、腾讯云、创蓝

目前已有一号通短信、阿里云短信、腾讯云短信扩展 扩展入口文件 文件目录 crmeb\services\sms\Sms.php 默认驱动类型为:一号通 namespace crmeb\services\sms;use crmeb\basic\BaseManager; use crmeb\services\AccessTokenServeService; use crmeb\services\sms\…...

LabVIEW双光子成像系统技术

双光子成像技术的核心特性 双光子成像通过双低能量光子协同激发机制,展现出显著的技术优势: 深层组织穿透能力:适用于活体组织深度成像 高分辨率观测性能:满足微观结构的精细研究需求 低光毒性特点:减少对样本的损伤…...

深入浅出Diffusion模型:从原理到实践的全方位教程

I. 引言:生成式AI的黎明 – Diffusion模型是什么? 近年来,生成式人工智能(Generative AI)领域取得了爆炸性的进展,模型能够根据简单的文本提示创作出逼真的图像、连贯的文本,乃至更多令人惊叹的…...

Leetcode33( 搜索旋转排序数组)

题目表述 整数数组 nums 按升序排列&#xff0c;数组中的值 互不相同 。 在传递给函数之前&#xff0c;nums 在预先未知的某个下标 k&#xff08;0 < k < nums.length&#xff09;上进行了 旋转&#xff0c;使数组变为 [nums[k], nums[k1], …, nums[n-1], nums[0], nu…...