神经网络分类和回归任务实战
学习方法:torch 边用边学,边查边学 真正用查的过程才是学习的过程
直接上案例,先来跑,遇到什么解决什么
数据集Minist 数据集
做简单的任务 Minist 分类任务
总体代码(可以跑通)
from pathlib import Path
import requests
import pickle
import gzip
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch import optim
import numpy as np
bs=64
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
from matplotlib import pyplot
import numpy as np
print(x_train.shape)# pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
#获取训练数据和测试数据
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
#设置模型结构
class Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x
net = Mnist_NN()
# print(net)
# for name, parameter in net.named_parameters():
# print(name, parameter,parameter.size())
#设置数据集和数据集加载器
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=64 * 2)
def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)loss_func = F.cross_entropy
def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)
#训练参数
def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))
def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(20, model, loss_func, opt, train_dl, valid_dl)
corret=0
total=0
for xb,yb in valid_dl:outputs=model(xb)_,predicted=torch.max(outputs.data,1)total+=yb.size(0)corret+=(predicted==yb).sum().item()
print('准确率是:%d %%'%(100*corret/total))
1.首先我们从最终实现的fit 函数开始看,
在fit h函数之前有一个get_model 函数 得到model和优化器
model, opt = get_model()

得到模型的优化器以后
需要把训练轮数 模型 损失函数 训练数据 测试数据传入fit 训练函数
fit(20, model, loss_func, opt, train_dl, valid_dl)
fit 函数
def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))

xb是从dataloader 中取64个训练数据图片 也就是64*784维度,784代表28*28的手写数字图片的展平成一维向量
yb是64个图片对应的数字值
loss_batch(model, loss_func, xb, yb, opt)

我们看一下loss_batch 函数
loss 反向传播——更新优化器——优化器梯度归0
loss 是一个带有梯度的tensor .item()返回的是loss 的值 len(xb )是为了求精度的时候算
输入是模型损失函数xb ,yb 和优化器
loss_func = F.cross_entropy 损失函数是交叉熵损失函数
将xb 经过model 得到输出后 和xb 求损失函数 model 是定义的一个简单的模型有两个隐藏层一个输出层 784——128——256——10

再回到fit 函数的验证部分model.evl()
先来一个
with torch.no_grad()
不去计算梯度
zip(*的意思是解压缩 分别得到losses 和nums)
loss 和num 鲜橙 求每64个batch 的总loss 再将datalosder的所有batch 相加除以总数得到训练损失
相关文章:
神经网络分类和回归任务实战
学习方法:torch 边用边学,边查边学 真正用查的过程才是学习的过程 直接上案例,先来跑,遇到什么解决什么 数据集Minist 数据集 做简单的任务 Minist 分类任务 总体代码(可以跑通) from pathlib import …...
【数据结构】考研真题攻克与重点知识点剖析 - 第 4 篇:串
前言 本文基础知识部分来自于b站:分享笔记的好人儿的思维导图与王道考研课程,感谢大佬的开源精神,习题来自老师划的重点以及考研真题。此前我尝试了完全使用Python或是结合大语言模型对考研真题进行数据清洗与可视化分析,本人技术…...
深入浅出 -- 系统架构之分布式多形态的存储型集群
一、多形态的存储型集群 在上阶段,我们简单聊了下集群的基本知识,以及快速过了一下逻辑处理型集群的内容,下面重点来看看存储型集群,毕竟这块才是重头戏,集群的形态在其中有着多种多样的变化。 逻辑处理型的应用&…...
STL —— list
博主首页: 有趣的中国人 专栏首页: C专栏 本篇文章主要讲解 list模拟实现的相关内容 1. list简介 列表(list)是C标准模板库(STL)中的一个容器,它是一个双向链表数据结构,…...
申请SSL证书
有很多方法可以确保您的网站安全。添加SSL证书可针对恶意攻击提供额外且关键的保护层。 即使网站不接受交易,您仍然需要保护用户的登录详细信息、地址和其他个人信息。 没有SSL证书的网站使用HTTP(一种基于文本的协议),这意味着…...
深入浅出 -- 系统架构之负载均衡Nginx环境搭建
引入负载均衡技术可带来的收益: 系统的高可用:当某个节点宕机后可以迅速将流量转移至其他节点。系统的高性能:多台服务器共同对外提供服务,为整个系统提供了更高规模的吞吐。系统的拓展性:当业务再次出现增长或萎靡时…...
notepad++绿色版添加右键菜单
解压路径 D:\Green\notepad_v8.0_x64_绿色版 添加右键菜单.reg 新建nodepad添加右键菜单.reg文件 Windows Registry Editor Version 5.00[HKEY_CLASSES_ROOT\*\shell\NotePad] "Edit with &Notepad" "Icon""D:\\Green\\notepad_v8.0_x64_绿色版…...
7 个 iMessage 恢复应用程序/软件可轻松恢复文本
由于误操作、iOS 升级中断、越狱失败、设备损坏等原因,您可能会丢失 iPhone/iPad 上的 iMessages。意外删除很大程度上增加了这种可能性。更糟糕的是,这种情况经常发生在 iDevice 缺乏备份的情况下。 (iPhone消息消失还占用空间?&…...
DockerFile启动jar程序
1.创建Dockerfile 在项目的根目录下创建一个名为Dockerfile的文件,并使用文本编辑器打开它。Dockerfile的内容如下: # 基础镜像 FROM openjdk:8-jre # 创建目录 RUN mkdir -p /usr/app/ # 设置工作目录 WORKDIR /usr/app # 将JAR文件复制到容器中,注:…...
基于R、Python的Copula变量相关性分析及AI大模型应用
在工程、水文和金融等各学科的研究中,总是会遇到很多变量,研究这些相互纠缠的变量间的相关关系是各学科的研究的重点。虽然皮尔逊相关、秩相关等相关系数提供了变量间相关关系的粗略结果,但这些系数都存在着无法克服的困难。例如,…...
鸿蒙组件学习_Tabs组件
说明 该组件从API Version 7 开始支持。 子组件 仅可包含子组件TabContent 参数 barPosition 设置Tabs的页签位置,默认值: BarPosition.StartStart vertical属性方法设置为true时,页签位于容器左侧;vertical属性方法设置为false时,页签位于容器顶部。End vertic…...
【LangChain学习之旅】—(19)BabyAGI:根据气候变化自动制定鲜花存储策略
【LangChain学习之旅】—(19)BabyAGI:根据气候变化自动制定鲜花存储策略 AutoGPTBaby AGIHuggingGPTLangChain 目前是将基于 CAMEL 框架的代理定义为 Simulation Agents(模拟代理)。这种代理在模拟环境中进行角色扮演,试图模拟特定场景或行为,而不是在真实世界中完成具体…...
thinkphp6入门(21)-- 如何删除图片、文件
假设文件的位置在 /*** 删除文件* $file_name avatar/20240208/d71d108bc1086b498df5191f9f925db3.jpg*/ function deleteFile($file_name) {// 要删除的文件路径$file app()->getRootPath() . public/uploads/ . $file_name; $result [];if (is_file($file)) {if (unlin…...
虚拟内存知识详解
虚拟内存 单片机的 CPU 是直接操作内存的「物理地址」 在这种情况下,要想在内存中同时运行两个程序是不可能的 操作系统是如何解决这个问题呢? 关键的问题是这两个程序都引用了绝对物理地址,而这正是我们最需要避免的。 可以把进程所使用的…...
数据结构初阶:顺序表和链表
线性表 线性表 ( linear list ) 是 n 个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使 用的数据结构,常见的线性表:顺序表、链表、栈、队列、字符串 ... 线性表在逻辑上是线性结构,也就说是连续的一条直线。但是在物理结构上并不一定是连续的, 线性…...
在flutter中添加video_player【视频播放插件】
添加插件依赖 dependencies:video_player: ^2.8.3插件的用途 在Flutter框架中,video_player 插件是一个专门用于播放视频的插件。它允许开发者在Flutter应用中嵌入视频播放器,并提供了一系列功能来控制和定制视频播放体验。这个插件对于需要在应用中展…...
golang微服务框架特性分析及选型
目录 一、微服务框架特性(10个)包括:Istio、go-zero、go-kit、go-kratos、go-micro、rpcx、kitex、goa、jupiter、dubbo-go、tarsgo 1、特性及使用场景2、比较 二、web框架特性(7个)包括:gin、fiber、beego…...
苹果cmsV10 MXProV4.5自适应PC手机影视站主题模板苹果cms模板mxone pro
演示站:http://a.88531.cn:8016 MXPro 模板主题(又名:mxonepro)是一款基于苹果 cms程序的一款全新的简洁好看 UI 的影视站模板类似于西瓜视频,不过同对比 MxoneV10 魔改模板来说功能没有那么多,也没有那么大气,但是比较且可视化功…...
GPU的了解
3D动画揭秘显卡的GPU是如何工作的_哔哩哔哩_bilibili 位于显卡中。 与CPU区别: 100名小学生和1位数学博士 做100道非常简单的算术题,小朋友一个人一道题,比博士快。 做1道非常复杂的数学问题,只有博士可以做出来。 CPU主要用于快…...
鸿蒙实战开发-如何使用Stage模型卡片
介绍 本示例展示了Stage模型卡片提供方的创建与使用。 用到了卡片扩展模块接口,ohos.app.form.FormExtensionAbility 。 卡片信息和状态等相关类型和枚举接口,ohos.app.form.formInfo 。 卡片提供方相关接口的能力接口,ohos.app.form.for…...
Linux链表操作全解析
Linux C语言链表深度解析与实战技巧 一、链表基础概念与内核链表优势1.1 为什么使用链表?1.2 Linux 内核链表与用户态链表的区别 二、内核链表结构与宏解析常用宏/函数 三、内核链表的优点四、用户态链表示例五、双向循环链表在内核中的实现优势5.1 插入效率5.2 安全…...
Java如何权衡是使用无序的数组还是有序的数组
在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...
Docker 运行 Kafka 带 SASL 认证教程
Docker 运行 Kafka 带 SASL 认证教程 Docker 运行 Kafka 带 SASL 认证教程一、说明二、环境准备三、编写 Docker Compose 和 jaas文件docker-compose.yml代码说明:server_jaas.conf 四、启动服务五、验证服务六、连接kafka服务七、总结 Docker 运行 Kafka 带 SASL 认…...
Leetcode 3577. Count the Number of Computer Unlocking Permutations
Leetcode 3577. Count the Number of Computer Unlocking Permutations 1. 解题思路2. 代码实现 题目链接:3577. Count the Number of Computer Unlocking Permutations 1. 解题思路 这一题其实就是一个脑筋急转弯,要想要能够将所有的电脑解锁&#x…...
全球首个30米分辨率湿地数据集(2000—2022)
数据简介 今天我们分享的数据是全球30米分辨率湿地数据集,包含8种湿地亚类,该数据以0.5X0.5的瓦片存储,我们整理了所有属于中国的瓦片名称与其对应省份,方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…...
Android Bitmap治理全解析:从加载优化到泄漏防控的全生命周期管理
引言 Bitmap(位图)是Android应用内存占用的“头号杀手”。一张1080P(1920x1080)的图片以ARGB_8888格式加载时,内存占用高达8MB(192010804字节)。据统计,超过60%的应用OOM崩溃与Bitm…...
深度学习习题2
1.如果增加神经网络的宽度,精确度会增加到一个特定阈值后,便开始降低。造成这一现象的可能原因是什么? A、即使增加卷积核的数量,只有少部分的核会被用作预测 B、当卷积核数量增加时,神经网络的预测能力会降低 C、当卷…...
Hive 存储格式深度解析:从 TextFile 到 ORC,如何选对数据存储方案?
在大数据处理领域,Hive 作为 Hadoop 生态中重要的数据仓库工具,其存储格式的选择直接影响数据存储成本、查询效率和计算资源消耗。面对 TextFile、SequenceFile、Parquet、RCFile、ORC 等多种存储格式,很多开发者常常陷入选择困境。本文将从底…...
【生成模型】视频生成论文调研
工作清单 上游应用方向:控制、速度、时长、高动态、多主体驱动 类型工作基础模型WAN / WAN-VACE / HunyuanVideo控制条件轨迹控制ATI~镜头控制ReCamMaster~多主体驱动Phantom~音频驱动Let Them Talk: Audio-Driven Multi-Person Conversational Video Generation速…...
PAN/FPN
import torch import torch.nn as nn import torch.nn.functional as F import mathclass LowResQueryHighResKVAttention(nn.Module):"""方案 1: 低分辨率特征 (Query) 查询高分辨率特征 (Key, Value).输出分辨率与低分辨率输入相同。"""def __…...
