神经网络分类和回归任务实战
学习方法:torch 边用边学,边查边学 真正用查的过程才是学习的过程
直接上案例,先来跑,遇到什么解决什么
数据集Minist 数据集
做简单的任务 Minist 分类任务
总体代码(可以跑通)
from pathlib import Path
import requests
import pickle
import gzip
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch import optim
import numpy as np
bs=64
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
from matplotlib import pyplot
import numpy as np
print(x_train.shape)# pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
#获取训练数据和测试数据
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
#设置模型结构
class Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x
net = Mnist_NN()
# print(net)
# for name, parameter in net.named_parameters():
# print(name, parameter,parameter.size())
#设置数据集和数据集加载器
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=64 * 2)
def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)loss_func = F.cross_entropy
def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)
#训练参数
def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))
def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(20, model, loss_func, opt, train_dl, valid_dl)
corret=0
total=0
for xb,yb in valid_dl:outputs=model(xb)_,predicted=torch.max(outputs.data,1)total+=yb.size(0)corret+=(predicted==yb).sum().item()
print('准确率是:%d %%'%(100*corret/total))
1.首先我们从最终实现的fit 函数开始看,
在fit h函数之前有一个get_model 函数 得到model和优化器
model, opt = get_model()
得到模型的优化器以后
需要把训练轮数 模型 损失函数 训练数据 测试数据传入fit 训练函数
fit(20, model, loss_func, opt, train_dl, valid_dl)
fit 函数
def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))
xb是从dataloader 中取64个训练数据图片 也就是64*784维度,784代表28*28的手写数字图片的展平成一维向量
yb是64个图片对应的数字值
loss_batch(model, loss_func, xb, yb, opt)
我们看一下loss_batch 函数
loss 反向传播——更新优化器——优化器梯度归0
loss 是一个带有梯度的tensor .item()返回的是loss 的值 len(xb )是为了求精度的时候算
输入是模型损失函数xb ,yb 和优化器
loss_func = F.cross_entropy 损失函数是交叉熵损失函数
将xb 经过model 得到输出后 和xb 求损失函数 model 是定义的一个简单的模型有两个隐藏层一个输出层 784——128——256——10
再回到fit 函数的验证部分model.evl()
先来一个
with torch.no_grad()
不去计算梯度
zip(*的意思是解压缩 分别得到losses 和nums)
loss 和num 鲜橙 求每64个batch 的总loss 再将datalosder的所有batch 相加除以总数得到训练损失
相关文章:

神经网络分类和回归任务实战
学习方法:torch 边用边学,边查边学 真正用查的过程才是学习的过程 直接上案例,先来跑,遇到什么解决什么 数据集Minist 数据集 做简单的任务 Minist 分类任务 总体代码(可以跑通) from pathlib import …...

【数据结构】考研真题攻克与重点知识点剖析 - 第 4 篇:串
前言 本文基础知识部分来自于b站:分享笔记的好人儿的思维导图与王道考研课程,感谢大佬的开源精神,习题来自老师划的重点以及考研真题。此前我尝试了完全使用Python或是结合大语言模型对考研真题进行数据清洗与可视化分析,本人技术…...

深入浅出 -- 系统架构之分布式多形态的存储型集群
一、多形态的存储型集群 在上阶段,我们简单聊了下集群的基本知识,以及快速过了一下逻辑处理型集群的内容,下面重点来看看存储型集群,毕竟这块才是重头戏,集群的形态在其中有着多种多样的变化。 逻辑处理型的应用&…...

STL —— list
博主首页: 有趣的中国人 专栏首页: C专栏 本篇文章主要讲解 list模拟实现的相关内容 1. list简介 列表(list)是C标准模板库(STL)中的一个容器,它是一个双向链表数据结构,…...

申请SSL证书
有很多方法可以确保您的网站安全。添加SSL证书可针对恶意攻击提供额外且关键的保护层。 即使网站不接受交易,您仍然需要保护用户的登录详细信息、地址和其他个人信息。 没有SSL证书的网站使用HTTP(一种基于文本的协议),这意味着…...

深入浅出 -- 系统架构之负载均衡Nginx环境搭建
引入负载均衡技术可带来的收益: 系统的高可用:当某个节点宕机后可以迅速将流量转移至其他节点。系统的高性能:多台服务器共同对外提供服务,为整个系统提供了更高规模的吞吐。系统的拓展性:当业务再次出现增长或萎靡时…...
notepad++绿色版添加右键菜单
解压路径 D:\Green\notepad_v8.0_x64_绿色版 添加右键菜单.reg 新建nodepad添加右键菜单.reg文件 Windows Registry Editor Version 5.00[HKEY_CLASSES_ROOT\*\shell\NotePad] "Edit with &Notepad" "Icon""D:\\Green\\notepad_v8.0_x64_绿色版…...

7 个 iMessage 恢复应用程序/软件可轻松恢复文本
由于误操作、iOS 升级中断、越狱失败、设备损坏等原因,您可能会丢失 iPhone/iPad 上的 iMessages。意外删除很大程度上增加了这种可能性。更糟糕的是,这种情况经常发生在 iDevice 缺乏备份的情况下。 (iPhone消息消失还占用空间?&…...
DockerFile启动jar程序
1.创建Dockerfile 在项目的根目录下创建一个名为Dockerfile的文件,并使用文本编辑器打开它。Dockerfile的内容如下: # 基础镜像 FROM openjdk:8-jre # 创建目录 RUN mkdir -p /usr/app/ # 设置工作目录 WORKDIR /usr/app # 将JAR文件复制到容器中,注:…...

基于R、Python的Copula变量相关性分析及AI大模型应用
在工程、水文和金融等各学科的研究中,总是会遇到很多变量,研究这些相互纠缠的变量间的相关关系是各学科的研究的重点。虽然皮尔逊相关、秩相关等相关系数提供了变量间相关关系的粗略结果,但这些系数都存在着无法克服的困难。例如,…...
鸿蒙组件学习_Tabs组件
说明 该组件从API Version 7 开始支持。 子组件 仅可包含子组件TabContent 参数 barPosition 设置Tabs的页签位置,默认值: BarPosition.StartStart vertical属性方法设置为true时,页签位于容器左侧;vertical属性方法设置为false时,页签位于容器顶部。End vertic…...
【LangChain学习之旅】—(19)BabyAGI:根据气候变化自动制定鲜花存储策略
【LangChain学习之旅】—(19)BabyAGI:根据气候变化自动制定鲜花存储策略 AutoGPTBaby AGIHuggingGPTLangChain 目前是将基于 CAMEL 框架的代理定义为 Simulation Agents(模拟代理)。这种代理在模拟环境中进行角色扮演,试图模拟特定场景或行为,而不是在真实世界中完成具体…...

thinkphp6入门(21)-- 如何删除图片、文件
假设文件的位置在 /*** 删除文件* $file_name avatar/20240208/d71d108bc1086b498df5191f9f925db3.jpg*/ function deleteFile($file_name) {// 要删除的文件路径$file app()->getRootPath() . public/uploads/ . $file_name; $result [];if (is_file($file)) {if (unlin…...
虚拟内存知识详解
虚拟内存 单片机的 CPU 是直接操作内存的「物理地址」 在这种情况下,要想在内存中同时运行两个程序是不可能的 操作系统是如何解决这个问题呢? 关键的问题是这两个程序都引用了绝对物理地址,而这正是我们最需要避免的。 可以把进程所使用的…...

数据结构初阶:顺序表和链表
线性表 线性表 ( linear list ) 是 n 个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使 用的数据结构,常见的线性表:顺序表、链表、栈、队列、字符串 ... 线性表在逻辑上是线性结构,也就说是连续的一条直线。但是在物理结构上并不一定是连续的, 线性…...

在flutter中添加video_player【视频播放插件】
添加插件依赖 dependencies:video_player: ^2.8.3插件的用途 在Flutter框架中,video_player 插件是一个专门用于播放视频的插件。它允许开发者在Flutter应用中嵌入视频播放器,并提供了一系列功能来控制和定制视频播放体验。这个插件对于需要在应用中展…...
golang微服务框架特性分析及选型
目录 一、微服务框架特性(10个)包括:Istio、go-zero、go-kit、go-kratos、go-micro、rpcx、kitex、goa、jupiter、dubbo-go、tarsgo 1、特性及使用场景2、比较 二、web框架特性(7个)包括:gin、fiber、beego…...

苹果cmsV10 MXProV4.5自适应PC手机影视站主题模板苹果cms模板mxone pro
演示站:http://a.88531.cn:8016 MXPro 模板主题(又名:mxonepro)是一款基于苹果 cms程序的一款全新的简洁好看 UI 的影视站模板类似于西瓜视频,不过同对比 MxoneV10 魔改模板来说功能没有那么多,也没有那么大气,但是比较且可视化功…...
GPU的了解
3D动画揭秘显卡的GPU是如何工作的_哔哩哔哩_bilibili 位于显卡中。 与CPU区别: 100名小学生和1位数学博士 做100道非常简单的算术题,小朋友一个人一道题,比博士快。 做1道非常复杂的数学问题,只有博士可以做出来。 CPU主要用于快…...

鸿蒙实战开发-如何使用Stage模型卡片
介绍 本示例展示了Stage模型卡片提供方的创建与使用。 用到了卡片扩展模块接口,ohos.app.form.FormExtensionAbility 。 卡片信息和状态等相关类型和枚举接口,ohos.app.form.formInfo 。 卡片提供方相关接口的能力接口,ohos.app.form.for…...

Linux 文件类型,目录与路径,文件与目录管理
文件类型 后面的字符表示文件类型标志 普通文件:-(纯文本文件,二进制文件,数据格式文件) 如文本文件、图片、程序文件等。 目录文件:d(directory) 用来存放其他文件或子目录。 设备…...
golang循环变量捕获问题
在 Go 语言中,当在循环中启动协程(goroutine)时,如果在协程闭包中直接引用循环变量,可能会遇到一个常见的陷阱 - 循环变量捕获问题。让我详细解释一下: 问题背景 看这个代码片段: fo…...

AI Agent与Agentic AI:原理、应用、挑战与未来展望
文章目录 一、引言二、AI Agent与Agentic AI的兴起2.1 技术契机与生态成熟2.2 Agent的定义与特征2.3 Agent的发展历程 三、AI Agent的核心技术栈解密3.1 感知模块代码示例:使用Python和OpenCV进行图像识别 3.2 认知与决策模块代码示例:使用OpenAI GPT-3进…...

无法与IP建立连接,未能下载VSCode服务器
如题,在远程连接服务器的时候突然遇到了这个提示。 查阅了一圈,发现是VSCode版本自动更新惹的祸!!! 在VSCode的帮助->关于这里发现前几天VSCode自动更新了,我的版本号变成了1.100.3 才导致了远程连接出…...

12.找到字符串中所有字母异位词
🧠 题目解析 题目描述: 给定两个字符串 s 和 p,找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义: 若两个字符串包含的字符种类和出现次数完全相同,顺序无所谓,则互为…...
【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的“no matching...“系列算法协商失败问题
【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的"no matching..."系列算法协商失败问题 摘要: 近期,在使用较新版本的OpenSSH客户端连接老旧SSH服务器时,会遇到 "no matching key exchange method found", "n…...

2025年渗透测试面试题总结-腾讯[实习]科恩实验室-安全工程师(题目+回答)
安全领域各种资源,学习文档,以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具,欢迎关注。 目录 腾讯[实习]科恩实验室-安全工程师 一、网络与协议 1. TCP三次握手 2. SYN扫描原理 3. HTTPS证书机制 二…...
NPOI操作EXCEL文件 ——CAD C# 二次开发
缺点:dll.版本容易加载错误。CAD加载插件时,没有加载所有类库。插件运行过程中用到某个类库,会从CAD的安装目录找,找不到就报错了。 【方案2】让CAD在加载过程中把类库加载到内存 【方案3】是发现缺少了哪个库,就用插件程序加载进…...

【Linux】Linux安装并配置RabbitMQ
目录 1. 安装 Erlang 2. 安装 RabbitMQ 2.1.添加 RabbitMQ 仓库 2.2.安装 RabbitMQ 3.配置 3.1.启动和管理服务 4. 访问管理界面 5.安装问题 6.修改密码 7.修改端口 7.1.找到文件 7.2.修改文件 1. 安装 Erlang 由于 RabbitMQ 是用 Erlang 编写的,需要先安…...

基于江科大stm32屏幕驱动,实现OLED多级菜单(动画效果),结构体链表实现(独创源码)
引言 在嵌入式系统中,用户界面的设计往往直接影响到用户体验。本文将以STM32微控制器和OLED显示屏为例,介绍如何实现一个多级菜单系统。该系统支持用户通过按键导航菜单,执行相应操作,并提供平滑的滚动动画效果。 本文设计了一个…...