当前位置: 首页 > article >正文

ResNet-50——pytorch版

声明本文为365天深度学习训练营中的学习记录博客原作者K同学啊先验知识ResNet残差网络根据网络层数可以分为ResNet-18、ResNet-34、ResNet-50、ResNet-101等他和普通的CNN网络不同的地方就是他提出了一个残差的概念解决了卷积网络在深度加深时候的梯度爆炸、梯度消失的问题梯度消失在反向传播的过程中随着网络层数的增加前几层的梯度值会变得非常小接近于0梯度爆炸在反向传播过程中随着网络层数增加梯度值变得非常大导致网络权重更新幅度过大模型无法收敛BN层的提出虽然在一定情况下解决了这个问题但是加了BN会带来网络变得更复杂更不容易收敛的影响。故何凯明证明了一个问题就是只要有合适的网络结构深的网络肯定比浅的好残差网络孕育而生。主要使用的残差单元有这两种分别是两层的浅残差和3层的深残差。我的环境Python版本3.8.10PyTorch版本2.4.1cpuTorchvision版本0.19.1cpu学习记录由于是刚回归Pytorch的第一篇我会尽量讲细一点整体流程跟tensorflow差不多的都是初始化GPU,数据集划分处理网络选择训练测试函数撰写然后正式训练最后成果可视化1.设置GPUdevice torch.device(cuda if torch.cuda.is_available() else cpu)2.数据导入在数据导入的时候我们要设置一个transform来对数据进行处理train_transforms transforms.Compose([ #尺寸调节 transforms.Resize((224, 224)), #totensor类型 transforms.ToTensor(), #归一化 transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ])pytorch在使用时要记得将图片变成tensor类型然后就可以调用datasets.ImageFolder来对其进行图片批量处理了这个函数还会生成dataset格式就是将每个图像根据文件名自动生成标签total_dataset datasets.ImageFolder(data_dir, transformtrain_transforms)后面经过数据集比例划分就可以调用DataLoader来生产测试集和训练集了train_size int(0.8 * len(total_dataset)) test_size len(total_dataset) - train_size train_dataset, test_dataset torch.utils.data.random_split(total_dataset, [train_size, test_size]) batch_size 4 train_dl torch.utils.data.DataLoader(train_dataset, batch_sizebatch_size, shuffleTrue) test_dl torch.utils.data.DataLoader(test_dataset, batch_sizebatch_size, shuffleFalse)3.网络选择这次我们选的是自建ResNet-50网络如图所示因为中间反复使用卷积块和恒等块故我们需要先将他们俩给定义了恒等块三层卷积后与原始值相加后通过一个激活函数class IdentityBlock(nn.Module): def __init__(self, in_channels, filters, kernel_size): super(IdentityBlock, self).__init__() #filters输入是个数组 f1, f2, f3 filters self.conv1 nn.Sequential( #biasFalse, 不使用偏执函数有批归一化了 nn.Conv2d(in_channels, f1, kernel_size1, stride1, padding0, biasFalse), nn.BatchNorm2d(f1), #节约内存 nn.ReLU(inplaceTrue) ) self.conv2 nn.Sequential( nn.Conv2d(f1, f2, kernel_sizekernel_size, stride1, paddingsame, biasFalse), nn.BatchNorm2d(f2), nn.ReLU(inplaceTrue) ) self.conv3 nn.Sequential( nn.Conv2d(f2, f3, kernel_size1, stride1, padding0, biasFalse), nn.BatchNorm2d(f3) ) #先加后激活保留通路的线性特征 self.relu nn.ReLU(inplaceTrue) def forward(self, x): identity x out self.conv1(x) out self.conv2(out) out self.conv3(out) #先加后激活 out identity out self.relu(out) return out卷积块一条路卷三次一条路卷一次最后相加后激活class ConvBlock(nn.Module): def __init__(self, in_channels, filters, kernel_size, stride2): super(ConvBlock, self).__init__() f1, f2, f3 filters self.conv1 nn.Sequential( nn.Conv2d(in_channels, f1, kernel_size1, stridestride, padding0, biasFalse), nn.BatchNorm2d(f1), nn.ReLU(inplaceTrue) ) self.conv2 nn.Sequential( nn.Conv2d(f1, f2, kernel_sizekernel_size, stride1, paddingsame, biasFalse), nn.BatchNorm2d(f2), nn.ReLU(inplaceTrue) ) self.conv3 nn.Sequential( nn.Conv2d(f2, f3, kernel_size1, stride1, padding0, biasFalse), nn.BatchNorm2d(f3) ) self.shortcut nn.Sequential( nn.Conv2d(in_channels, f3, kernel_size1, stridestride, padding0, biasFalse), nn.BatchNorm2d(f3) ) self.relu nn.ReLU(inplaceTrue) def forward(self, x): identity x out self.conv1(x) out self.conv2(out) out self.conv3(out) shortcut self.shortcut(identity) out shortcut out self.relu(out) return out接下来就可以写resnet-50了class ResNet50(nn.Module): def __init__(self, num_classes3): super(ResNet50, self).__init__() self.conv1 nn.Sequential( nn.Conv2d(3, 64, kernel_size7, stride2, padding3, biasFalse, padding_modezeros), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size3, stride2, padding1) ) self.conv2 nn.Sequential( ConvBlock(64, [64, 64, 256], kernel_size3, stride1), IdentityBlock(256, [64, 64, 256], kernel_size3), IdentityBlock(256, [64, 64, 256], kernel_size3) ) self.conv3 nn.Sequential( ConvBlock(256, [128, 128, 512], kernel_size3), IdentityBlock(512, [128, 128, 512], kernel_size3), IdentityBlock(512, [128, 128, 512], kernel_size3), IdentityBlock(512, [128, 128, 512], kernel_size3) ) self.conv4 nn.Sequential( ConvBlock(512, [256, 256, 1024], kernel_size3), IdentityBlock(1024, [256, 256, 1024], kernel_size3), IdentityBlock(1024, [256, 256, 1024], kernel_size3), IdentityBlock(1024, [256, 256, 1024], kernel_size3), IdentityBlock(1024, [256, 256, 1024], kernel_size3), IdentityBlock(1024, [256, 256, 1024], kernel_size3) ) self.conv5 nn.Sequential( ConvBlock(1024, [512, 512, 2048], kernel_size3), IdentityBlock(2048, [512, 512, 2048], kernel_size3), IdentityBlock(2048, [512, 512, 2048], kernel_size3) ) self.avgpool nn.AvgPool2d(kernel_size7, stride7, padding0) self.fc nn.Linear(2048, num_classes) def forward(self, x): x self.conv1(x) x self.conv2(x) x self.conv3(x) x self.conv4(x) x self.conv5(x) x self.avgpool(x) x torch.flatten(x, 1) x self.fc(x) return x model ResNet50(num_classes3).to(device)看一眼参数吧import torchsummary as summary summary.summary(model, (3, 224, 224))50层是指卷积层开头的和结尾的fc层4.训练测试函数训练测试函数核心在于前向传播与反向传播和todevicedef train(dataloader, model, loss_fn, optimizer): size len(dataloader.dataset) num_batches len(dataloader) train_loss, correct 0, 0 for X, y in dataloader: X, y X.to(device), y.to(device) # 前向传播 pred model(X) loss loss_fn(pred, y) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() train_loss loss.item() correct (pred.argmax(1) y).type(torch.float).sum().item() train_loss / num_batches correct / size return train_loss, correctdef test(dataloader, model, loss_fn): size len(dataloader.dataset) num_batches len(dataloader) test_loss, correct 0, 0 with torch.no_grad(): for X, y in dataloader: X, y X.to(device), y.to(device) pred model(X) test_loss loss_fn(pred, y).item() correct (pred.argmax(1) y).type(torch.float).sum().item() test_loss / num_batches correct / size return test_loss, correct5.开始训练选择后优化器和损失函数就可以开始训练了import copy optimizer torch.optim.AdamW(model.parameters(), lr0.0001) loss_fn nn.CrossEntropyLoss() num_epochs 10 train_loss_history [] train_acc_history [] test_loss_history [] test_acc_history [] best_acc 0.0 for epoch in range(num_epochs): model.train() train_loss, train_acc train(train_dl, model, loss_fn, optimizer) train_loss_history.append(train_loss) train_acc_history.append(train_acc) model.eval() test_loss, test_acc test(test_dl, model, loss_fn) test_loss_history.append(test_loss) test_acc_history.append(test_acc) if test_acc best_acc: best_acc test_acc best_model_wts copy.deepcopy(model.state_dict()) lr optimizer.state_dict()[param_groups][0][lr] template Epoch [{}/{}], LR: {:.6f}, Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f} print(template.format(epoch1, num_epochs, lr, train_loss, train_acc, test_loss, test_acc)) PATH ../model/resnet50.pth torch.save(best_model_wts, PATH) print(DONE)6.数据可视化import matplotlib.pyplot as plt import warnings warnings.filterwarnings(ignore) plt.rcParams[font.sans-serif] [SimHei] plt.rcParams[axes.unicode_minus] False plt.rcParams[figure.dpi] 100 from datetime import datetime current_time datetime.now().strftime(%Y-%m-%d %H:%M:%S) epochs range(1, num_epochs 1) plt.figure(figsize(12, 3)) plt.subplot(1, 2, 1) plt.plot(epochs, train_loss_history, label训练损失) plt.plot(epochs, test_loss_history, label测试损失) plt.title(训练和测试损失) plt.xlabel(current_time) plt.subplot(1, 2, 2) plt.plot(epochs, train_acc_history, label训练准确率) plt.plot(epochs, test_acc_history, label测试准确率) plt.title(训练和测试准确率) plt.xlabel(current_time) plt.legend() plt.show()总结这一节我们重新回到了Pytorch环境进行训练在该环境下可以调用OpenCV等图像处理库。并且了解了ResNet 这个经典的CNN网络结构并完成了他的搭建与训练。

相关文章:

ResNet-50——pytorch版

声明: 🍨 本文为🔗365天深度学习训练营中的学习记录博客🍖 原作者:K同学啊 先验知识: ResNet残差网络,根据网络层数可以分为(ResNet-18、ResNet-34、ResNet-50、ResNet-101等&…...

保姆级教程:用RV1126开发板和RKISP Tuner搞定ISP黑电平(BLC)校准(附避坑指南)

RV1126开发板ISP黑电平校准实战指南:从原理到避坑全解析 当你第一次拿到RV1126开发板,准备调试图像质量时,黑电平校准(BLC)往往是第一个需要攻克的难关。作为ISP处理流水线的第一道工序,BLC校准的质量直接影响后续所有图像处理效果…...

农村的爸爸拉肚子多年,幸好有它的出现

#东海阿泰宁#基石菌酪酸梭菌#肠易激...

AI时代工程师的超级进化论

AI时代工程师的Superpowers进化论技术文章大纲技术背景与趋势AI对传统工程领域的冲击与重构工程师核心能力的变迁:从编码到系统设计数据驱动与自动化工具对生产力的解放Superpowers 1:数据思维与AI协作能力数据敏感度:从业务需求到数据建模的…...

2026年电子商务论文降AI工具推荐:用户行为分析和商业模式部分

2026年电子商务论文降AI工具推荐:用户行为分析和商业模式部分 在知乎看了很多帖子,在论坛翻了很多评测,最后用的是嘎嘎降AI(www.aigcleaner.com)。 价格4.8元一篇,实测知网从67%降到6%。电子商务论文降AI…...

【Hermes系列7】我把 Hermes 接入了 Jenkins:回归测试从 3 天到 30 分钟

01 这是 Hermes 系列的第 7 篇,也是企业落地关键篇。前 6 篇我们解决了:本地跑通、场景实战、工程化。但真实企业里,还有一个绕不开的问题:你本地跑得再好,怎么让团队每个人都用上?怎么保证每天按时执行&a…...

Linux CFS 的 nr_switches:上下文切换次数统计

简介在Linux内核的进程调度体系中,完全公平调度器(Completely Fair Scheduler, CFS)自2.6.23版本引入以来,一直是通用操作系统环境下的默认调度策略。对于从事系统性能优化、容器化资源管控或实时系统设计的工程师而言&#xff0c…...

基于Python的网购平台管理系统毕业设计

博主介绍:✌ 专注于Java,python,✌关注✌私信我✌具体的问题,我会尽力帮助你。一、研究目的本研究旨在设计并实现一个基于Python的网购平台管理系统,以满足现代电子商务环境下对高效、安全、便捷的网购体验的需求。具体研究目的如下&#xff…...

某上市炼化企业人才培养及引进成功案例纪实

某上市炼化企业人才培养及引进成功案例纪实——从“熬年限”到“凭能力”,以人才机制创新支撑战略转型【客户行业】炼化行业;民营企业【问题类型】人才引进;梯队建设【客户背景】该企业是国内领先的民营炼化一体化企业,业务涵盖原…...

基于Python的影城会员管理系统

博主介绍:✌ 专注于Java,python,✌关注✌私信我✌具体的问题,我会尽力帮助你。一、研究目的本研究旨在设计并实现一套基于Python的影城会员管理系统,以满足现代影城在会员管理方面的需求。具体研究目的如下: 首先,通过…...

告别玄学调试:用J-Flash给STM32芯片“洗个澡”,解决RT-Thread Studio下载疑难杂症

嵌入式开发实战:用J-Flash彻底解决STM32下载异常问题 当你满怀期待地点击"下载"按钮,RT-Thread Studio却无情地显示"执行完毕"而板子毫无反应时,那种挫败感每个嵌入式开发者都深有体会。更令人抓狂的是,编译器…...

从SVM到凸优化:对偶问题的数学之美

1. 从SVM到凸优化:理解对偶问题的必要性 第一次接触支持向量机(SVM)时,很多人都会被其中复杂的数学推导劝退。特别是当算法从原始问题转换到对偶问题时,总会有种"为什么要绕这么大圈子"的困惑。我在教学过程…...

Kotlin的Flow背压策略:Buffer、Conflate、Drop对比

Kotlin的Flow背压策略:Buffer、Conflate、Drop对比 在异步数据流处理中,背压(Backpressure)是一个常见问题,即生产者的数据生成速度超过消费者的处理能力。Kotlin的Flow提供了三种背压策略:Buffer、Confla…...

基于STM32与VS1053的智能音乐播放器设计与实现

1. 项目背景与核心功能 每次在地铁上看到有人用复古MP3听歌,我都会想起学生时代攒钱买的第一台音乐播放器。如今虽然手机听歌很方便,但自己动手做一个能解码多种格式的智能音乐播放器,依然是电子爱好者心中的"白月光"。这次我们要用…...

国产IDE崛起?实测MounRiver Studio:用它开发CH32V103/CH32F103全流程(附串口调试技巧)

国产IDE实战评测:MounRiver Studio开发RISC-V/ARM双核MCU全指南 第一次接触MounRiver Studio(MRS)是在一个嵌入式技术交流群,几位同行对这款国产IDE的评价褒贬不一。作为长期使用Keil和IAR的开发者,我对"国产IDE能…...

2026年3月 GESP CCF编程能力等级认证图形化编程一级真题

答案和更多内容请查看网站:【试卷中心 -----> CCF GESP ----> 图形化/Scratch ----> 一级】 网站链接 青少年软件编程历年真题模拟题实时更新 GESP CCF编程能力等级认证 图形化/Scratch一级真题 一、单选题 1. 在2026年春晚的《武BOT》节目中&#…...

多模态游戏AI不是升级,是重定义:2026奇点大会发布的《实时语义-物理耦合引擎》标准草案(全球首次公开)

第一章:多模态游戏AI不是升级,是重定义 2026奇点智能技术大会(https://ml-summit.org) 传统游戏AI长期依赖预设规则与有限状态机(FSM),或基于单一模态(如数值化行为树)进行决策。而多模态游戏A…...

破解Google SynthID:AI水印逆向工程

这是一个非常有趣且具有技术深度的项目。基于你提供的 GitHub 项目地址,reverse-SynthID 是一个旨在“逆向工程” Google SynthID 水印技术的开源尝试。 简单来说,它试图解决一个核心问题:如果 AI 生成的图片被植入了肉眼不可见的水印&#x…...

WebToEpub:5分钟免费将网页小说转为EPUB电子书的终极指南

WebToEpub:5分钟免费将网页小说转为EPUB电子书的终极指南 【免费下载链接】WebToEpub A simple Chrome (and Firefox) Extension that converts Web Novels (and other web pages) into an EPUB. 项目地址: https://gitcode.com/gh_mirrors/we/WebToEpub 还在…...

如何永久保存微信聊天记录?终极免费工具使用指南

如何永久保存微信聊天记录?终极免费工具使用指南 【免费下载链接】WeChatMsg 提取微信聊天记录,将其导出成HTML、Word、CSV文档永久保存,对聊天记录进行分析生成年度聊天报告 项目地址: https://gitcode.com/GitHub_Trending/we/WeChatMsg …...

Python 自动化办公:批量提取 Excel 表格中的特定数据

在日常办公中,我们常常会遇到需要从大量 Excel 表格中提取特定数据的情况。手动操作不仅效率低下,还容易出错。借助 Python 强大的库,我们可以轻松实现自动化提取,提高工作效率。需求分析 假设我们有一个包含多个 Excel 文件的文件…...

AEUX终极指南:5分钟掌握Figma/Sketch到After Effects的无缝转换

AEUX终极指南:5分钟掌握Figma/Sketch到After Effects的无缝转换 【免费下载链接】AEUX Editable After Effects layers from Sketch artboards 项目地址: https://gitcode.com/gh_mirrors/ae/AEUX 如果你是一名UI/UX设计师或动效设计师,一定经历过…...

Mac长期连移动硬盘,修改这4个关键设置,避免伤盘

很多人用Mac时,会长期外接移动硬盘存资料、剪视频或者做备份,觉得插着不拔很方便。但其实macOS默认的不少设置,长期下来会悄悄损耗硬盘,轻则频繁掉线、读写变慢,重则直接坏道、数据丢失。 今天就结合2026年macOS最新系…...

多模态大模型容灾备份策略(NASA级冗余设计白皮书首次公开)

第一章:多模态大模型容灾备份策略 2026奇点智能技术大会(https://ml-summit.org) 多模态大模型(如LLaVA-X、Qwen-VL、Fuyu-8B)在训练与推理阶段依赖海量参数、跨模态对齐权重及动态缓存状态,其容灾备份需超越传统单模态模型的快照…...

3个实用技巧快速解决城通网盘下载限速问题

3个实用技巧快速解决城通网盘下载限速问题 【免费下载链接】ctfileGet 获取城通网盘一次性直连地址 项目地址: https://gitcode.com/gh_mirrors/ct/ctfileGet 你是否曾经为了下载城通网盘上的文件而苦苦等待?面对几十KB/s的下载速度,看着进度条缓…...

大模型发展史

人工智能是一场跨越数十年、由一系列关键突破所驱动的波澜壮阔的史诗。回顾其历程,我们可以清晰地看到三个特征鲜明的阶段,每一阶段都以前一阶段的理论和实践为基础,最终引爆了今天我们所见到的AI革命。一、 萌芽期(1950-2005&…...

乐高与众球星共同庆祝足球的魅力

乐高集团携手克里斯蒂亚诺罗纳尔多、基利安姆巴佩、莱昂内尔梅西和维尼修斯儒尼奥尔等足球明星,与世界各地的孩子和家庭一同庆祝足球的魅力——因为每个人都想参与其中!随着 2026 年国际足联世界杯日益临近,足球热潮空前高涨,球迷…...

C#怎么操作WPF样式和模板 C#如何用WPF Style和ControlTemplate自定义控件外观【控件】

Style负责统一设置控件属性值,ControlTemplate决定控件结构与视觉树;混淆二者是80%样式失效主因,如Style中Template不生效、Background被覆盖、Trigger导致控件消失等。WPF里Style和ControlTemplate到底该谁管什么Style负责统一设置控件的属性…...

如果你很懒,那这种一定很适合你:CSGO游戏搬砖,不需要玩游戏就能赚钱

最近好几个朋友问我:现在有什么靠谱的副业?不要太累,能稳定赚点钱就行。如果我不是一直在跑这些赚钱项目,这问题还真答不上来。市面上副业一大堆,能快速拿到结果,并且有稳定收益的还真不多。我第一反应就是…...

AI4S:战略赋能与产业突围,中科曙光的产业链优势解析

当AI技术从应用层向基础研究渗透,AI4S(人工智能驱动科学创新)正成为重塑科技发展逻辑、破解产业升级瓶颈的核心力量。它并非简单的技术叠加,而是以人工智能赋能基础科研,推动科研范式从“试错驱动”向“数据模型驱动”…...