当前位置: 首页 > news >正文

《深度学习实战》第2集-补充:卷积神经网络(CNN)与图像分类 实战代码解析和改进

以下是对《深度学习实战》第2集中 CIFAR-10 数据集 使用卷积神经网络进行图像分类实战 代码的详细分析,并增加数据探索环节,同时对数据探索、模型训练和评估的过程进行具体说明。所有代码都附上了运行结果配图,方便对比。


《深度学习实战》第2集 补充:数据探索与分析

在深度学习项目中,数据探索(Exploratory Data Analysis, EDA)是至关重要的一步。通过数据探索,我们可以了解数据集的基本特性、分布情况以及潜在问题,从而为后续的模型设计和优化提供指导。

1. 数据探索的目标

  • 了解 CIFAR-10 数据集的类别分布。
  • 可视化样本图像,观察其特征。
  • 分析数据预处理的效果。

2. 数据探索实现

2.1 类别分布分析

CIFAR-10 数据集包含 10 个类别,每个类别的样本数量应均匀分布。我们可以通过以下代码统计类别分布:

import matplotlib.pyplot as plt# 统计类别分布
train_labels = [label for _, label in train_dataset]
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
label_counts = {class_names[i]: train_labels.count(i) for i in range(10)}# 可视化类别分布
plt.bar(label_counts.keys(), label_counts.values())
plt.title("Class Distribution in CIFAR-10")
plt.xlabel("Class")
plt.ylabel("Number of Samples")
plt.xticks(rotation=45)
plt.show()

代码运行结果输出:

在这里插入图片描述

结果分析

  • 如果类别分布均匀,说明数据集没有类别不平衡问题。
  • 在 CIFAR-10 中,每个类别有 5,000 张训练图像,分布均衡。
2.2 样本可视化

为了直观了解数据集中的图像特征,我们可以随机抽取一些样本并可视化:

import numpy as np# 可视化样本图像
def imshow(img):img = img / 2 + 0.5  # 反归一化npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.axis('off')plt.show()# 获取一批数据
dataiter = iter(train_loader)
images, labels = next(dataiter)# 显示图像
imshow(torchvision.utils.make_grid(images[:16]))  # 显示前 16 张图像
print("Labels:", [class_names[label] for label in labels[:16]])

代码运行结果输出:

在这里插入图片描述

Labels: ['bird', 'ship', 'automobile', 'ship', 'cat', 'truck', 'airplane', 'bird', 'airplane', 'frog', 'ship', 'bird', 'automobile', 'bird', 'automobile', 'truck']

结果分析

  • 图像大小为 32x32,分辨率较低,但足以捕捉基本特征。
  • 不同类别的图像具有明显的视觉差异(如飞机与汽车、猫与狗等),这有助于模型学习区分不同类别。
2.3 数据预处理效果

数据预处理包括调整大小、归一化等操作。我们可以通过打印预处理后的图像张量来验证其效果:

print("Preprocessed Image Shape:", images.shape)  # 输出形状
print("Preprocessed Image Values:", images[0].min().item(), images[0].max().item())  # 输出归一化范围

代码运行输出结果:

Preprocessed Image Shape: torch.Size([64, 3, 224, 224])
Preprocessed Image Values: -0.929411768913269 1.0

结果分析

  • 预处理后图像被调整为 224x224(ResNet 输入要求),并归一化到 [-1, 1] 范围。
  • 这些操作确保了输入数据的一致性和模型的稳定性。

原代码分析与改进

1. 数据加载与预处理

代码中使用 torchvision.transforms 对数据进行了标准化和尺寸调整。以下是关键步骤的解释:

  • Resize(224):将图像从原始的 32x32 调整为 ResNet 的输入尺寸 224x224。
  • Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)):将像素值归一化到 [-1, 1] 范围,以加速收敛。

改进建议

  • 添加数据增强(Data Augmentation),如随机裁剪、水平翻转等,以提高模型的泛化能力:
    transform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    

改进后代码替换原代码,运行后输出:

Epoch 1, Loss: 0.6051
Epoch 2, Loss: 0.3872
Epoch 3, Loss: 0.3126
Epoch 4, Loss: 0.2649
Epoch 5, Loss: 0.2316
Test Accuracy: 0.9022
可以看到损失系数不同,但精确率最终结果差不多。

2. 模型训练

代码中使用了预训练的 ResNet-18 模型,并修改了最后一层以适应 CIFAR-10 的 10 个类别。以下是训练过程的关键点:

2.1 模型结构
  • ResNet-18 是一个轻量级的 CNN 架构,包含 18 层卷积网络。
  • 修改全连接层(model.fc)以输出 10 个类别的概率。
2.2 训练过程
  • 使用 Adam 优化器,学习率为 0.001。
  • 损失函数为交叉熵损失(nn.CrossEntropyLoss),适用于多分类任务。
  • 每个 epoch 后打印平均损失,便于监控训练进度。

改进建议

  • 增加学习率调度器(Learning Rate Scheduler),例如余弦退火或 StepLR,以动态调整学习率。
  • 保存最佳模型权重,避免过拟合。

3. 模型评估

代码中通过测试集计算了模型的准确率。以下是评估过程的关键点:

3.1 测试过程
  • 将模型切换为评估模式(model.eval()),关闭 Dropout 和 BatchNorm 的随机性。
  • 使用 torch.no_grad() 禁用梯度计算,减少内存消耗。
3.2 结果分析

假设测试准确率为 75%,说明模型在 CIFAR-10 上表现良好,但仍有一定的提升空间。

改进建议

  • 计算混淆矩阵(Confusion Matrix),分析模型在不同类别上的表现:
    from sklearn.metrics import confusion_matrix
    import seaborn as snsall_preds, all_labels = [], []
    with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, preds = torch.max(outputs, 1)all_preds.extend(preds.cpu().numpy())all_labels.extend(labels.cpu().numpy())cm = confusion_matrix(all_labels, all_preds)
    sns.heatmap(cm, annot=True, fmt="d", xticklabels=class_names, yticklabels=class_names)
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.show()
    

代码运行输出结果:
在这里插入图片描述

  • 分析上图混淆矩阵可以发现模型在某些类别(如“猫”与“狗”)上容易混淆,Cat 和 Dog 的矩阵交汇数值相对偏高,从而指导进一步优化。

总结

通过增加数据探索环节,我们深入了解了 CIFAR-10 数据集的特性,并验证了数据预处理的有效性。在模型训练和评估过程中,我们分析了代码的实现细节,并提出了改进建议,包括数据增强、学习率调度器和混淆矩阵分析。这些改进可以帮助模型更好地适应数据集,并提升性能。

希望这些内容能为你提供更全面的理解!如果你有任何问题或想法,欢迎在评论区留言讨论。

相关文章:

《深度学习实战》第2集-补充:卷积神经网络(CNN)与图像分类 实战代码解析和改进

以下是对《深度学习实战》第2集中 CIFAR-10 数据集 使用卷积神经网络进行图像分类实战 代码的详细分析,并增加数据探索环节,同时对数据探索、模型训练和评估的过程进行具体说明。所有代码都附上了运行结果配图,方便对比。 《深度学习实战》第…...

nodejs:express + js-mdict 作为后端,vue 3 + vite 作为前端,在线查询英汉词典

向 doubao.com/chat/ 提问: node.js js-mdict 作为后端,vue 3 vite 作为前端,编写在线查询英汉词典 后端部分(express js-mdict ) 1. 项目结构 首先,创建一个项目目录,结构如下&#xff1…...

《深度剖析Linux 系统 Shell 核心用法与原理_666》

1. 管道符的用法 查找当前目录下所有txt文件并统计行数 # 使用管道符将ls命令的结果传递给wc命令进行行数统计 ls *.txt | wc -l 在/etc目录下查找包含"network"的文件并统计数量 # 使用find命令查找文件,并通过grep查找包含特定字符串的文件&#xf…...

索提诺比率(Sortino Ratio):更精准的风险调整收益指标(中英双语)

索提诺比率(Sortino Ratio):更精准的风险调整收益指标 📉📊 📌 什么是索提诺比率? 在投资分析中,我们通常使用 夏普比率(Sharpe Ratio) 来衡量风险调整后的…...

minio作为K8S后端存储

docker部署minio mkdir -p /minio/datadocker run -d \-p 9000:9000 \-p 9001:9001 \--name minio \-v /minio/data:/data \-e "MINIO_ROOT_USERjbk" \-e "MINIO_ROOT_PASSWORDjbjbjb123" \quay.io/minio/minio server /data --console-address ":90…...

一周学会Flask3 Python Web开发-Jinja2模板访问对象

锋哥原创的Flask3 Python Web开发 Flask3视频教程: 2025版 Flask3 Python web开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili 如果渲染模板传的是对象,如果如何来访问呢? 我们看下下面示例: 定义一个Student类 cla…...

RAGS评测后的数据 如何利用influxdb和grafan 进行数据汇总查看

RAGS(通常指相关性、准确性、语法、流畅性)评测后的数据能借助 InfluxDB 存储,再利用 Grafana 进行可视化展示,实现从四个维度查看数据,并详细呈现每个问题对应的这四个指标情况。以下是详细步骤: 1. 环境准备 InfluxDB 安装与配置 依据自身操作系统,从 InfluxDB 官网下…...

第25周JavaSpringboot实战-电商项目 4.商品分类管理

商品分类模块开发笔记 模块功能概述 实现分类数据的 增删改查 功能核心难点: 分类的父子级目录结构递归实现多级分类查找列表展示顺序控制(从父级向子级递归) 接口说明 后台接口 1. 添加分类 请求地址: /admin/category/add 请求方法: …...

C语言--正序、逆序输出为奇数的位。

题目&#xff1a; 采用正序和逆序分别输出为奇数的位。例如输入12345&#xff0c;正序输出135&#xff0c;逆序输出531 代码&#xff1a; #include <stdio.h>void printOddDigits(int num) {int res 0;int divider 10;while (num / divider > 10) {divider * 10;…...

C#快速调用DeepSeek接口,winform接入DeepSeek查询资料 C#零门槛接入DeepSeek C#接入DeepSeek源代码下载

下载地址<------完整源码 在数字化转型加速的背景下&#xff0c;企业应用系统对智能服务的需求日益增长。DeepSeek作为先进的人工智能服务平台&#xff0c;其自然语言处理、图像识别等核心能力可显著提升业务系统的智能化水平。传统开发模式下&#xff0c;C#开发者需要耗费大…...

H13-821 V3.0 HCIP 华为云服务架构题库

华为云上哪个服务能够用于收集来自主机和云服务的日志数据&#xff0c;并通过海量日志数据的分析与处理帮助开发或运维人员进行问题定位和分析&#xff1f; A&#xff1a;云监控服务 B&#xff1a;云日志服务 C&#xff1a;云审计服务 D&#xff1a;对象存储服务 答案&#xff…...

Linux主机用户登陆安全配置

Linux主机用户登陆安全配置 在Linux主机上进行用户登录安全配置是一个重要的安全措施&#xff0c;可以防止未经授权的访问。以下是如何创建用户hbu、赋予其sudo权限&#xff0c;以及禁止root用户SSH登录&#xff0c;以及通过ssh key管理主机用户登陆。 创建用户hbu 使用具有…...

提升数据洞察力:五款报表软件助力企业智能决策

概述 随着数据量的激增和企业对决策支持需求的提升&#xff0c;报表软件已经成为现代企业管理中不可或缺的工具。这些软件能够帮助企业高效处理数据、生成报告&#xff0c;并将数据可视化&#xff0c;从而推动更智能的决策过程。 1. 山海鲸报表 概述&#xff1a; 山海鲸报表…...

Linux | man 手册使用详解

注&#xff1a;本文为 “Linux man 手册” 相关文章合辑。 略作重排。 man 手册常用命令 1. 查看和搜索手册页 查看特定软件包的手册页&#xff0c;并使用 grep 命令过滤出包含特定关键字的行&#xff1a; man <package> | grep <keyword>在整个系统的手册页中…...

安全见闻4

今天学了Windows操作系统和驱动程序的相关知识 Windows注册表 注册表是windows系统中具有层次结构的核心数据库 储存的数据对windows 和Windows上运行的应用程序和服务至关重要。注册表时帮助windows控制硬件、软件、用户环境和windows界面的一套数据文件。 打开注册表编辑器…...

项目实战--网页五子棋(匹配模块)(4)

上期我们完成了游戏大厅的前端部分内容&#xff0c;今天我们实现后端部分内容 1. 维护在线用户 在用户登录成功后&#xff0c;我们可以维护好用户的websocket会话&#xff0c;把用户表示为在线状态&#xff0c;方便获取到用户的websocket会话 package org.ting.j20250110_g…...

P8716 [蓝桥杯 2020 省 AB2] 回文日期

1 题目说明 2 题目分析 暴力不会超时&#xff0c;O(n)的时间复杂度&#xff0c; < 1 0 8 <10^8 <108。分析见代码&#xff1a; #include<iostream> #include<string> using namespace std;int m[13]{0,31,28,31,30,31,30,31,31,30,31,30,31};// 判断日期…...

如何在视频中提取关键帧?

在视频处理中&#xff0c;提取关键帧是一项常见的任务。下面将介绍如何基于FFmpeg和Python&#xff0c;结合OpenCV库来实现从视频中提取关键帧的功能。 实现思路 使用FFmpeg获取视频的关键帧时间戳&#xff1a;FFmpeg是一个强大的视频处理工具&#xff0c;可以通过命令行获取…...

為什麼使用不限量動態住宅IP採集數據?

在瞭解“不限量動態住宅IP數據採集”之前&#xff0c;我們需要先搞清楚什麼是“動態住宅IP”。簡單來說&#xff0c;動態IP是一種會定期變化的IP地址&#xff0c;通常由互聯網服務提供商&#xff08;ISP&#xff09;分配給家庭用戶。與固定IP&#xff08;靜態IP&#xff09;不同…...

Go语言中使用viper绑定结构体和yaml文件信息时,标签的使用

在Go中使用Viper将YAML配置绑定到结构体时&#xff0c;主要依赖 mapstructure 标签&#xff08;而非 json 或 yaml 标签&#xff09;实现字段名映射。 --- ### 1. **基础绑定方法** 使用 viper.Unmarshal(&config) 或 viper.UnmarshalKey("key", &subConfi…...

应用升级/灾备测试时使用guarantee 闪回点迅速回退

1.场景 应用要升级,当升级失败时,数据库回退到升级前. 要测试系统,测试完成后,数据库要回退到测试前。 相对于RMAN恢复需要很长时间&#xff0c; 数据库闪回只需要几分钟。 2.技术实现 数据库设置 2个db_recovery参数 创建guarantee闪回点&#xff0c;不需要开启数据库闪回。…...

VB.net复制Ntag213卡写入UID

本示例使用的发卡器&#xff1a;https://item.taobao.com/item.htm?ftt&id615391857885 一、读取旧Ntag卡的UID和数据 Private Sub Button15_Click(sender As Object, e As EventArgs) Handles Button15.Click轻松读卡技术支持:网站:Dim i, j As IntegerDim cardidhex, …...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建&#xff08;全平台详解&#xff09; 在开始使用 React Native 开发移动应用之前&#xff0c;正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南&#xff0c;涵盖 macOS 和 Windows 平台的配置步骤&#xff0c;如何在 Android 和 iOS…...

Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件

今天呢&#xff0c;博主的学习进度也是步入了Java Mybatis 框架&#xff0c;目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学&#xff0c;希望能对大家有所帮助&#xff0c;也特别欢迎大家指点不足之处&#xff0c;小生很乐意接受正确的建议&…...

DAY 47

三、通道注意力 3.1 通道注意力的定义 # 新增&#xff1a;通道注意力模块&#xff08;SE模块&#xff09; class ChannelAttention(nn.Module):"""通道注意力模块(Squeeze-and-Excitation)"""def __init__(self, in_channels, reduction_rat…...

测试markdown--肇兴

day1&#xff1a; 1、去程&#xff1a;7:04 --11:32高铁 高铁右转上售票大厅2楼&#xff0c;穿过候车厅下一楼&#xff0c;上大巴车 &#xffe5;10/人 **2、到达&#xff1a;**12点多到达寨子&#xff0c;买门票&#xff0c;美团/抖音&#xff1a;&#xffe5;78人 3、中饭&a…...

MMaDA: Multimodal Large Diffusion Language Models

CODE &#xff1a; https://github.com/Gen-Verse/MMaDA Abstract 我们介绍了一种新型的多模态扩散基础模型MMaDA&#xff0c;它被设计用于在文本推理、多模态理解和文本到图像生成等不同领域实现卓越的性能。该方法的特点是三个关键创新:(i) MMaDA采用统一的扩散架构&#xf…...

现代密码学 | 椭圆曲线密码学—附py代码

Elliptic Curve Cryptography 椭圆曲线密码学&#xff08;ECC&#xff09;是一种基于有限域上椭圆曲线数学特性的公钥加密技术。其核心原理涉及椭圆曲线的代数性质、离散对数问题以及有限域上的运算。 椭圆曲线密码学是多种数字签名算法的基础&#xff0c;例如椭圆曲线数字签…...

管理学院权限管理系统开发总结

文章目录 &#x1f393; 管理学院权限管理系统开发总结 - 现代化Web应用实践之路&#x1f4dd; 项目概述&#x1f3d7;️ 技术架构设计后端技术栈前端技术栈 &#x1f4a1; 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 &#x1f5c4;️ 数据库设…...

【VLNs篇】07:NavRL—在动态环境中学习安全飞行

项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战&#xff0c;克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...