AlexNet(pytorch)
AlexNet是2012年ISLVRC 2012(ImageNet Large Scale Visual Recognition Challenge)竞赛的冠军网络,分类准确率由传统的 70%+提升到 80%+
该网络的亮点在于:
(1)首次利用 GPU 进行网络加速训练。
(2)使用了 ReLU 激活函数,而不是传统的 Sigmoid 激活函数以及 Tanh 激活函数。
(3)使用了 LRN 局部响应归一化。
(4)在全连接层的前两层中使用了 Dropout 随机失活神经元操作,以减少过拟合。
模型:

模型参数表:

model.py
import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27]: (55-3+0)/4 + 1=27nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1)x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)
train.py
import os
import sys
import jsonimport torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdmfrom model import AlexNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))#前期的网络还是用的Normalize标准化,之后的网络会用到BN批标准化data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../../")) # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data") # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)#注意这里的数据加载还是直接用的torchvision.datasets.ImageFolder加载,#并不需要定义数据加载的脚本,可能是数据比较简单吧#定义数据集时候直接定义数据处理方法,之后torch.utils.data.DataLoader加载数据集加载时候直接调用这里定义的数据处理参数的方法#train文件夹下还有五种花的文件夹,这个具体处理看下面的代码,可能是ImageFolder直接加载文件夹里的图片文件train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])#训练集图片的个数train_num = len(train_dataset)#train_dataset.class_to_idx 是一个字典,将类别名称映射到相应的索引。#下行注释就是flower_list具体内容# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}# cla_dict是一个反转字典,将原始字典 flower_list 的键和值进行交换flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# json.dumps() 将 cla_dict 转换为格式化的 JSON 字符串。# 最后,将 JSON 字符串写入名为 class_indices.json 的文件中# indent 参数表示有几类json_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32#这个代码片段的目的是为了确定在并行计算时使用的最大工作进程数,并确保不超过系统的逻辑 CPU 核心数量和其他限制nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=4, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()## def imshow(img):# img = img / 2 + 0.5 # unnormalize# npimg = img.numpy()# plt.imshow(np.transpose(npimg, (1, 2, 0)))# plt.show()## print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))# imshow(utils.make_grid(test_image))net = AlexNet(num_classes=5, init_weights=True)net.to(device)loss_function = nn.CrossEntropyLoss()# pata = list(net.parameters())optimizer = optim.Adam(net.parameters(), lr=0.0002)epochs = 10save_path = './AlexNet.pth'best_acc = 0.0#一个epoch训练多少批次的数据,一批数据32个CWH,即32张图片train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0#这段代码使用了 tqdm 库来创建一个进度条,用于迭代训练数据集 train_loader 中的批次数据#file=sys.stdout 的作用是将进度条的输出定向到标准输出流,即将进度条显示在终端窗口中train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()#更新进度条的描述信息,显示当前训练的轮数、总轮数和损失值#这个loss是批次损失,在进度条上显示出来train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# 验证是训练完一个epoch后进行在验证集上验证,验证准确率net.eval()acc = 0.0 # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:#val_bar 的类型是 tqdm.tqdm,它是 tqdm 库中的一个类。该类提供了迭代器的功能,# 可以用于包装迭代器对象,并在循环中显示进度条和相关信息val_images, val_labels = val_dataoutputs = net(val_images.to(device)) #outputs:[batch_size,num_classes]predict_y = torch.max(outputs, dim=1)[1] #torch.max 返回的第一个元素是张量数值,第二个是对应的索引acc += torch.eq(predict_y, val_labels.to(device)).sum().item()#验证完后计算验证集里所有的正确个数/总个数val_accurate = acc / val_num#总损失/训练总批次,求得平均每批的损失print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()
训练过程:
using cuda:0 device.
Using 8 dataloader workers every process
using 3306 images for training, 364 images for validation.
train epoch[1/10] loss:1.215: 100%|██████████| 104/104 [00:23<00:00, 4.38it/s]
100%|██████████| 91/91 [00:15<00:00, 5.73it/s]
[epoch 1] train_loss: 1.342 val_accuracy: 0.478
train epoch[2/10] loss:1.111: 100%|██████████| 104/104 [00:19<00:00, 5.30it/s]
100%|██████████| 91/91 [00:15<00:00, 5.75it/s]
[epoch 2] train_loss: 1.183 val_accuracy: 0.533
train epoch[3/10] loss:1.252: 100%|██████████| 104/104 [00:19<00:00, 5.30it/s]
100%|██████████| 91/91 [00:15<00:00, 5.75it/s]
[epoch 3] train_loss: 1.097 val_accuracy: 0.604
train epoch[4/10] loss:0.730: 100%|██████████| 104/104 [00:19<00:00, 5.32it/s]
100%|██████████| 91/91 [00:15<00:00, 5.74it/s]
[epoch 4] train_loss: 1.025 val_accuracy: 0.607
train epoch[5/10] loss:0.961: 100%|██████████| 104/104 [00:19<00:00, 5.28it/s]
100%|██████████| 91/91 [00:16<00:00, 5.65it/s]
[epoch 5] train_loss: 0.941 val_accuracy: 0.676
train epoch[6/10] loss:0.853: 100%|██████████| 104/104 [00:19<00:00, 5.31it/s]
100%|██████████| 91/91 [00:15<00:00, 5.82it/s]
[epoch 6] train_loss: 0.915 val_accuracy: 0.659
train epoch[7/10] loss:1.032: 100%|██████████| 104/104 [00:19<00:00, 5.34it/s]
100%|██████████| 91/91 [00:15<00:00, 5.82it/s]
[epoch 7] train_loss: 0.864 val_accuracy: 0.684
train epoch[8/10] loss:0.704: 100%|██████████| 104/104 [00:19<00:00, 5.32it/s]
100%|██████████| 91/91 [00:15<00:00, 5.80it/s]
[epoch 8] train_loss: 0.842 val_accuracy: 0.706
train epoch[9/10] loss:1.279: 100%|██████████| 104/104 [00:19<00:00, 5.30it/s]
100%|██████████| 91/91 [00:15<00:00, 5.83it/s]
[epoch 9] train_loss: 0.825 val_accuracy: 0.714
train epoch[10/10] loss:0.796: 100%|██████████| 104/104 [00:19<00:00, 5.31it/s]
100%|██████████| 91/91 [00:15<00:00, 5.82it/s]
[epoch 10] train_loss: 0.801 val_accuracy: 0.703
Finished TrainingProcess finished with exit code 0
predict.py:
import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import AlexNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = "./test.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = AlexNet(num_classes=5).to(device)# load model weightsweights_path = "./AlexNet.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)#torch.load() 函数会根据路径加载模型的权重,并返回一个包含模型参数的字典#load_state_dict() 函数将加载的模型参数字典应用到 model 中,从而将预训练模型的参数加载到 model 中model.load_state_dict(torch.load(weights_path))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10} prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()
预测结果:
我感觉pycharm的plt显示并不是特别明了

class: daisy prob: 4.2e-06
class: dandelion prob: 9.61e-07
class: roses prob: 0.000773
class: sunflowers prob: 1.28e-05
class: tulips prob: 0.999
相关文章:
AlexNet(pytorch)
AlexNet是2012年ISLVRC 2012(ImageNet Large Scale Visual Recognition Challenge)竞赛的冠军网络,分类准确率由传统的 70%提升到 80% 该网络的亮点在于: (1)首次利用 GPU 进行网络加速训练。 ÿ…...
【单调栈 】LeetCode321:拼接最大数
作者推荐 【动态规划】【广度优先搜索】LeetCode:2617 网格图中最少访问的格子数 本文涉及的知识点 单调栈 题目 给定长度分别为 m 和 n 的两个数组,其元素由 0-9 构成,表示两个自然数各位上的数字。现在从这两个数组中选出 k (k < m n) 个数字…...
TikTok与虚拟现实的完美交融:全新娱乐时代的开启
TikTok,这个风靡全球的短视频平台,与虚拟现实(VR)技术的深度结合,为用户呈现了一场全新的娱乐盛宴。虚拟现实技术为TikTok带来了更丰富、更沉浸的用户体验,标志着全新娱乐时代的开启。本文将深入探讨TikTok…...
PXI/PCIe/VPX机箱 ARM|x86 + FPGA测试测量板卡解决方案
PXI便携式测控系统是一种基于PXI总线的便携式测试测控系统,它填补了现有台式及机架式仪器在外场测控和便携测控应用上的空白,在军工国防、航空航天、兵器电子、船舶舰载等各个领域的外场测控场合和科学试验研究场合都有广泛的应用。由于PXI便携式测控系统…...
ES6 面试题 | 12.精选 ES6 面试题
🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…...
【linux】Debian不能运行sudo的解决
一、问题: sudo: 没有找到有效的 sudoers 资源,退出 sudo: 初始化审计插件 sudoers_audit 出错 二、可用的方法: 出现 "sudo: 没有找到有效的 sudoers 资源,退出" 和 "sudo: 初始化审计插件 sudoers_audit 出错&q…...
讲解ThinkPHP的链式操作
数据库提供的链式操作方法,可以有效的提高数据存取的代码清晰度和开发效率,并且支持所有的CURD操作。 使用也比较简单,假如我们现在要查询一个User表的满足状态为1的前10条记录,并希望按照用户的创建时间排序 Db::table(think_u…...
Java技术栈 —— 微服务框架Spring Cloud —— Ruoyi-Cloud 学习(二)
RuoYi项目开发过程 一、登录功能(鉴权模块)1.1 后端部分1.1.1 什么是JWT?1.1.2 什么是Base64?为什么需要它?1.1.3 SpringBoot注解解析1.1.4 依赖注入和控制反转1.1.5 什么是Restful?1.1.6 Log4j 2、Logpack、SLF4j日志框架1.1.7 如何将项目打包成指定bytecode字节…...
如何进行软件测试和测试驱动开发(TDD)?
1. 软件测试概述 1.1 什么是软件测试? 软件测试是一种评估系统的过程,目的是发现潜在的错误或缺陷。通过对软件进行测试,开发者和测试人员可以确定软件是否符合预期的需求、功能是否正常运行,以及系统是否足够稳定和可靠。 1.2…...
linux 开机启动流程
1.打开电源 2.BIOS 有时间和启动方式 3.启动Systemd 其pid为1 4.挂载引导分区 /boot 5.启动各种服务 如rc.local...
Mybatis 动态SQL的插入操作
需求 : 根据用户的输入情况进行插入 动态SQL:根据需求动态拼接SQL 用户往表中插入数据,有的数据可能不想插入,比如不想让别人知道自己的性别,性别就为空 insert into userinfo(username,password,age,gender,phone) values(?,?,?,?,?); insert into userinfo(username,…...
共建开源新里程:北京航空航天大学OpenHarmony技术俱乐部正式揭牌成立
12月11日,由OpenAtom OpenHarmony(以下简称“OpenHarmony”)项目群技术指导委员会(以下简称“TSC”)和北京航空航天大学共同举办的“OpenHarmony软件工程研讨会暨北京航空航天大学OpenHarmony技术俱乐部成立仪式”在京圆满落幕。 现场大合影 活动当天,多位重量级嘉宾出席了此次…...
企业微信机器人发送文本、图片、文件、markdown、图文信息
import requests import base64 import hashlib import json # 机器人地址的key值 key"811a1652-60e8-4f51-a1d9-231783399ad2" def path2base64(path):"""文件转换为base64:param path: 文件路径:return:"""with open(path, "rb…...
智能优化算法应用:基于天牛须算法3D无线传感器网络(WSN)覆盖优化 - 附代码
智能优化算法应用:基于天牛须算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于天牛须算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.天牛须算法4.实验参数设定5.算法结果6.参考文…...
【Hive】【Hadoop】工作中常操作的笔记-随时添加
文章目录 1、Hive 复制一个表:2、字段级操作3、hdfs 文件统计 1、Hive 复制一个表: 直接Copy文件 create table new_table like table_name;hdfs dfs -get /apps/hive/warehouse/ods.db/table_nameload data local inpath /路径 into table new_table;修复表: m…...
DIY电脑装机机箱风扇安装方法
作为第一次自己diy一台电脑主机的我,在经历了众多的坑中今天来说一下如何安装机箱风扇的问题 一、风扇的数量 1、i3 xx50显卡 就用一个cpu散热风扇即可 2、i5 xx60 一个cpu散热风扇 一个风扇即可 3、i7 xx70 一个cpu散热 4个风扇即可 4、i9 xx80 就需要7个以…...
基础算法(4):排序(4)冒泡排序
1.冒泡排序(BubbleSort)实现 算法步骤:比较相邻的元素。如果第一个比第二个大,就交换。 对每一对相邻元素作同样的工作,从开始第一对到结尾的最后一对。 这步做完后,最后的元素会是最大的数。 针对所有的元素重复以上的步骤&#…...
鸿蒙开发之网络请求
//需要导入http头文件 import http from ohos.net.http//请求地址url: string http://apis.juhe.cn/simpleWeather/queryText(this.message).maxFontSize(50).minFontSize(10).fontWeight(FontWeight.Bold).onClick(() > {console.log(请求开始)let req http.createHttp()…...
PrimDiffusion:3D 人类生成的体积基元扩散模型NeurIPS 2023
NeurIPS2023 ,这是一种用于 3D 人体生成的体积基元扩散模型,可通过离体拓扑实现明确的姿势、视图和形状控制。 PrimDiffusion 对一组紧凑地代表 3D 人体的基元执行扩散和去噪过程。这种生成建模可以实现明确的姿势、视图和形状控制,并能够在…...
时序预测 | Python实现LSTM-Attention-XGBoost组合模型电力需求预测
时序预测 | Python实现LSTM-Attention-XGBoost组合模型电力需求预测 目录 时序预测 | Python实现LSTM-Attention-XGBoost组合模型电力需求预测预测效果基本描述程序设计参考资料预测效果 基本描述 该数据集因其每小时的用电量数据以及 TSO 对消耗和定价的相应预测而值得注意,从…...
51c自动驾驶~合集58
我自己的原文哦~ https://blog.51cto.com/whaosoft/13967107 #CCA-Attention 全局池化局部保留,CCA-Attention为LLM长文本建模带来突破性进展 琶洲实验室、华南理工大学联合推出关键上下文感知注意力机制(CCA-Attention),…...
突破不可导策略的训练难题:零阶优化与强化学习的深度嵌合
强化学习(Reinforcement Learning, RL)是工业领域智能控制的重要方法。它的基本原理是将最优控制问题建模为马尔可夫决策过程,然后使用强化学习的Actor-Critic机制(中文译作“知行互动”机制),逐步迭代求解…...
YSYX学习记录(八)
C语言,练习0: 先创建一个文件夹,我用的是物理机: 安装build-essential 练习1: 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件,随机修改或删除一部分,之后…...
Golang dig框架与GraphQL的完美结合
将 Go 的 Dig 依赖注入框架与 GraphQL 结合使用,可以显著提升应用程序的可维护性、可测试性以及灵活性。 Dig 是一个强大的依赖注入容器,能够帮助开发者更好地管理复杂的依赖关系,而 GraphQL 则是一种用于 API 的查询语言,能够提…...
从零实现STL哈希容器:unordered_map/unordered_set封装详解
本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说,直接开始吧! 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...
管理学院权限管理系统开发总结
文章目录 🎓 管理学院权限管理系统开发总结 - 现代化Web应用实践之路📝 项目概述🏗️ 技术架构设计后端技术栈前端技术栈 💡 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 🗄️ 数据库设…...
vulnyx Blogger writeup
信息收集 arp-scan nmap 获取userFlag 上web看看 一个默认的页面,gobuster扫一下目录 可以看到扫出的目录中得到了一个有价值的目录/wordpress,说明目标所使用的cms是wordpress,访问http://192.168.43.213/wordpress/然后查看源码能看到 这…...
redis和redission的区别
Redis 和 Redisson 是两个密切相关但又本质不同的技术,它们扮演着完全不同的角色: Redis: 内存数据库/数据结构存储 本质: 它是一个开源的、高性能的、基于内存的 键值存储数据库。它也可以将数据持久化到磁盘。 核心功能: 提供丰…...
Ubuntu系统多网卡多相机IP设置方法
目录 1、硬件情况 2、如何设置网卡和相机IP 2.1 万兆网卡连接交换机,交换机再连相机 2.1.1 网卡设置 2.1.2 相机设置 2.3 万兆网卡直连相机 1、硬件情况 2个网卡n个相机 电脑系统信息,系统版本:Ubuntu22.04.5 LTS;内核版本…...
一些实用的chrome扩展0x01
简介 浏览器扩展程序有助于自动化任务、查找隐藏的漏洞、隐藏自身痕迹。以下列出了一些必备扩展程序,无论是测试应用程序、搜寻漏洞还是收集情报,它们都能提升工作流程。 FoxyProxy 代理管理工具,此扩展简化了使用代理(如 Burp…...
