使用 AlexNet 实现图片分类 | PyTorch 深度学习实战
前一篇文章,CNN 卷积神经网络处理图片任务 | PyTorch 深度学习实战
本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started
本篇文章内容来自于 强化学习必修课:引领人工智能新时代【梗直哥瞿炜】
使用 AlexNet 实现图片分类
- 经典卷积网络
- AlexNet 特点
- 实验代码
- 实验结果
- Links
经典卷积网络
以下是卷积神经网络发展的里程碑:

- AlexNet 在各项比赛中,比其它算法好,证明了深度神经网络算法的优越性和前景
- VGGNet 则使用比 AlexNet 更深更宽的网络,取得了比 AlexNet 还好的成绩
- GoogLeNet 效果则比 VGGNet 更好
- ResNet 引入残差模块,解决了深度网络训练中的退化问题,超越之前的模型
- DenseNet 模型采用密集连接的结构,使模型具有更好的鲁棒性
AlexNet 特点

AlexNet 结构
![![[../assets/media/screenshot_20250208184431.png]]](https://i-blog.csdnimg.cn/direct/0c31fdef1a234aeb86b5162ff6237f54.png)
更多详细介绍,阅读作者 Paper 论文, ImageNet Classification with Deep Convolutional Neural Networks
视频资源:9年后重读深度学习奠基作之一:AlexNet【上】【论文精读】
实验代码
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms
from torchvision import datasets
import torch.nn as nn
import torch
import numpy as np# configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_rootdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, "data")################################
# 定义 dataset loader
################################
def get_train_valid_loader(data_dir,batch_size,augment,random_seed,valid_size=0.1,shuffle=True):# 正则化图片的参数,mean 和 std 的值来自于 imagenet 的数据统计normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],std=[0.2023, 0.1994, 0.2010],)# define transformsvalid_transform = transforms.Compose([transforms.Resize((227, 227)),transforms.ToTensor(),normalize,])if augment:# 数据增强train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize,])else:train_transform = transforms.Compose([transforms.Resize((227, 227)),transforms.ToTensor(),normalize,])# load the datasettrain_dataset = datasets.CIFAR10(root=data_dir, train=True,download=True, transform=train_transform,)valid_dataset = datasets.CIFAR10(root=data_dir, train=True,download=True, transform=valid_transform,)num_train = len(train_dataset)indices = list(range(num_train))split = int(np.floor(valid_size * num_train))if shuffle:np.random.seed(random_seed)np.random.shuffle(indices)train_idx, valid_idx = indices[split:], indices[:split]train_sampler = SubsetRandomSampler(train_idx)valid_sampler = SubsetRandomSampler(valid_idx)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, sampler=valid_sampler)return (train_loader, valid_loader)def get_test_loader(data_dir,batch_size,shuffle=True):normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],)# define transformtransform = transforms.Compose([transforms.Resize((227, 227)),transforms.ToTensor(),normalize,])dataset = datasets.CIFAR10(root=data_dir, train=False,download=True, transform=transform,)data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)return data_loader# CIFAR10 dataset
CIFAR10_data = os.path.join(data_rootdir, "CIFAR10")
train_loader, valid_loader = get_train_valid_loader(data_dir= CIFAR10_data, batch_size = 64,augment = False, random_seed = 1)test_loader = get_test_loader(data_dir= CIFAR10_data,batch_size = 64)################################
# 定义 Model
################################
class AlexNet(nn.Module):def __init__(self, num_classes=10):super(AlexNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),nn.BatchNorm2d(96),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.layer2 = nn.Sequential(nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.layer3 = nn.Sequential(nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.layer4 = nn.Sequential(nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(384),nn.ReLU())self.layer5 = nn.Sequential(nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2))self.fc = nn.Sequential(nn.Dropout(0.5),nn.Linear(9216, 4096),nn.ReLU())self.fc1 = nn.Sequential(nn.Dropout(0.5),nn.Linear(4096, 4096),nn.ReLU())self.fc2= nn.Sequential(nn.Linear(4096, num_classes))def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)out = out.reshape(out.size(0), -1)out = self.fc(out)out = self.fc1(out)out = self.fc2(out)return out######################################
# Setting Hyperparameters
######################################
num_classes = 10
num_epochs = 20
batch_size = 64
learning_rate = 0.005model = AlexNet(num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9) # Train the model
total_step = len(train_loader)######################################
# Training
######################################
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader): # Move tensors to the configured deviceimages = images.to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item()))# Validationwith torch.no_grad():correct = 0total = 0for images, labels in valid_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()del images, labels, outputsprint('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total))# Now, we see how our model performs on unseen data
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()del images, labels, outputsprint('Accuracy of the network on the {} test images: {} %'.format(10000, 100 * correct / total))

实验结果
以上代码在 NVIDIA GeForce RTX 2050 WDDM 显存上训练和测试,大约花了半个小时时间。
最终,测试集上的准确率达到了 82.24% 。

Links
- Writing AlexNet from Scratch in PyTorch
- ImageNet Classification with Deep Convolutional
Neural Networks - Conv2d API in PyTorch
- 9年后重读深度学习奠基作之一:AlexNet【上】【论文精读】
相关文章:
使用 AlexNet 实现图片分类 | PyTorch 深度学习实战
前一篇文章,CNN 卷积神经网络处理图片任务 | PyTorch 深度学习实战 本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started 本篇文章内容来自于 强化学习必修课:引领人工智能新时代【梗直哥瞿炜】 使用 AlexNet 实现图片分类…...
Linux系统引导与服务管理
目录 一、Linux引导过程 1、引导过程概述 1.1、BIOS开机自检 1.2、MBR读取 1.3、加载引导加载程序(GRUB) 1.4、内核加载 1.5、初始化进程(init) 二、服务 2.1、服务类型 2.2、服务管理工具 三、运行级别 四、systemd …...
【Hadoop】大数据权限管理工具Ranger2.1.0编译
目录 编辑一、下载 ranger源码并编译 二、报错信息 报错1 报错2 报错3 报错4 一、下载 ranger源码并编译 ranger官网 https://ranger.apache.org/download.html 由于Ranger不提供二进制安装包,故需要maven编译。安装其它依赖: yum install gcc …...
宝珀(Blancpain):传承近三百年的机械制表传奇(中英双语)
宝珀(Blancpain):传承近三百年的机械制表传奇 在钟表行业中,宝珀(Blancpain) 作为世界上最古老的制表品牌,一直以其卓越的机械工艺、复杂功能腕表和对创新的坚持而闻名。自 1735 年成立以来&am…...
【Linux】Linux命令:crontab
目录 1、作用2、命令使用格式3、常用参数说明4、时程表4.1 格式4.2 常见问题处理 5、示例 1、作用 crontab命令用于对用户的时程表进行查看、删除、修改等操作。 用户的时程表是用于记录着要定期执行的程序。当安装完Linux操作系统启动后, cron服务会定期执行时程表…...
C++ 使用CURL开源库实现Http/Https的get/post请求进行字串和文件传输
CURL开源库介绍 CURL 是一个功能强大的开源库,用于在各种平台上进行网络数据传输。它支持众多的网络协议,像 HTTP、HTTPS、FTP、SMTP 等,能让开发者方便地在程序里实现与远程服务器的通信。 CURL 可以在 Windows、Linux、macOS 等多种操作系…...
浙江大华社招面试
下面是我之前社招面试大华时,面得是嵌入式Linux系统工程师,下面是我初试所被问到的问题分享给大家 毕业之后工作负责过哪些产品,工作负责哪些内容 Camera相关 1、调试sensor是多少像素 2、板子上怎么连接sensor 3、几LINE 4、每个LINE的data rate 是多少 ,单位是什么 5、图…...
多对多的增删改查
一 : 增 随机单号: /*** 文档就绪函数*/$(function () {//随机单号let number Math.floor(Math.random()*(9999-10001)1000);//取随机单号的值 固定格式输出$("#docNo").val(BSnumber);}) 开单日期: //处理开单日期$("#invoiceDate").val(new Date().to…...
vscode设置保存时自动缩进和格式化
参考博客 如何在 VSCode 中自动缩进你的代码 | Linux 中国 省流 使用 Ctrl Shift P 来打开命令模式,搜索 Open User Settings 并按下回车你需要搜索 Auto Indent,并在 “编辑器:自动缩进(Editor: Auto Indent)” 中选择 “全部(Full)”P…...
【练习】PAT 乙 1074 宇宙无敌加法器
题目 地球人习惯使用十进制数,并且默认一个数字的每一位都是十进制的。而在PAT星人开挂的世界里,每个数字的每一位都是不同进制的,这种神奇的数字称为“PAT数”。每个PAT星人都必须熟记各位数字的进制表,例如“……0527”就表示最…...
探店小程序:解锁商业新生态,定制未来
在数字化浪潮席卷全球的今天,商业的边界正在被重新定义。随着移动互联网技术的飞速发展,探店小程序作为一种新兴的商业模式,正以其独特的优势迅速成为连接商家与消费者的桥梁。我们刚刚为一家客户成功交付了一款集分销、分润、商业模式定制开…...
计算机视觉核心任务
1. 计算机视频重要分类 计算机视觉的重要任务可以大致分为以下几类: 1. 图像分类(Image Classification) 识别图像属于哪个类别,例如猫、狗、汽车等。 应用场景:物品识别、人脸识别、医疗影像分类。代表模型&#…...
【人工智能】如何在VSCode中使用DeepSeek?
文章目录 前言一、准备工作二、安装DeepSeek插件步骤1、扩展图标搜索DeepSeep2、安装DeepSeek插件3、使用测试DeepSeekBito文心一言 结论 前言 介绍在VSCode中调用DeepSeek插件工具,可以进行对话、编码。 一、准备工作 确保已经安装好了VSCode软件。 二、安装D…...
机器学习 - 进一步理解最大似然估计和高斯分布的关系
一、高斯分布得到的是一个概率吗? 高斯分布(也称为正态分布)描述的是随机变量在某范围内取值的概率分布情况。其概率密度函数(PDF)为: 其中,μ 是均值,σ 是标准差。 需要注意的是…...
Office/WPS接入DeepSeek等多个AI工具,开启办公新模式!
在现代职场中,Office办公套件已成为工作和学习的必备工具,其功能强大但复杂,熟练掌握需要系统的学习。为了简化操作,使每个人都能轻松使用各种功能,市场上涌现出各类办公插件。这些插件不仅提升了用户体验,…...
如何在Android Studio中开发一个简单的Android应用?
Android Studio是开发Android应用的官方集成开发环境(IDE),它提供了许多强大的功能,使得开发者能够高效地创建Android应用。如果你是Android开发的初学者,本文将引导你如何在Android Studio中开发一个简单的Android应用…...
第40天:Web开发-JS应用VueJS框架Vite构建启动打包渲染XSS源码泄露代码审计
#知识点 1、安全开发-VueJS-搭建启动&打包安全 2、安全开发-VueJS-源码泄漏&代码审计 一、Vue搭建创建项目启动项目 1、Vue 框架搭建->基于nodejs搭建,安装nodejs即可 参考:https://cn.vuejs.org/ 已安装18.3或更高版本的Node.js 2、Vue 创建…...
996引擎-问题处理:三职业改单职业
996引擎-问题处理:三职业改单职业 问题解决方案顺便补充点单性别设置补充:可视化配置表参考资料问题 目前的版本: 引擎版本号:2024.8.7.0 三端配套客户端:3.40.9 传统PC客户端:23.12.07 配套数据库:64_24.8.7.0此版本需要通过可视化配置表...
Lua语言的云计算
Lua语言在云计算中的应用 随着信息技术的迅猛发展,云计算已经成为现代计算的重要组成部分。云计算通过互联网将计算资源(如服务器、存储、数据库、网络等)进行动态调配和高效利用,极大地提高了资源利用率与开发效率。在众多编程语…...
[数据结构] Set的使用与注意事项
目录 Set的说明 常见方法说明 注意事项 TreeSet使用案例 Set的说明 Set与Map主要的不同有两点: Set是继承自Collection的接口类,Set中只存储了Key. 常见方法说明 方法解释boolean add(E e)添加元素,但重复元素不会被添加成功void clear()清空集合boolean contains(Object…...
安当SLA操作系统登录双因素认证:全方位保障Windows系统登录安全
一、产品概述 在当今数字化时代,Windows系统面临着诸多安全挑战,如弱口令问题等。安当SLA(System Login Agent)作为一款强大的双因素登录认证产品,通过支持OTP动态口令和USBKey硬件令牌认证,有效解决多种W…...
Java学习进阶路线
Java基础 Java Web 前端HTML/css/js,J2EE(Servlet/jsp),数据库(Mysql/oracle) Java开发框架 Spring MVC/Mybatis/Herbernate/maven 《Java编程思想》 深入了解java基础 Java设计模式 《Effective j…...
操作系统|ARM和X86的区别,存储,指令集
文章目录 主频寄存器寄存器在硬件中的体现是什么寄存器的基本特性硬件实现寄存器类型 内存和寄存器的区别内存(Memory)和磁盘(Disk)指令的执行ARM Cortex-M3与Thumb-2指令集Thumb-2 与流水线虚拟地址指令的执行 多核CPU芯片间的通…...
Mp4视频播放机无法播放视频-批量修改视频分辨率(帧宽、帧高)
背景 家人有一台夏新多功能 视频播放器(夏新多功能 视频播放器),用来播放广场舞。下载了一些广场舞视频, 只有部分视频可以播放,其他视频均无法播放,判断应该不是帧速率和数据速率的限制, 分析可能是播放器不支持帧高度大于720的视频。由于视频文件较多,需要借助视频编…...
日语学习-日语知识点小记-构建基础-JLPT-N4&N5阶段(2):どれ・どの・どんな :区别 等
日语学习-日语知识点小记-构建基础-JLPT-N4&N5阶段(2):どれ・どの・どんな :区别 等 1、前言(1)情况说明(2)工程师的信仰2、知识点(1)知识点な形容詞(けいようし) と い形容詞(けいようし):并列修饰(2)知识点どれ・どの・どんな :区别(3)知识点は &…...
【浏览器多开】Google Chrome 谷歌浏览器分身术
谷歌浏览器分身术(多开): 复制已有谷歌浏览器图标—>右键–>属性的目标栏中,添加 --user-data-dir自定义文件夹路径 参数。 例如: C:\MySpace\02Installed\Chrome\Chrome-bin\99.0.4844.51\chrome.exe –user-d…...
《LeetCode Hot100》 Day01
Day01 轮转数组 思路: (1) 使用O(1) 空间复杂度解决,就需要原地解决,不能创建新的数组。 (2) 先整体反转数组,再反转前k个数,再反转剩下的数。即可完整本题。 &…...
【图片合并转换PDF】如何将每个文件夹下的图片转化成PDF并合并成一个文件?下面基于C++的方式教你实现
医院在为患者进行诊断和治疗过程中,会产生大量的医学影像图片,如 X 光片、CT 扫描图、MRI 图像等。这些图片通常会按照检查时间或者检查项目存放在不同的文件夹中。为了方便医生查阅和患者病历的长期保存,需要将每个患者文件夹下的图片合并成…...
uniapp实现人脸识别(不使用三方插件)
uniapp实现人脸识别 内容简介功能实现上传身份证进行人脸比对 遇到的问题 内容简介 1.拍摄/相册将身份证照片上传到接口进行图片解析 2.使用live-pusher组件拍摄人脸照片,上传接口与身份证人脸进行比对 功能实现 上传身份证 先看下效果 点击按钮调用chooseImage…...
2025全新JSP简约博客平台-免费开源
前言 最近收到不少同学期末作业的需求,都还是JSP的老技术,介于现在很多网上可以找到的JSP现有项目,要么就是很老好几年前的,要么就是搞了一通不仅乱码还各自报错失败的,总之就是资源有限,于是我花了一星期…...
