当前位置: 首页 > news >正文

LeNet

概念

代码

model

import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()  # super()继承父类的构造函数self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x): x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x)            # output(16, 14, 14)x = F.relu(self.conv2(x))    # output(32, 10, 10)x = self.pool2(x)            # output(32, 5, 5)x = x.view(-1, 32*5*5)       # output(32*5*5)x = F.relu(self.fc1(x))      # output(120)x = F.relu(self.fc2(x))      # output(84)x = self.fc3(x)              # output(10)return x

forward:定义正向传播的过程。

ReLU:激活哈数

观察网络中的参数传递:发现传递的都是channel通道数,最后output在softmax函数里展开的也是展开的通道数。

train

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transformsdef main():transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 50000张训练图片# 第一次使用时要将download设置为True才会自动去下载数据集train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=True, num_workers=0)# 10000张验证图片# 第一次使用时要将download设置为True才会自动去下载数据集val_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,shuffle=False, num_workers=0)val_data_iter = iter(val_loader)val_image, val_label = next(val_data_iter)# classes = ('plane', 'car', 'bird', 'cat',#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = LeNet()loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)for epoch in range(5):  # loop over the dataset multiple timesrunning_loss = 0.0for step, data in enumerate(train_loader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if step % 500 == 499:    # print every 500 mini-batcheswith torch.no_grad():outputs = net(val_image)  # [batch, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')save_path = './Lenet.pth'torch.save(net.state_dict(), save_path)if __name__ == '__main__':main()

predict.py

import torch
import torchvision.transforms as transforms
from PIL import Imagefrom model import LeNetdef main():transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = LeNet()net.load_state_dict(torch.load('Lenet.pth'))im = Image.open('1.jpg').convert('RGB')im = transform(im)  # [C, H, W]im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].numpy()# predict = torch.softmax(outputs,dim=1)# print(predict)# tensor([[9.9884e-01, 1.9386e-04, 3.8757e-04, 2.0671e-05, 2.5372e-04, 3.6199e-05,# 3.7643e-05, 1.7624e-04, 2.0138e-05, 3.4801e-05]])print(classes[int(predict)])if __name__ == '__main__':main()

知识点:

增加新的维度: 

im = torch.unsqueeze(im, dim=0)  # [N, C, H, W] 

predict = torch.max(outputs, dim=1)[1].numpy():

这一行代码使用torch.max()函数找到outputs张量在第一个维度上的最大值,并返回最大值和对应的索引。dim=1表示在第一个维度上进行最大值的计算,即对每个样本的输出进行比较。[1]表示返回最大值对应的索引。最后,.numpy()将结果转换为NumPy数组。 

更换:

predict = torch.softmax(outputs,dim=1)

print:tensor([[9.9884e-01, 1.9386e-04, 3.8757e-04, 2.0671e-05, 2.5372e-04, 3.6199e-05,
         3.7643e-05, 1.7624e-04, 2.0138e-05, 3.4801e-05]])

Pytorch使用

相关文章:

LeNet

概念 代码 model import torch.nn as nn import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__() # super()继承父类的构造函数self.conv1 nn.Conv2d(3, 16, 5)self.pool1 nn.MaxPool2d(2, 2)self.conv2 nn.Conv2d(16…...

JavaScript 简单理解原型和创建实例时 new 操作符的执行操作

function Person(){// 构造函数// 当函数创建,prototype 属性指向一个原型对象时,在默认情况下,// 这个原型对象将会获得一个 constructor 属性,这个属性是一个指针,指向 prototype 所在的函数对象。 } // 为原型对象添…...

生成对抗网络——研讨会

时隔一年,再跟着李沐大师学习了GAN之后,仍旧没能在离散优化中实现通用的应用,实在惭愧,借着组内研讨会的机会,再队GAN的前世今生做一个简单的综述。 GAN产生的背景 目前与GAN相关的应用 去reddit社区的机器学习板块…...

Ubuntu 20.04 安装 mysql8 LTS

Ubuntu 20.04 安装 mysql8 LTS sudo apt-get update sudo apt-get install mysql-server -y mysql --version mysql Ver 8.0.35-0ubuntu0.20.04.1 for Linux on x86_64 ((Ubuntu)) Ubuntu20.04 是自带了 MySQL8. 几版本的,低于 20.04 则默认安装是 MySQL5.7.33…...

蓝桥杯:货物摆放

小蓝有一个超大的仓库,可以摆放很多货物。 现在,小蓝有 n 箱货物要摆放在仓库,每箱货物都是规则的正方体。小蓝规定了长、宽、高三个互相垂直的方向,每箱货物的边都必须严格平行于长、宽、高。 小蓝希望所有的货物最终摆成一个大…...

ganache部署智能合约报错VM Exception while processing transaction: invalid opcode

这是因为编译的字节码不正确,ganache和remix编译时需要选择相同的evm version 如下图所示: remix: ganache: 确保两者都选择london或者其他evm,只要确保EVM一致就可以正确编译并部署, 不会再出现VM Exception while processing…...

金融银行业更适合申请哪种SSL证书?

在当今数字化时代,金融行业的重要性日益增加。越来越多的金融交易和敏感信息在线进行,金融银行机构必须采取必要的措施来保护客户数据的安全。SSL证书作为一种重要的安全技术工具,可以帮助金融银行机构加密数据传输,验证网站身份&…...

文心一言API(高级版)使用

文心一言API高级版使用 一、百度文心一言API(高级版)二、使用步骤1、接口2、请求参数3、请求参数示例4、接口 返回示例 三、 如何获取appKey和uid1、申请appKey:2、获取appKey和uid 四、重要说明 一、百度文心一言API(高级版) 基于百度文心一言语言大模型的智能文本对话AI机器…...

C# 任务并行类库Parallel调用示例

写在前面 Task Parallel Library 是微软.NET框架基础类库(BCL)中的一个,主要目的是为了简化并行编程,可以实现在不同的处理器上并行处理不同任务,以提升运行效率。Parallel常用的方法有For/ForEach/Invoke三个静态方法…...

2024年江苏省职业院校技能大赛信息安全管理与评估 第二阶段学生组(样卷)

2024年江苏省职业院校技能大赛信息安全管理与评估 第二阶段学生组(样卷) 竞赛项目赛题 本文件为信息安全管理与评估项目竞赛-第二阶段样题,内容包括:网络安全事件响应、数字取证调查、应用程序安全。 本次比赛时间为180分钟。 …...

飞天使-linux操作的一些技巧与知识点3

http工作原理 http1.0 协议 使用的是短连接,建立一次tcp连接,发起一次http的请求,结束,tcp断开 http1.1 协议使用的是长连接,建立一次tcp的连接,发起多次http的请求,结束,tcp断开ngi…...

Appium获取toast方法封装

一、前置说明 toast消失的很快,并且通过uiautomatorviewer也不能获取到它的定位信息,如下图: 二、操作步骤 toast的class name值为android.widget.Toast,虽然toast消失的很快,但是它终究是在Dom结构中出现过&…...

Google Guava简析

Google Guava 是Google开源的一个Java类库,对基本类库做了扩充。感觉最大的价值点在于其 集合类、Cache和String工具类。 github项目地址:GitHub - google/guava: Google core libraries for Java github文档地址:Home google/guava Wiki …...

反序列化漏洞详解(二)

目录 pop链前置知识,魔术方法触发规则 pop构造链解释(开始烧脑了) 字符串逃逸基础 字符减少 字符串逃逸基础 字符增加 实例获取flag 字符串增多逃逸 字符串减少逃逸 延续反序列化漏洞(一)的内容 pop链前置知识,魔术方法触…...

React全站框架Next.js使用入门

Next.js是一个基于React的服务器端渲染框架,它可以帮助我们快速构建React应用程序,并具有以下优势: 1. 支持服务器端渲染,提高页面渲染速度和SEO; 2. 自带webpack开发环境,实现即插即用的特性;…...

【操作系统笔记】-文件系统

引言 之前已经学习过数据在内存中是如何表示,如何存储,但是这些存储在PC断电后数据便消失。因此我们需要一个可以持久化存储并且容量远远大于内存的结构,这一篇我们将学习,文件是如何被组织和操作的,这是一个操作系统…...

第二十一章 网络通信

计算机网络实现了堕胎计算机间的互联,使得它们彼此之间能够进行数据交流。网络应用程序就是再已连接的不同计算机上运行的程序,这些程序借助于网络协议,相互之间可以交换数据,编写网络应用程序前,首先必须明确网络协议…...

【漏洞复现】万户协同办公平台ezoffice wpsservlet接口存在任意文件上传漏洞 附POC

漏洞描述 万户ezOFFICE集团版协同平台以工作流程、知识管理、沟通交流和辅助办公四大核心应用 万户ezOFFICE协同管理平台是一个综合信息基础应用平台。 万户协同办公平台ezoffice wpsservlet接口存在任意文件上传漏洞。 免责声明 技术文章仅供参考,任何个人和组织使用网络应…...

【uniapp】小程序中input输入框的placeholder-class不生效解决办法

问题描述 uniapp微信小程序&#xff0c;使用input组件时&#xff0c;想要改变提示词 placeholder 的样式&#xff0c;但是使用placeholder-class 改变不了 如下&#xff1a; <input type"text" placeholder"搜索" placeholder-class"placeholde…...

SimplePIR——目前最快单服务器匿踪查询方案

一、介绍 这篇论文旨在实现高效的单服务器隐私信息检索&#xff08;PIR&#xff09;方案&#xff0c;以解决在保护用户隐私的同时快速检索数据库的问题。为了实现这一目标&#xff0c;论文提出了两种新的PIR方案&#xff1a;SimplePIR和DoublePIR。这两种方案的实现基于学习与错…...

COMSOL多场耦合地应力平衡开挖与衬砌支护案例:带衬砌与钢衬支护的实践研究

COMSOL 地应力平衡后开挖及衬砌支护案例&#xff08;带衬砌、钢衬&#xff09;隧道开挖模拟最头疼的就是初始地应力场的平衡问题。前些天用COMSOL折腾了个带衬砌支护的案例&#xff0c;今天把关键步骤拆开说说。咱们直接从地应力平衡开始&#xff0c;到开挖后钢衬安装一气呵成。…...

基于CubeMX与HAL库:STM32F302串口重定向Printf的工程化实践

1. 为什么需要串口重定向Printf 在嵌入式开发中&#xff0c;调试信息输出是排查问题的生命线。想象一下你正在调试一个复杂的传感器数据采集系统&#xff0c;突然发现数据异常&#xff0c;这时候如果能像在PC上编程一样直接printf("当前温度值&#xff1a;%f", temp…...

用PLECS和C代码手把手教你实现数字滤波(附完整工程文件)

用PLECS和C代码实现数字滤波的工程实践指南 在电力电子和电机控制领域&#xff0c;数字滤波技术是实现信号处理的关键环节。无论是消除高频噪声还是提取特定频段的信号成分&#xff0c;一个设计良好的数字滤波器都能显著提升系统性能。本文将带您从理论到实践&#xff0c;通过P…...

CORS跨域问题终极指南:从XMLHttpRequest到Nginx代理的完整解决方案

CORS跨域问题终极指南&#xff1a;从XMLHttpRequest到Nginx代理的完整解决方案 第一次在控制台看到那个鲜红的CORS错误时&#xff0c;我正为一个紧急项目赶工。凌晨三点的咖啡已经凉了&#xff0c;而浏览器的报错信息像一堵墙横在我和 deadline 之间。相信每个全栈开发者都经历…...

【毕业设计】微信小程序文创商城-从真实支付到模拟支付的实现与优化

1. 微信小程序文创商城支付功能概述 做毕业设计选择微信小程序文创商城是个不错的选题&#xff0c;尤其是支付功能的实现&#xff0c;既能锻炼技术能力&#xff0c;又很实用。我去年指导过几个类似的项目&#xff0c;发现学生们最头疼的就是支付模块。真实支付需要营业执照和公…...

Flutter 3.24.x项目升级AGP 8.6适配Android 15,我踩过的坑和完整配置清单

Flutter 3.24.x项目升级AGP 8.6适配Android 15实战指南 上周在给公司核心项目做技术栈升级时&#xff0c;我花了整整三天时间才把Flutter 3.24.x项目成功迁移到AGP 8.6并适配Android 15&#xff08;API 35&#xff09;。这过程中踩过的坑比预想中多得多——从Gradle版本冲突到n…...

效率倍增:用快马生成openclaw在ubuntu的一键部署与docker化脚本

最近在折腾一个开源项目openclaw的部署&#xff0c;发现每次在Ubuntu服务器上手动安装配置特别费时间。作为一个懒人程序员&#xff0c;我决定研究下怎么把整个流程自动化&#xff0c;结果发现用InsCode(快马)平台可以轻松搞定这件事&#xff0c;效率直接翻倍。 传统部署方式的…...

GLM-4.1V-9B-Base入门必看:中文提问技巧——如何写出高稳定度问题

GLM-4.1V-9B-Base入门必看&#xff1a;中文提问技巧——如何写出高稳定度问题 1. 认识GLM-4.1V-9B-Base GLM-4.1V-9B-Base是智谱开源的视觉多模态理解模型&#xff0c;专门用于处理图像内容识别、场景描述、目标问答等中文视觉理解任务。与普通聊天模型不同&#xff0c;它更擅…...

避坑指南:在YOLOv5-7.0中融合BiFPN时,如何平衡P2检测头带来的精度与速度损耗?

YOLOv5-7.0中BiFPN与P2检测头的精度与速度平衡实战 当你在无人机航拍画面中寻找几毫米大小的电子元件时&#xff0c;或者在显微镜图像中定位细胞核位置时&#xff0c;传统目标检测模型的性能往往会大打折扣。这正是微小目标检测技术大显身手的场景——而YOLOv5作为工业界最受欢…...

ThinkPHP8 + Swoole6 实战:从宝塔面板到进程守护,手把手搭建稳定WebSocket服务

ThinkPHP8 Swoole6 生产级WebSocket服务部署指南 当实时通信成为现代应用的标配&#xff0c;如何将WebSocket服务稳定部署到生产环境就成了开发者必须掌握的技能。不同于本地开发环境&#xff0c;线上部署需要考虑服务器配置、进程守护、负载均衡等一系列复杂因素。本文将带你…...