当前位置: 首页 > news >正文

传知代码-从零开始构建你的第一个神经网络

代码以及视频讲解

本文所涉及所有资源均在传知代码平台可获取

从零开始构建你的第一个神经网络

在本教程中,我们将使用PyTorch框架从零开始构建一个简单的卷积神经网络(CNN),用于图片二分类任务。CNN 是一种深度学习模型,特别适用于图像数据。本文将详细介绍如何一步步搭建一个包含4-5层的CNN,并训练它来分类图像。

1.准备工作

在开始之前,需要一些必要的库,本教程中使用的库的版本如下:

packageversion
python3.8.0
torch1.13.0
scikit-learn1.1.3
numpy1.20.1
matplotlib3.3.4

本软件包均是在anaconda虚拟环境中安装的。如果还没该虚拟环境还需要创建一下才可以。准备工作完成后就可以开始正式的神经网络搭建了。

2.数据集准备

在本教程中我们使用1元100元人民币两个类别。其数据集存放格式如下图所示:
在这里插入图片描述

一共分为三个部分train、test和val,即训练验证和测试。“1”和“100”文件夹下面就是数据集图片。下面就来对数据集进行处理:

  • 首先导入必要的包和进行数据的预处理
#导入包
import torch
import torchvision
import torchvision.transforms as transforms
#数据预处理
transform = transforms.Compose([transforms.Resize((64, 64)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
  • 加载数据集
train_data_path = "02-01-data-RMB_data/train" # 数据集存放的相对路径
trainset = datasets.ImageFolder(root=train_data_path, transform=transform)
trainloader = DataLoader(trainset, batch_size=4, shuffle=True)val_data_path = "02-01-data-RMB_data/val"
valset = datasets.ImageFolder(root=val_data_path, transform=transform)
valloader = DataLoader(valset, batch_size=4, shuffle=False)classes = trainset.classes # 获取类别名:这里为“1”和“100”

我们可以使用torchvision.datasets.ImageFolder来加载这些数据。ImageFolder会根据目录名自动将图像归类。

3.定义卷积神经网络结构并实例化模型

接下来,我们将定义一个简单的卷积神经网络。该网络将包含两个卷积层(每层后接池化层)和两个全连接层

# 定义卷积神经网络
class SimpleCNN(nn.Module):def __init__(self, num_classes):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3, padding=1) #彩色图片输入通道为3:RGB;输出为16self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 3, padding=1)self.fc1 = nn.Linear(32 * 16 * 16, 120)self.fc2 = nn.Linear(120, num_classes)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 32 * 16 * 16) #展平操作成为1维x = F.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型
num_classes = len(classes)
net = SimpleCNN(num_classes).to(device)

上述网络结构中nn.Conv2d(3, 16, 3, padding=1) 表明彩色图片输入通道为3:RGB;输出为16,这里16可以自己定义,但是要与下一卷积层的输入通道数相同。16后面的3为卷积核大小即为3*3卷积。

4.训练和验证模型

4.1.配置训练参数

在开始训练之前,我们需要定义损失函数和优化器。这里我们将使用交叉熵损失函数随机梯度下降SGD)优化器。
这里不同的损失函数和不同的优化器也会影响模型最终的效果。比如可以使用AdamAdamW等优化器。

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4.2 模型训练与验证

# 存储训练损失和验证损失
train_losses = []
val_losses = []# 训练模型并记录损失和准确率
for epoch in range(10):net.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data# 将输入和标签移动到GPUinputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()train_loss = running_loss / len(trainloader)train_accuracy = 100 * correct / totaltrain_losses.append(train_loss)# 验证模型net.eval()val_loss = 0.0correct = 0total = 0with torch.no_grad():for data in valloader:images, labels = data# 将输入和标签移动到GPUinputs, labels = inputs.to(device), labels.to(device)outputs = net(images)loss = criterion(outputs, labels)val_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()val_loss = val_loss / len(valloader)val_accuracy = 100 * correct / totalval_losses.append(val_loss)print(f'Epoch [{epoch + 1}/10], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')print('Finished Training')
#保存模型在本地
model_save_path = "simple_cnn.pth"
torch.save(net.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

在这里我们将训练模型10个epoch,并在每个epoch结束后输出损失准确率,可以看到每一个epoch上的结果。
训练过程如下图:
在这里插入图片描述

5.性能评估与模型测试

5.1 性能评估

绘制loss曲线

# 绘制损失曲线
plt.figure()
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.show()

查看结果
在这里插入图片描述

绘制混淆矩阵

# 混淆矩阵
net.eval()
all_preds = []
all_labels = []
with torch.no_grad():for data in valloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)all_preds.extend(predicted.cpu().numpy())all_labels.extend(labels.cpu().numpy())# 生成混淆矩阵
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)# 绘制混淆矩阵
plt.figure()
disp.plot(cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.show()

在这里插入图片描述

在val文件夹中1100的数量都是19,所以做到了100%的准确率。

5.2 模型测试

在第四步中我们将训练好的模型进行了保存,那么这里我们将模型加载出来对test数据集进行测试,看看模型效果如何。这里我们需要加载训练好的模型进行测试。

# 加载模型权重
model_path = "simple_cnn.pth"
net.load_state_dict(torch.load(model_path))# 将模型移动到设备上
net.to(device)# 设置模型为评估模式
net.eval()# 加载测试集
test_data_path = "02-01-data-RMB_data/test"
testset = datasets.ImageFolder(root=test_data_path, transform=transform)
testloader = DataLoader(testset, batch_size=4, shuffle=False)# 预测并输出结果
all_preds = []
all_labels = []
with torch.no_grad():for data in testloader:images, labels = data# 将数据移动到GPUimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)all_preds.extend(predicted.cpu().numpy())all_labels.extend(labels.cpu().numpy())# 你可以在这里处理预测结果,比如输出预测值或计算准确率
correct = sum(1 for x, y in zip(all_preds, all_labels) if x == y)
total = len(all_labels)
accuracy = correct / total * 100print(f"Test Accuracy: {accuracy:.2f}%")

完整代码和数据集在附件中,可以自行查看

6.总结

在本教程中,我们从零开始构建了一个简单的卷积神经网络,并用它来完成图片的二分类任务。我们介绍了如何定义网络结构,如何训练和评估模型,并讨论了进一步优化模型的方法。

接下来,你可以尝试扩展网络的深度或宽度,使用更大的数据集,或者使用迁移学习等技术来提升模型的性能。祝你学习愉快!

源码下载

相关文章:

传知代码-从零开始构建你的第一个神经网络

代码以及视频讲解 本文所涉及所有资源均在传知代码平台可获取 从零开始构建你的第一个神经网络 在本教程中,我们将使用PyTorch框架从零开始构建一个简单的卷积神经网络(CNN),用于图片二分类任务。CNN 是一种深度学习模型&#…...

大厂面试真题:SpringBoot的核心注解

其实理解一个注解就行了@SpringBootApplication,我们的启动类其实就加了这一个 但是这么答也不行,因为面试官要的答案肯定不止这一个 我们打开SpringBootApplication的源码,会发现上面加了一堆的注解 相对而言比较重要是下面三个…...

Java设计模式—面向对象设计原则(五) ----->迪米特法则(DP) (完整详解,附有代码+案例)

文章目录 3.5 迪米特法则(DP)3.5.1 概述3.5.2 案例 3.5 迪米特法则(DP) 迪米特法则:Demeter Principle,简称DP 3.5.1 概述 只和你的直接朋友交谈,不跟“陌生人”说话(Talk only to your immediate friends and not to stranger…...

docker多阶段镜像制作,比如nginx镜像,编译+制作

镜像制作, nginx的源码包 把nginx源码拷贝到容器内 编译要用到gcc make , 以及扩展工具 pcre openssl # "pcre" perl compatibal regulaer expression 刚开始,可以两个终端, 一个手工操作(编译安装、拷贝、环境变量等)&#xf…...

大语言模型量化方法GPTQ、GGUF、AWQ详细原理

大语言模型量化的目的是减少模型的计算资源需求和存储占用,同时尽量保持模型的性能。以下是几种常见的量化方法的原理; 1. GPTQ (Gradient-based Post-training Quantization) GPTQ 是一种基于梯度的后训练量化方法,主要目的是在减少浮点计…...

《 C++ 修炼全景指南:十 》自平衡的艺术:深入了解 AVL 树的核心原理与实现

摘要 本文深入探讨了 AVL 树(自平衡二叉搜索树)的概念、特点以及实现细节。我们首先介绍了 AVL 树的基本原理,并详细分析了其四种旋转操作,包括左旋、右旋、左右双旋和右左双旋,阐述了它们在保持树平衡中的重要作用。…...

SAP 特别总账标识[SGL]

1. 特别总账标识(SGL)概述 1.1 定义与目的 特别总账标识(Special General Ledger, SGL)在SAP系统中用于区分客户或供应商的不同业务类型,以便将特定的业务交易记录到非标准的总账科目中。 定义:SGL是一个用于标记特殊业务类型的…...

认知杂谈77《简单:通往高手的技巧》

内容摘要:          在信息爆炸、关系复杂的时代,简单是复杂背后的真谛。简单如“112”,是智慧的朴素呈现。简单有强大力量,像清泉般纯净,如“我爱你”简单却有力,基础财务知识也体现其在理财中的作…...

《SmartX ELF 虚拟化核心功能集》发布,详解 80+ 功能特性和 6 例金融实践

《SmartX ELF 虚拟化核心功能集》电子书现已发布!本书详细介绍了 SmartX ELF 虚拟化及云平台核心功能,包含虚机服务、容器服务、网络服务、存储服务、运维管理、工具服务、数据保护等各个方面。 即刻下载电子书,了解如何利用基于 SmartX ELF …...

9月23日

思维导图 作业 统计家目录下.c文件的个数 #!/bin/bashnum0for file in ~/*.c; doif [ -f "$file" ]; then((num))fi doneecho "家目录下.c文件的个数: $num"...

如何使用Jinja定义dbt宏

dbt宏在dbt框架内的工作方式与传统编程中的函数类似。它允许用户将特定的、通常是重复的SQL逻辑封装到可调用的命名单元中,就像在其他编程语言中用函数来避免重复代码一样;dbt宏定义特定业务的SQL逻辑,然后在dbt项目中需要的地方调用该宏函数…...

深入理解 JavaScript 三大作用域:全局作用域、函数作用域、块级作用域

一. 作用域 对于多数编程语言,最基本的功能就是能够存储变量当中的值、并且允许我们对这个变量的值进行访问和修改。那么有了变量之后,应该把它放在哪里、程序如何找到它们?是否需要提前约定好一套存储变量、访问变量的规则?答案…...

【门牌制作 / A】

题目 代码 #include <bits/stdc.h> using namespace std; int main() {int cnt 0;for (int i 1; i < 2020; i){string s;s to_string(i);cnt count(s.begin(), s.end(), 2);}cout << cnt; }...

Git+Jenkins 基本使用(Basic Usage of Git+Jenkins)

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:Linux运维老纪的首页…...

智谱清言:智能语音交互的引领者,解锁高效沟通新体验

哪个编程工具让你的工作效率翻倍&#xff1f; 在日益繁忙的工作环境中&#xff0c;选择合适的编程工具已成为提升开发者工作效率的关键。不同的工具能够帮助我们简化代码编写、自动化任务、提升调试速度&#xff0c;甚至让团队协作更加顺畅。那么&#xff0c;哪款编程工具让你…...

前端组件库

vant2现在的地址 Vant 2 - Mobile UI Components built on Vue...

后端常用的mybatis-plus方法以及配合querywapper使用

目录 一、插入数据 save方法 二、删除操作 removeById方法 三、更新操作 updateById方法 四、查询操作 selectById方法 五、条件构造器QueryWrapper的更多用法 1.比较操作符 2.逻辑操作符 3.模糊查询 4.空值判断 一、插入数据 save方法 save(T entity):向数据库中插入…...

【设计模式】万字详解:深入掌握五大基础行为模式

作者&#xff1a;后端小肥肠 &#x1f347; 我写过的文章中的相关代码放到了gitee&#xff0c;地址&#xff1a;xfc-fdw-cloud: 公共解决方案 &#x1f34a; 有疑问可私信或评论区联系我。 &#x1f951; 创作不易未经允许严禁转载。 姊妹篇&#xff1a; 【设计模式】&#xf…...

C++ 9.19

练习&#xff1a;要求在堆区申请5个double类型的空间&#xff0c;用于存储5名学生的成绩。请自行封装函数完成 1> 空间的申请 2> 学生成绩的录入 3> 学生成绩的输出 4> 学生成绩进行降序排序 5> 释放申请的空间 主程序中用于测试上述函数 #include<ios…...

[Unity Demo]从零开始制作空洞骑士Hollow Knight第五集:再制作更多的敌人

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、制作敌人另个爬虫Crawler 1.公式化导入制作另个爬虫Crawler素材2.制作另个爬虫Crawler的Crawler.cs状态机3.制作敌人另个爬虫Crawler的playmaker状态机二、…...

喜马拉雅音频下载器:三分钟学会下载付费专辑的完整方案

喜马拉雅音频下载器&#xff1a;三分钟学会下载付费专辑的完整方案 【免费下载链接】xmly-downloader-qt5 喜马拉雅FM专辑下载器. 支持VIP与付费专辑. 使用GoQt5编写(Not Qt Binding). 项目地址: https://gitcode.com/gh_mirrors/xm/xmly-downloader-qt5 你是否遇到过这…...

告别丢帧!用CANoe 12+和VN5610A搞定CSM ECAT模块高速采集(附100kHz采样率避坑要点)

突破100kHz采样率瓶颈&#xff1a;CANoe 12与VN5610A高速数据采集全攻略 在汽车电子测试领域&#xff0c;高速数据采集一直是工程师面临的重大挑战。当采样率超过100kHz时&#xff0c;传统配置方式往往会出现数据丢帧、时间戳错乱等问题。本文将深入解析CANoe 12与VN5610A硬件组…...

惠普OMEN笔记本终极性能控制:OmenSuperHub 5分钟完全指南

惠普OMEN笔记本终极性能控制&#xff1a;OmenSuperHub 5分钟完全指南 【免费下载链接】OmenSuperHub 使用 WMI BIOS控制性能和风扇速度&#xff0c;自动解除DB功耗限制。 项目地址: https://gitcode.com/gh_mirrors/om/OmenSuperHub 想要彻底释放惠普OMEN游戏本的性能潜…...

Agent解析复杂PDF表格时效果极差,如何自动化处理?

斯坦福大学教授、AI领域顶尖学者吴恩达近日明确表示&#xff1a;不会有AI就业末日。在他看来&#xff0c;AI会影响岗位、改变技能要求、也会替代一部分任务&#xff0c;但将其描绘成大规模失业灾难&#xff0c;“是在制造不必要的恐惧&#xff0c;也是不负责任的”。与其担忧被…...

智能门锁语音方案:WTVXXX-32N芯片一体化设计与低功耗实现

1. 项目概述&#xff1a;当智能门锁遇上“会说话”的芯片最近在做一个智能门锁的后板方案整合项目&#xff0c;客户提了个挺有意思的需求&#xff1a;他们希望门锁在完成每一次开锁、上锁、或者遇到异常情况时&#xff0c;不仅能通过手机APP推送通知&#xff0c;还能在现场给用…...

Kilim Actor模型实践:构建高并发消息传递系统的终极指南 [特殊字符]

Kilim Actor模型实践&#xff1a;构建高并发消息传递系统的终极指南 &#x1f680; 【免费下载链接】kilim Lightweight threads for Java, with message passing, nio, http and scheduling support. 项目地址: https://gitcode.com/gh_mirrors/ki/kilim Kilim是一个强…...

ncmdump终极指南:5分钟解锁网易云音乐NCM加密文件

ncmdump终极指南&#xff1a;5分钟解锁网易云音乐NCM加密文件 【免费下载链接】ncmdump 项目地址: https://gitcode.com/gh_mirrors/ncmd/ncmdump 你是否曾在网易云音乐下载了心爱的歌曲&#xff0c;却发现只能在特定客户端播放&#xff1f;当你想在车载音响、智能音箱…...

YOLOv8实时目标检测与自适应控制技术在游戏辅助系统中的应用研究

YOLOv8实时目标检测与自适应控制技术在游戏辅助系统中的应用研究 【免费下载链接】RookieAI_yolov8 基于yolov8实现的AI自瞄项目 AI self-aiming project based on yolov8 项目地址: https://gitcode.com/gh_mirrors/ro/RookieAI_yolov8 技术挑战剖析&#xff1a;实时游…...

用 Articraft 制作可动 3D 资产

如果你想做一个“能开合的台灯、能转动的风扇、能拉开的抽屉柜”&#xff0c;传统 3D 工作流通常意味着&#xff1a;建模、拆分部件、定义关节、反复调试、再导出到下游系统。 问题是&#xff0c;这类“可动对象”并不只是静态几何体&#xff0c;它们还需要语义化部件、合理结构…...

个人开发者如何通过TaoToken以更低成本体验多种主流大模型

&#x1f680; 告别海外账号与网络限制&#xff01;稳定直连全球优质大模型&#xff0c;限时半价接入中。 &#x1f449; 点击领取海量免费额度 个人开发者如何通过TaoToken以更低成本体验多种主流大模型 对于预算有限的个人开发者和学生而言&#xff0c;直接接入和使用多个主…...