当前位置: 首页 > article >正文

PyTorch实战:从零构建ResNet50模型(CIFAR10训练+测试+ONNX转换)

1. ResNet50模型基础认知第一次接触ResNet50时我被它的残差连接设计惊艳到了。传统神经网络随着层数增加会出现梯度消失问题而ResNet通过跨层直连通道让信息能够无损传递到更深层。这就好比在高速公路上设置应急车道即使主路拥堵车辆仍能通过应急道快速通行。ResNet50名称中的50代表50层深度实际包含49个卷积层和1个全连接层其核心结构是下图所示的残差块。每个残差块包含1x1卷积降维3x3卷积特征提取1x1卷积升维跨层连接identity shortcutclass Bottleneck(nn.Module): expansion 4 def __init__(self, in_channels, channels, stride1): super().__init__() self.conv1 nn.Conv2d(in_channels, channels, kernel_size1) self.bn1 nn.BatchNorm2d(channels) self.conv2 nn.Conv2d(channels, channels, kernel_size3, stridestride, padding1) self.bn2 nn.BatchNorm2d(channels) self.conv3 nn.Conv2d(channels, channels*self.expansion, kernel_size1) self.bn3 nn.BatchNorm2d(channels*self.expansion) self.shortcut nn.Sequential() if stride ! 1 or in_channels ! channels*self.expansion: self.shortcut nn.Sequential( nn.Conv2d(in_channels, channels*self.expansion, kernel_size1, stridestride), nn.BatchNorm2d(channels*self.expansion) ) def forward(self, x): out F.relu(self.bn1(self.conv1(x))) out F.relu(self.bn2(self.conv2(out))) out self.bn3(self.conv3(out)) out self.shortcut(x) return F.relu(out)2. PyTorch环境搭建推荐使用conda创建独立的Python环境避免包版本冲突。这是我验证过的稳定版本组合conda create -n resnet python3.8 conda activate resnet pip install torch1.12.1 torchvision0.13.1 pip install numpy pandas matplotlib tqdm关键工具说明torch.nn神经网络层实现torch.optim优化算法如SGD、Adamtorchvision.transforms图像预处理torch.utils.data数据加载与批处理验证GPU是否可用import torch print(torch.cuda.is_available()) # 输出True表示GPU可用 device torch.device(cuda if torch.cuda.is_available() else cpu)3. CIFAR10数据处理实战CIFAR10包含6万张32x32彩色图片5万训练1万测试共10个类别。处理流程如下3.1 数据集加载与增强from torchvision import datasets, transforms # 定义增强策略 train_transform transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.Resize(224), # ResNet原始输入尺寸 transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ]) test_transform transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ]) # 加载数据集 train_set datasets.CIFAR10(root./data, trainTrue, downloadTrue, transformtrain_transform) test_set datasets.CIFAR10(root./data, trainFalse, downloadTrue, transformtest_transform) # 创建数据加载器 train_loader torch.utils.data.DataLoader(train_set, batch_size64, shuffleTrue, num_workers4) test_loader torch.utils.data.DataLoader(test_set, batch_size64, shuffleFalse, num_workers4)3.2 数据可视化检查import matplotlib.pyplot as plt classes (plane, car, bird, cat, deer, dog, frog, horse, ship, truck) def imshow(img): img img * 0.5 0.5 # 反归一化 npimg img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() # 获取随机批次 dataiter iter(train_loader) images, labels next(dataiter) # 显示图像 imshow(torchvision.utils.make_grid(images[:4])) print( .join(f{classes[labels[j]]:5s} for j in range(4)))4. 完整ResNet50实现4.1 网络结构搭建class ResNet50(nn.Module): def __init__(self, num_classes10): super().__init__() self.in_channels 64 self.conv1 nn.Conv2d(3, 64, kernel_size7, stride2, padding3) self.bn1 nn.BatchNorm2d(64) self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) # 残差块组 self.layer1 self._make_layer(64, 3, stride1) self.layer2 self._make_layer(128, 4, stride2) self.layer3 self._make_layer(256, 6, stride2) self.layer4 self._make_layer(512, 3, stride2) self.avgpool nn.AdaptiveAvgPool2d((1, 1)) self.fc nn.Linear(512 * Bottleneck.expansion, num_classes) def _make_layer(self, channels, num_blocks, stride): strides [stride] [1]*(num_blocks-1) layers [] for stride in strides: layers.append(Bottleneck(self.in_channels, channels, stride)) self.in_channels channels * Bottleneck.expansion return nn.Sequential(*layers) def forward(self, x): x F.relu(self.bn1(self.conv1(x))) x self.maxpool(x) x self.layer1(x) x self.layer2(x) x self.layer3(x) x self.layer4(x) x self.avgpool(x) x torch.flatten(x, 1) x self.fc(x) return x4.2 模型初始化技巧def initialize_weights(model): for m in model.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, modefan_out, nonlinearityrelu) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) model ResNet50().to(device) initialize_weights(model)5. 训练与评估流程5.1 训练配置criterion nn.CrossEntropyLoss() optimizer torch.optim.SGD(model.parameters(), lr0.1, momentum0.9, weight_decay5e-4) scheduler torch.optim.lr_scheduler.StepLR(optimizer, step_size30, gamma0.1)5.2 训练循环def train(epoch): model.train() running_loss 0.0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs, targets inputs.to(device), targets.to(device) optimizer.zero_grad() outputs model(inputs) loss criterion(outputs, targets) loss.backward() optimizer.step() running_loss loss.item() if batch_idx % 100 99: print(fEpoch: {epoch}, Batch: {batch_idx1}, Loss: {running_loss/100:.3f}) running_loss 0.0 def test(): model.eval() correct 0 total 0 with torch.no_grad(): for inputs, targets in test_loader: inputs, targets inputs.to(device), targets.to(device) outputs model(inputs) _, predicted outputs.max(1) total targets.size(0) correct predicted.eq(targets).sum().item() acc 100. * correct / total print(fTest Accuracy: {acc:.2f}%) return acc for epoch in range(100): train(epoch) test() scheduler.step()6. 模型转换与部署6.1 PyTorch转ONNXdummy_input torch.randn(1, 3, 224, 224).to(device) torch.onnx.export(model, dummy_input, resnet50.onnx, input_names[input], output_names[output], dynamic_axes{input: {0: batch}, output: {0: batch}})6.2 ONNX模型验证import onnxruntime as ort ort_session ort.InferenceSession(resnet50.onnx) outputs ort_session.run(None, {input: dummy_input.cpu().numpy()}) # 对比原始模型输出 with torch.no_grad(): torch_output model(dummy_input) print(Output difference:, np.max(np.abs(outputs[0] - torch_output.cpu().numpy())))6.3 ONNX模型推理示例def preprocess_image(image_path): image Image.open(image_path).convert(RGB) transform transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ]) return transform(image).unsqueeze(0).numpy() def predict_onnx(image_path): input_data preprocess_image(image_path) outputs ort_session.run(None, {input: input_data}) pred np.argmax(outputs[0]) return classes[pred] print(predict_onnx(test_cat.jpg)) # 输出预测类别7. 性能优化技巧混合精度训练scaler torch.cuda.amp.GradScaler() for inputs, targets in train_loader: inputs, targets inputs.to(device), targets.to(device) with torch.cuda.amp.autocast(): outputs model(inputs) loss criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()梯度累积小批量场景accumulation_steps 4 optimizer.zero_grad() for i, (inputs, targets) in enumerate(train_loader): inputs, targets inputs.to(device), targets.to(device) outputs model(inputs) loss criterion(outputs, targets) / accumulation_steps loss.backward() if (i1) % accumulation_steps 0: optimizer.step() optimizer.zero_grad()学习率预热from torch.optim.lr_scheduler import LambdaLR warmup_epochs 5 scheduler LambdaLR(optimizer, lr_lambdalambda epoch: (epoch1)/warmup_epochs if epoch warmup_epochs else 0.1**(epoch//30))

相关文章:

PyTorch实战:从零构建ResNet50模型(CIFAR10训练+测试+ONNX转换)

1. ResNet50模型基础认知 第一次接触ResNet50时,我被它的"残差连接"设计惊艳到了。传统神经网络随着层数增加会出现梯度消失问题,而ResNet通过跨层直连通道,让信息能够无损传递到更深层。这就好比在高速公路上设置应急车道&#xf…...

从浮点到定点:手把手教你用MATLAB自定义函数实现加减乘除(避坑溢出与精度损失)

从浮点到定点:手把手教你用MATLAB自定义函数实现加减乘除(避坑溢出与精度损失) 当算法需要从实验室环境迁移到嵌入式设备时,浮点运算的硬件开销常常成为瓶颈。这时定点数运算就像一把手术刀——精准控制每个比特的用途&#xff0c…...

AlertDialog高斯模糊进阶指南:Android12新特性与兼容方案对比

AlertDialog高斯模糊进阶指南:Android12新特性与兼容方案对比 在移动应用设计中,视觉层次的营造往往决定了用户体验的优劣。当用户与AlertDialog交互时,背景的高斯模糊效果能够有效聚焦注意力,同时保持界面连贯性。Android 12引入…...

告别序列‘拉直’的暴力美学:手把手复现MaIR,体验保持图像局部与连续性的Mamba新玩法

告别序列“拉直”的暴力美学:手把手复现MaIR,体验保持图像局部与连续性的Mamba新玩法 在计算机视觉领域,图像修复任务(如去噪、超分、去模糊)一直是研究热点。传统方法往往将2D图像“拉直”为1D序列进行处理&#xff0…...

vLLM-v0.17.1应用场景:智能硬件语音助手离线LLM推理部署

vLLM-v0.17.1应用场景:智能硬件语音助手离线LLM推理部署 1. 技术背景与需求分析 智能硬件语音助手正在经历从云端依赖向本地化处理的转变。传统方案面临三大痛点: 网络延迟问题:云端API调用导致响应速度受限隐私安全顾虑:用户对…...

避开这3个坑!MIPI走线设计如何减少对GSM信号的干扰(含阻抗匹配计算)

避开这3个坑!MIPI走线设计如何减少对GSM信号的干扰(含阻抗匹配计算) 在消费电子硬件设计中,MIPI接口与射频信号的共存问题一直是工程师面临的棘手挑战。特别是当设备需要同时支持高清显示和GSM通信功能时,MIPI信号对GS…...

从SolidWorks到Gazebo:手把手教你用SW2URDF插件为ROS2 Humble机械臂建模(含ROS2适配避坑指南)

从SolidWorks到Gazebo:ROS2 Humble机械臂建模全流程实战 1. 工业设计与机器人仿真的桥梁搭建 当机械工程师第一次接触机器人仿真时,往往会面临一个关键挑战:如何将精心设计的SolidWorks模型转化为可在Gazebo中运行的仿真模型?这个…...

OpenH264:开源H.264编解码库的技术实现与工程实践

OpenH264:开源H.264编解码库的技术实现与工程实践 【免费下载链接】openh264 Open Source H.264 Codec 项目地址: https://gitcode.com/gh_mirrors/op/openh264 OpenH264作为Cisco维护的开源H.264编解码库,在实时视频通信、流媒体传输和嵌入式设…...

bert-base-chinese新手教程:从零开始学习中文预训练模型部署与使用

bert-base-chinese新手教程:从零开始学习中文预训练模型部署与使用 1. 认识bert-base-chinese模型 1.1 什么是BERT模型 BERT(Bidirectional Encoder Representations from Transformers)是Google在2018年发布的预训练语言模型。它通过大规…...

基于智能体(Agent)的自动化图像工作流:Wan2.2-I2V-A14B与任务编排

基于智能体(Agent)的自动化图像工作流:Wan2.2-I2V-A14B与任务编排 1. 引言:当图像生成遇上智能体 想象一下这样的场景:你需要为电商平台制作一组节日主题的广告图,包含特定风格的背景、商品展示和人物互动…...

Qwen3-Reranker-0.6B效果展示:中英术语对照表构建中的跨语言排序

Qwen3-Reranker-0.6B效果展示:中英术语对照表构建中的跨语言排序 1. 跨语言术语排序的技术挑战 在全球化信息时代,构建准确的中英术语对照表已成为跨语言交流、技术文档翻译和国际合作的重要基础。传统方法往往面临几个核心痛点: 语义鸿沟…...

Qwen3.5-4B-Claude-Opus实战案例:用推理链输出提升技术沟通准确性

Qwen3.5-4B-Claude-Opus实战案例:用推理链输出提升技术沟通准确性 1. 模型介绍与核心能力 Qwen3.5-4B-Claude-4.6-Opus-Reasoning-Distilled-GGUF是一个基于Qwen3.5-4B的推理蒸馏模型,专门针对结构化分析、分步骤回答以及代码与逻辑类问题的处理能力进…...

单片机通用按键处理模块设计与实现

单片机通用按键处理模块设计与实现1. 项目概述1.1 模块功能特性本按键处理模块为单片机系统提供了一套完整的按键事件处理解决方案,具有以下核心功能:基础按键检测:支持按下(PRESS)和释放(RELEASE)事件检测高级触发模式:长按触发(…...

构建大规模数据导入系统:技术选型与工程实践

在现代数据密集型应用中,将海量数据高效、可靠地导入目标存储系统是一项基础但极具挑战的任务。表面上看,“写入数据库”只是一个简单的操作;然而,当数据规模达到TB级、业务逻辑涉及合并去重、系统架构包含多个存储引擎时&#xf…...

3分钟掌握Balena Etcher:安全可靠的跨平台镜像烧录工具

3分钟掌握Balena Etcher:安全可靠的跨平台镜像烧录工具 【免费下载链接】etcher Flash OS images to SD cards & USB drives, safely and easily. 项目地址: https://gitcode.com/GitHub_Trending/et/etcher Balena Etcher是一款专为简化操作系统镜像部署…...

Kali Linux安装失败?5个常见报错解决方案(虚拟机专用版)

Kali Linux虚拟机安装报错实战指南:5个高频问题深度解析 当你兴致勃勃地在VMware里安装Kali Linux准备大展身手时,突然弹出的报错信息就像一盆冷水浇下来。别急着重装——90%的安装问题都有现成解决方案。本文将聚焦虚拟机环境下最棘手的5类安装报错&…...

Linux服务器GPU环境配置避坑指南:从Nvidia驱动到PyTorch Lightning一站式搞定

Linux服务器GPU环境配置避坑指南:从Nvidia驱动到PyTorch Lightning一站式搞定 当你第一次在Linux服务器上配置GPU环境时,可能会遇到各种令人抓狂的问题:驱动安装失败、CUDA版本不兼容、PyTorch无法识别GPU...这些问题足以让任何一个开发者崩溃…...

Win11Debloat终极指南:5分钟让你的Windows系统焕然一新

Win11Debloat终极指南:5分钟让你的Windows系统焕然一新 【免费下载链接】Win11Debloat 一个简单的PowerShell脚本,用于从Windows中移除预装的无用软件,禁用遥测,从Windows搜索中移除Bing,以及执行各种其他更改以简化和…...

Shield CLI:MySQL 插件 vs phpMyAdmin:轻量 Web 数据库管理工具对比

phpMyAdmin 是 MySQL Web 管理的事实标准,1998 年发布至今,功能覆盖面极广。但在"查个数据、改个表、看看关系"这类日常场景下,它的部署成本和界面复杂度显得有些过重。Shield CLI MySQL 插件是一个 7MB 的单二进制 Web 客户端&…...

3步颠覆性解决方案:零成本条码生成技术让企业彻底告别付费依赖

3步颠覆性解决方案:零成本条码生成技术让企业彻底告别付费依赖 【免费下载链接】librebarcode Libre Barcode: barcode fonts for various barcode standards. 项目地址: https://gitcode.com/gh_mirrors/li/librebarcode Libre Barcode开源字体库通过字体化…...

深度解析PDFMathTranslate:揭秘AI如何实现毫秒级学术文档翻译与精准排版保留

深度解析PDFMathTranslate:揭秘AI如何实现毫秒级学术文档翻译与精准排版保留 【免费下载链接】PDFMathTranslate PDF scientific paper translation with preserved formats - 基于 AI 完整保留排版的 PDF 文档全文双语翻译,支持 Google/DeepL/Ollama/Op…...

CasRel模型LaTeX学术论文辅助工具:自动提取相关工作和贡献

CasRel模型LaTeX学术论文辅助工具:自动提取相关工作和贡献 每次打开一篇新的学术论文,尤其是那些动辄几十页的综述或顶会文章,你是不是也有点头大?密密麻麻的文字里,最关键的信息——“别人做了什么”、“他们有什么不…...

EVA-01场景应用:电商商品分析、文档信息提取,真实工作流分享

EVA-01场景应用:电商商品分析、文档信息提取,真实工作流分享 1. 从科幻到现实:EVA-01的商业价值 在电商运营和文档处理的日常工作中,我们常常面临这样的挑战:海量商品图片需要人工标注关键信息,繁杂的合同…...

LFM2.5-1.2B-Thinking-GGUF基础教程:单页Web界面交互逻辑与后处理机制

LFM2.5-1.2B-Thinking-GGUF基础教程:单页Web界面交互逻辑与后处理机制 1. 模型与平台介绍 LFM2.5-1.2B-Thinking-GGUF是Liquid AI推出的轻量级文本生成模型,专为低资源环境优化设计。这个镜像采用内置GGUF模型文件和llama.cpp运行时,提供了…...

8255A工作方式0实战:手把手教你用汇编语言驱动八路抢答器LED与数码管

8255A工作方式0实战:从零构建八路抢答器驱动框架 记得第一次在实验室见到8255A芯片时,那块黑色的DIP封装器件看起来平平无奇,直到它让八颗LED随着我的汇编指令跳起"灯光芭蕾"。本文将带你深入这个经典可编程并行接口芯片的实战应用…...

保姆级教程:在Windows 11上为PyTorch配置CUDA 12.x和cuDNN(含环境变量疑难杂症排查)

Windows 11深度学习环境配置全攻略:从CUDA安装到PyTorch GPU加速实战 每次打开PyCharm准备大展身手时,看到那个令人心碎的False——torch.cuda.is_available()的输出结果,是不是感觉整个深度学习梦想都被泼了冷水?别担心&#xf…...

20吨燃气蒸汽锅炉实力厂家/支持上门安装调试

燃气蒸汽锅炉,认准源头实力厂家,不仅能买到品质过硬的设备,更能享受到省心便捷的上门安装调试服务,免去自行安装的繁琐与隐患,让设备快速投入平稳运行。我们作为深耕锅炉制造行业的实力厂家,具备正规生产资…...

K230目标检测实战:手把手教你用Labelme标注数据并一键转成VOC格式(附避坑指南)

K230目标检测实战:高效数据标注与VOC格式转换全攻略 当你第一次接触K230开发板进行目标检测项目时,数据准备往往是最大的拦路虎。特别是从原始图片到符合AI_Cube要求的VOC格式数据集,这个过程充满了各种"坑"。本文将分享一套经过实…...

半导体放电管TSS选型避坑指南:从RS485到CAN接口的实战经验分享

半导体放电管TSS选型避坑指南:从RS485到CAN接口的实战经验分享 在工业通信设备的电路保护设计中,浪涌防护是一个不可忽视的关键环节。作为一名长期奋战在一线的硬件工程师,我深知半导体放电管(TSS)选型过程中的种种陷阱…...

EVE舰船配置神器Pyfa全攻略:从新手到专家的实战指南

EVE舰船配置神器Pyfa全攻略:从新手到专家的实战指南 【免费下载链接】Pyfa Python fitting assistant, cross-platform fitting tool for EVE Online 项目地址: https://gitcode.com/gh_mirrors/py/Pyfa 在EVE Online的浩瀚宇宙中,每一位舰长都需…...