当前位置: 首页 > article >正文

PyTorch深度学习框架之多分类交叉熵实现图像分类

目录一、自定义小CNN实现手机分类1、代码示例2、代码解析一、自定义小CNN实现手机分类1、代码示例适合苹果/华为/小米 3分类手机识别你可以直接改类别数适配你的任务importtorchimporttorch.nnasnnimporttorch.nn.functionalasFimporttorch.optimasoptimfromtorch.utils.dataimportDataset,DataLoaderfromtorchvisionimporttransformsfromPILimportImageimportos# --------------------------# 第一步定义自定义小CNN 就是我们说的3-4层卷积# 结构输入(3, 224, 224) → 卷积1 → ReLU → 池化 → 卷积2 → ReLU → 池化 → 卷积3 → ReLU → 池化 → 全连接1 → ReLU → 全连接输出# --------------------------classSmallPhoneCNN(nn.Module):def__init__(self,num_classes3): num_classes: 你要分的手机类别数默认3类苹果/华为/小米 super(SmallPhoneCNN,self).__init__()# 1. 卷积层部分3层卷积符合你说的3-4层参数量极小# 输入是RGB图片3通道所以第一层in_channels3self.conv1nn.Conv2d(in_channels3,# 输入RGB三通道out_channels16,# 提取16种特征小任务足够用kernel_size3,# 3×3卷积核最常用padding1# 填充保证卷积后大小不变)self.pool1nn.MaxPool2d(kernel_size2,stride2)# 池化把尺寸缩小一半self.conv2nn.Conv2d(16,32,kernel_size3,padding1)# 输入16通道输出32种特征self.pool2nn.MaxPool2d(2,2)self.conv3nn.Conv2d(32,64,kernel_size3,padding1)# 输入32通道输出64种特征self.pool3nn.MaxPool2d(2,2)# 2. 全连接层部分把卷积提取的特征转成分类输出# 计算一下输入224×224经过3次池化每次缩小一半224 → 112 → 56 → 28# 所以最后特征大小是 64通道 × 28 × 28 64*28*28 50176self.fc1nn.Linear(64*28*28,128)# 卷积特征转128维隐藏向量self.fc2nn.Linear(128,num_classes)# 最后输出128维转成类别数就是我们要的结果# 3. 按照规则做初始化ReLU用Kaiming初始化self._initialize_weights()def_initialize_weights(self):# 所有卷积和全连接都用Kaiming初始化符合我们之前说的规则forminself.modules():ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight,modefan_out,nonlinearityrelu)ifm.biasisnotNone:nn.init.zeros_(m.bias)elifisinstance(m,nn.Linear):nn.init.kaiming_normal_(m.weight)nn.init.zeros_(m.bias)# 前向传播定义数据怎么走defforward(self,x):# x形状[batch_size, 3, 224, 224] → 一批图片3通道大小224×224xself.pool1(F.relu(self.conv1(x)))# 第一层卷积激活池化 → 输出[batch,16,112,112]xself.pool2(F.relu(self.conv2(x)))# 第二层 → 输出[batch,32,56,56]xself.pool3(F.relu(self.conv3(x)))# 第三层 → 输出[batch,64,28,28]# 把四维特征拉成一维给全连接层[batch,64,28,28] → [batch, 64*28*28]xx.flatten(1)xF.relu(self.fc1(x))# 全连接第一层激活xself.fc2(x)# 最后一层输出logits不用加softmax训练时候交叉熵会处理returnx# --------------------------# 第二步自定义数据集加载你的手机图片# 数据存放要求和之前一样每个类别一个文件夹# ./phone_data/train/# ├── apple/ (苹果手机图片)# ├── huawei/ (华为手机图片)# └── xiaomi/ (小米手机图片)# ./phone_data/test/ 同理放测试集# --------------------------classPhoneDataset(Dataset):def__init__(self,data_root,transformNone):self.data_rootdata_root self.transformtransform# 读取所有类别生成映射self.class_namessorted(os.listdir(data_root))self.class_to_idx{cls:ifori,clsinenumerate(self.class_names)}# 收集所有图片标签self.img_paths[]forcls_nameinself.class_names:cls_diros.path.join(data_root,cls_name)forimg_nameinos.listdir(cls_dir):self.img_paths.append((os.path.join(cls_dir,img_name),self.class_to_idx[cls_name]))def__len__(self):returnlen(self.img_paths)def__getitem__(self,idx):img_path,labelself.img_paths[idx]imgImage.open(img_path).convert(RGB)ifself.transform:imgself.transform(img)returnimg,label# --------------------------# 第三步训练配置训练流程# --------------------------if__name____main__:# 超参数NUM_CLASSES3# 改成你自己的类别数BATCH_SIZE8EPOCHS15LR1e-3# 数据预处理train_transformtransforms.Compose([transforms.Resize((224,224)),# 直接resize到224小模型不用复杂增强transforms.RandomHorizontalFlip(),# 随机翻转增加数据多样性transforms.ToTensor(),transforms.Normalize(mean[0.485,0.456,0.406],std[0.229,0.224,0.225])])test_transformtransforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize(mean[0.485,0.456,0.406],std[0.229,0.224,0.225])])# 加载数据集train_datasetPhoneDataset(./phone_data/train,transformtrain_transform)test_datasetPhoneDataset(./phone_data/test,transformtest_transform)train_loaderDataLoader(train_dataset,batch_sizeBATCH_SIZE,shuffleTrue)test_loaderDataLoader(test_dataset,batch_sizeBATCH_SIZE,shuffleFalse)class_namestrain_dataset.class_namesprint(f分类类别{class_names}总训练样本{len(train_dataset)})# 初始化模型、损失、优化器modelSmallPhoneCNN(num_classesNUM_CLASSES)devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)modelmodel.to(device)# 多分类交叉熵和之前讲的完全一样不用改criterionnn.CrossEntropyLoss()optimizeroptim.Adam(model.parameters(),lrLR)# 打印模型参数量看看有多小total_paramssum(p.numel()forpinmodel.parameters())print(f模型总参数量{total_params/1000:.1f}K ≈{total_params/1000000:.2f}M)# 输出大概6.5M比MobileNet还小CPU都能轻松跑# --------------------------# 开始训练# --------------------------forepochinrange(EPOCHS):model.train()train_loss0.0forinputs,labelsintrain_loader:inputs,labelsinputs.to(device),labels.to(device)# 前向算损失outputsmodel(inputs)losscriterion(outputs,labels)# 反向更新optimizer.zero_grad()loss.backward()optimizer.step()train_lossloss.item()*inputs.size(0)# 计算平均损失train_losstrain_loss/len(train_dataset)# 测试集评估准确率model.eval()correct0withtorch.no_grad():forinputs,labelsintest_loader:inputs,labelsinputs.to(device),labels.to(device)outputsmodel(inputs)_,predstorch.max(outputs,1)correcttorch.sum(predslabels.data)acccorrect.double()/len(test_dataset)print(fEpoch [{epoch1}/{EPOCHS}] 训练损失{train_loss:.4f}测试准确率{acc:.4f})# --------------------------# 第四步单张图片预测测试# --------------------------defpredict_single_image(img_path):# 1. model.eval()PyTorch的nn.Module自带方法切换模型到推理模式不用自己写model.eval()# 2. Image.open()PIL库自带方法打开图片不用自己写# 3. .convert(RGB)PIL自带把图片转成三通道RGB处理透明图/灰度图自带的imgImage.open(img_path).convert(RGB)# 4. test_transform我们自己定义的数据预处理就是几个torchvision自带变换的组合所有变换都是自带的我们只需要组合不用自己实现# 5. .unsqueeze(0)PyTorch张量自带方法给张量加一个batch维度自带的# 6. .to(device)PyTorch张量自带方法把张量放到GPU/CPU自带的img_tensortest_transform(img).unsqueeze(0).to(device)# 7. torch.no_grad()PyTorch自带上下文管理器关闭梯度计算自带的withtorch.no_grad():outputsmodel(img_tensor)# model本身我们已经定义好了调用也是自带的# 8. F.softmaxtorch.nn.functional自带函数自带的不需要实现probsF.softmax(outputs,dim1)# 9. torch.maxPyTorch自带函数找最大值和索引自带的top_prob,top_idxtorch.max(probs,dim1)# 10. top_idx[0].item()PyTorch张量自带方法把张量里的单个值转成Python数字自带的pred_clsclass_names[top_idx[0].item()]confidencetop_prob[0].item()print(f\n预测结果输入图片 {img_path})print(f分类结果{pred_cls}置信度{confidence:.4f})returnpred_cls,confidence# 替换成你自己的测试图片predict_single_image(./test_apple.jpg)2、代码解析QSmallPhoneCNN卷积层参数是怎么来的A待完善。。。。。。

相关文章:

PyTorch深度学习框架之多分类交叉熵实现图像分类

目录:一、自定义小CNN实现手机分类1、代码示例2、代码解析一、自定义小CNN实现手机分类 1、代码示例 适合苹果/华为/小米 3分类手机识别,你可以直接改类别数适配你的任务: import torch import torch.nn as nn import torch.nn.functional…...

终极指南:如何使用 Deepin Boot Maker 快速制作 Linux 启动盘

终极指南:如何使用 Deepin Boot Maker 快速制作 Linux 启动盘 【免费下载链接】deepin-boot-maker 项目地址: https://gitcode.com/gh_mirrors/de/deepin-boot-maker Deepin Boot Maker 是一款由 Linux Deepin 团队开发的开源启动盘制作工具,它让…...

告别云端依赖:Qwen3-VL-8B本地图文对话工具快速上手教程

告别云端依赖:Qwen3-VL-8B本地图文对话工具快速上手教程 1. 为什么选择本地部署多模态模型? 在当今AI应用蓬勃发展的时代,越来越多的企业和开发者开始关注数据隐私和安全性。云端API虽然方便,但存在以下痛点: 数据安…...

解决QQ音乐加密格式转换难题的开源方案:QMCDecode让音频文件自由管理成为可能

解决QQ音乐加密格式转换难题的开源方案:QMCDecode让音频文件自由管理成为可能 【免费下载链接】QMCDecode QQ音乐QMC格式转换为普通格式(qmcflac转flac,qmc0,qmc3转mp3, mflac,mflac0等转flac),仅支持macOS,可自动识别到QQ音乐下载…...

智能图像识别自动点击:解放双手的安卓自动化神器

智能图像识别自动点击:解放双手的安卓自动化神器 【免费下载链接】Smart-AutoClicker An open-source auto clicker on images for Android 项目地址: https://gitcode.com/gh_mirrors/smar/Smart-AutoClicker 你是否曾遇到这样的困境:游戏中需要…...

5个步骤打造企业级网络净化与全设备防护方案

5个步骤打造企业级网络净化与全设备防护方案 【免费下载链接】AdGuardHomeRules 高达百万级规则!由我原创&整理的 AdGuardHomeRules ADH广告拦截过滤规则!打造全网最强最全规则集 项目地址: https://gitcode.com/gh_mirrors/ad/AdGuardHomeRules …...

webpack优化:Vue配置compression-webpack-plugin实现gzip压缩

需求实现 1.安装依赖 npm i -D compression-webpack-plugin6.1.12.修改vue .config.js配置 const CompressionPlugin require(compression-webpack-plugin) // gzip 相关 const isGZIP process.env.VUE_APP_GZIP ONmodule.exports {configureWebpack(config) {if (isGZ…...

源码之家_最新建站源码_开源项目_成品源码一键部署

在互联网技术飞速发展的今天,网站建设已成为企业、个人展示形象、开展业务的重要窗口。然而,从零开始搭建一个功能完善、界面美观的网站,往往需要投入大量的时间和精力。对于开发者而言,寻找优质、可靠的源码资源,成为…...

腾讯HY-MT1.5翻译模型应用案例:多语言文档翻译实战

腾讯HY-MT1.5翻译模型应用案例:多语言文档翻译实战 1. 模型概述与核心能力 1.1 模型架构与版本 腾讯开源的HY-MT1.5翻译模型包含两个版本: HY-MT1.5-1.8B:18亿参数版本,专为边缘计算和实时翻译场景优化HY-MT1.5-7B&#xff1a…...

CYBER-VISION智能助盲系统部署指南:Dify平台保姆级教学

CYBER-VISION智能助盲系统部署指南:Dify平台保姆级教学 1. 项目背景与核心价值 CYBER-VISION智能助盲系统是一款基于YOLO分割算法的高精度目标识别工具,专为视障人群设计。系统通过实时解构视觉信号,将周围环境转化为可理解的导航信息&…...

SAM 3科研可视化:分割结果嵌入Jupyter Notebook交互式分析

SAM 3科研可视化:分割结果嵌入Jupyter Notebook交互式分析 1. 引言:当科研遇上智能分割 想象一下这样的场景:你正在分析一批生物医学图像,需要从复杂的细胞图像中精确分离出特定的细胞结构。传统方法需要手动标注,耗…...

NEURAL MASK 惊艳效果案例:城市景观照片的4K超分辨率重建

NEURAL MASK 惊艳效果案例:城市景观照片的4K超分辨率重建 每次翻看手机相册,是不是总有些照片让你觉得可惜?明明当时光线、构图都挺好,可放大一看,细节糊成一团,远处的招牌看不清,建筑的纹理也…...

通道分割并行处理改进YOLOv26双路径特征提取与计算效率双重优化

通道分割并行处理改进YOLOv26双路径特征提取与计算效率双重优化 引言 在目标检测领域,特征提取的效率和质量直接影响模型的性能表现。传统的卷积神经网络通常采用串行处理方式,所有通道共享相同的卷积核参数,这种设计虽然简单高效&#xff…...

云边协同 智启未来 | 阿里云 × ZStack 云边一体解决方案正式落地

随着数字化转型的不断深入,企业对于云计算的需求已从"集中上云"逐步演进为"云边协同"。在智慧城市、工业互联网、智慧交通、能源电力等行业场景中,数据的实时处理、低延迟响应以及本地化合规需求日益迫切。单一的中心化云架构已难以…...

像素时装锻造坊实战教程:用Enchantment功能将文字描述转为像素咒语技巧

像素时装锻造坊实战教程:用Enchantment功能将文字描述转为像素咒语技巧 1. 像素时装锻造坊简介 像素时装锻造坊是一款基于Stable Diffusion与Anything-v5的图像生成工具,它将AI图像生成与复古日系RPG游戏界面完美结合。不同于传统AI工具的单调界面&…...

3秒完整保存:颠覆传统的Full Page Screen Capture网页截图新方案

3秒完整保存:颠覆传统的Full Page Screen Capture网页截图新方案 【免费下载链接】full-page-screen-capture-chrome-extension One-click full page screen captures in Google Chrome 项目地址: https://gitcode.com/gh_mirrors/fu/full-page-screen-capture-ch…...

VSCode插件开发:集成Phi-4-mini-reasoning实现智能代码补全与解释

VSCode插件开发:集成Phi-4-mini-reasoning实现智能代码补全与解释 1. 为什么需要更智能的代码补全 传统的代码补全工具如Codex主要基于模式匹配和统计概率,虽然能快速给出建议,但缺乏真正的理解能力。在实际开发中,我们经常遇到…...

计算机组成原理视角:解析GTE-Base-ZH在GPU上的计算与存储

计算机组成原理视角:解析GTE-Base-ZH在GPU上的计算与存储 最近在折腾一些文本嵌入模型,发现大家讨论模型效果的多,但聊它背后在硬件上怎么“跑”起来的少。这就像开车只关心能跑多快,却不看发动机是怎么工作的。今天,…...

隧道液氮速冻机哪家企业值得信赖

隧道液氮速冻机行业分析:成都华能低温设备制造有限公司的卓越表现一、行业痛点分析在隧道液氮速冻机领域,存在着一些技术挑战。首先,速冻速度的提升面临瓶颈。传统的速冻方式难以满足现代食品加工等行业对于快速冻结以保证产品品质的要求。据…...

WarcraftHelper完整指南:3步解决魔兽争霸3在现代电脑上的兼容性问题

WarcraftHelper完整指南:3步解决魔兽争霸3在现代电脑上的兼容性问题 【免费下载链接】WarcraftHelper Warcraft III Helper , support 1.20e, 1.24e, 1.26a, 1.27a, 1.27b 项目地址: https://gitcode.com/gh_mirrors/wa/WarcraftHelper 还在为经典游戏《魔兽…...

火绒安全软件6.0 深度评测 | 安静、安全、纯粹的“反PUA型“杀毒软件

🛡️ 火绒安全软件6.0 深度评测 一、 软件简介 定义:Windows终端安全软件,成立于2012年,以“干净”著称。定位:只做安全本质(不做浏览器、输入法、导航),不靠广告赚钱(…...

Wan2.2-I2V-A14B与MATLAB联合仿真:为科学可视化生成示意图

Wan2.2-I2V-A14B与MATLAB联合仿真:为科学可视化生成示意图 1. 科研可视化的新选择 在科研和工程领域,数据可视化一直是成果展示的关键环节。传统方法往往需要研究人员手动绘制示意图,既耗时又难以保证一致性。最近我们尝试了一种新方法&…...

如何为Jellyfin添加豆瓣插件:一键获取中文元数据和评分的完整指南

如何为Jellyfin添加豆瓣插件:一键获取中文元数据和评分的完整指南 【免费下载链接】jellyfin-plugin-douban Douban metadata provider for Jellyfin 项目地址: https://gitcode.com/gh_mirrors/je/jellyfin-plugin-douban 还在为Jellyfin媒体库缺少中文信息…...

Science Bulletin-2026 | 首套中国40年城市土地利用数据集

数据介绍 Fig. 1. Study areas for time-series urban land use mapping in China. Spatial distribution of urban area density (defined as the ratio of built-up area to the total administrative area) across China and six representative subregions: (a) Xinjiang, …...

BetterNCM Installer完整指南:三步打造个性化网易云音乐工作站

BetterNCM Installer完整指南:三步打造个性化网易云音乐工作站 【免费下载链接】BetterNCM-Installer 一键安装 Better 系软件 项目地址: https://gitcode.com/gh_mirrors/be/BetterNCM-Installer 还在为网易云音乐的功能限制感到困扰吗?BetterNC…...

3种方法实现微信聊天记录完整备份:WeChatExporter的高效实用指南

3种方法实现微信聊天记录完整备份:WeChatExporter的高效实用指南 【免费下载链接】WeChatExporter 一个可以快速导出、查看你的微信聊天记录的工具 项目地址: https://gitcode.com/gh_mirrors/wec/WeChatExporter 在数字时代,微信聊天记录承载着我…...

2001-2024年我国农作物分布栅格数据(小麦、玉米、水稻、甘蔗等)

1 数据介绍 中国农作物分布栅格数据集(2001-2024) 数据简介 本数据集由Yangyang Fu团队开发,提供2001-2024年中国28个省份30米分辨率的农作物分布栅格数据,涵盖单季稻、双季稻、冬小麦、玉米等主要作物类型及其轮作模式。 数…...

5分钟解锁中文版Figma:设计师亲手翻译的完整汉化方案

5分钟解锁中文版Figma:设计师亲手翻译的完整汉化方案 【免费下载链接】figmaCN 中文 Figma 插件,设计师人工翻译校验 项目地址: https://gitcode.com/gh_mirrors/fi/figmaCN 还在为Figma的英文界面而烦恼吗?FigmaCN为你带来完美解决方…...

手把手教你用GrsAi的Webhook和轮询,搞定GPT Image 1.5的异步图片生成任务

实战指南:基于GrsAi构建高可靠异步图像生成系统 当你的应用需要处理大量图像生成请求时,同步调用API往往会遇到超时、连接不稳定等问题。我曾在一个电商项目中使用同步调用,结果在促销高峰期系统频繁崩溃——直到改用异步架构才彻底解决问题。…...

Intv_AI_MK11助力后端开发:构建基于大模型的智能API服务

Intv_AI_MK11助力后端开发:构建基于大模型的智能API服务 1. 智能API服务的时代机遇 最近跟几个做后端开发的朋友聊天,发现大家都在讨论同一个问题:如何把大模型能力快速集成到现有系统中。传统做法要么调用第三方API(贵且慢&…...