当前位置: 首页 > news >正文

神经网络分类和回归任务实战

学习方法:torch 边用边学,边查边学 真正用查的过程才是学习的过程

直接上案例,先来跑,遇到什么解决什么

数据集Minist 数据集

做简单的任务 Minist 分类任务

总体代码(可以跑通)

from pathlib import Path
import requests
import pickle
import gzip
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch import optim
import numpy as np
bs=64
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
from matplotlib import pyplot
import numpy as np
print(x_train.shape)# pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
#获取训练数据和测试数据
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
#设置模型结构
class Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x
net = Mnist_NN()
# print(net)
# for name, parameter in net.named_parameters():
#     print(name, parameter,parameter.size())
#设置数据集和数据集加载器
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=64 * 2)
def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)loss_func = F.cross_entropy
def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)
#训练参数
def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))
def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(20, model, loss_func, opt, train_dl, valid_dl)
corret=0
total=0
for xb,yb in valid_dl:outputs=model(xb)_,predicted=torch.max(outputs.data,1)total+=yb.size(0)corret+=(predicted==yb).sum().item()
print('准确率是:%d %%'%(100*corret/total))

1.首先我们从最终实现的fit 函数开始看,

 在fit h函数之前有一个get_model 函数 得到model和优化器

model, opt = get_model()

得到模型的优化器以后

需要把训练轮数 模型 损失函数 训练数据 测试数据传入fit 训练函数

fit(20, model, loss_func, opt, train_dl, valid_dl)

fit 函数

def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))

xb是从dataloader 中取64个训练数据图片 也就是64*784维度,784代表28*28的手写数字图片的展平成一维向量

yb是64个图片对应的数字值

loss_batch(model, loss_func, xb, yb, opt)

我们看一下loss_batch 函数

loss 反向传播——更新优化器——优化器梯度归0

loss 是一个带有梯度的tensor    .item()返回的是loss 的值 len(xb )是为了求精度的时候算

输入是模型损失函数xb ,yb 和优化器

loss_func = F.cross_entropy 损失函数是交叉熵损失函数

将xb 经过model 得到输出后 和xb 求损失函数  model 是定义的一个简单的模型有两个隐藏层一个输出层 784——128——256——10

再回到fit 函数的验证部分model.evl()

先来一个 

with torch.no_grad()

不去计算梯度

zip(*的意思是解压缩 分别得到losses 和nums)

loss 和num 鲜橙 求每64个batch 的总loss 再将datalosder的所有batch 相加除以总数得到训练损失

相关文章:

神经网络分类和回归任务实战

学习方法:torch 边用边学,边查边学 真正用查的过程才是学习的过程 直接上案例,先来跑,遇到什么解决什么 数据集Minist 数据集 做简单的任务 Minist 分类任务 总体代码(可以跑通) from pathlib import …...

【数据结构】考研真题攻克与重点知识点剖析 - 第 4 篇:串

前言 本文基础知识部分来自于b站:分享笔记的好人儿的思维导图与王道考研课程,感谢大佬的开源精神,习题来自老师划的重点以及考研真题。此前我尝试了完全使用Python或是结合大语言模型对考研真题进行数据清洗与可视化分析,本人技术…...

深入浅出 -- 系统架构之分布式多形态的存储型集群

一、多形态的存储型集群 在上阶段,我们简单聊了下集群的基本知识,以及快速过了一下逻辑处理型集群的内容,下面重点来看看存储型集群,毕竟这块才是重头戏,集群的形态在其中有着多种多样的变化。 逻辑处理型的应用&…...

STL —— list

博主首页: 有趣的中国人 专栏首页: C专栏 本篇文章主要讲解 list模拟实现的相关内容 1. list简介 列表(list)是C标准模板库(STL)中的一个容器,它是一个双向链表数据结构&#xff0c…...

申请SSL证书

有很多方法可以确保您的网站安全。添加SSL证书可针对恶意攻击提供额外且关键的保护层。 即使网站不接受交易,您仍然需要保护用户的登录详细信息、地址和其他个人信息。 没有SSL证书的网站使用HTTP(一种基于文本的协议),这意味着…...

深入浅出 -- 系统架构之负载均衡Nginx环境搭建

引入负载均衡技术可带来的收益: 系统的高可用:当某个节点宕机后可以迅速将流量转移至其他节点。系统的高性能:多台服务器共同对外提供服务,为整个系统提供了更高规模的吞吐。系统的拓展性:当业务再次出现增长或萎靡时…...

notepad++绿色版添加右键菜单

解压路径 D:\Green\notepad_v8.0_x64_绿色版 添加右键菜单.reg 新建nodepad添加右键菜单.reg文件 Windows Registry Editor Version 5.00[HKEY_CLASSES_ROOT\*\shell\NotePad] "Edit with &Notepad" "Icon""D:\\Green\\notepad_v8.0_x64_绿色版…...

7 个 iMessage 恢复应用程序/软件可轻松恢复文本

由于误操作、iOS 升级中断、越狱失败、设备损坏等原因,您可能会丢失 iPhone/iPad 上的 iMessages。意外删除很大程度上增加了这种可能性。更糟糕的是,这种情况经常发生在 iDevice 缺乏备份的情况下。 (iPhone消息消失还占用空间?&…...

DockerFile启动jar程序

1.创建Dockerfile 在项目的根目录下创建一个名为Dockerfile的文件,并使用文本编辑器打开它。Dockerfile的内容如下: # 基础镜像 FROM openjdk:8-jre # 创建目录 RUN mkdir -p /usr/app/ # 设置工作目录 WORKDIR /usr/app # 将JAR文件复制到容器中,注:…...

基于R、Python的Copula变量相关性分析及AI大模型应用

在工程、水文和金融等各学科的研究中,总是会遇到很多变量,研究这些相互纠缠的变量间的相关关系是各学科的研究的重点。虽然皮尔逊相关、秩相关等相关系数提供了变量间相关关系的粗略结果,但这些系数都存在着无法克服的困难。例如,…...

鸿蒙组件学习_Tabs组件

说明 该组件从API Version 7 开始支持。 子组件 仅可包含子组件TabContent 参数 barPosition 设置Tabs的页签位置,默认值: BarPosition.StartStart vertical属性方法设置为true时,页签位于容器左侧;vertical属性方法设置为false时,页签位于容器顶部。End vertic…...

【LangChain学习之旅】—(19)BabyAGI:根据气候变化自动制定鲜花存储策略

【LangChain学习之旅】—(19)BabyAGI:根据气候变化自动制定鲜花存储策略 AutoGPTBaby AGIHuggingGPTLangChain 目前是将基于 CAMEL 框架的代理定义为 Simulation Agents(模拟代理)。这种代理在模拟环境中进行角色扮演,试图模拟特定场景或行为,而不是在真实世界中完成具体…...

thinkphp6入门(21)-- 如何删除图片、文件

假设文件的位置在 /*** 删除文件* $file_name avatar/20240208/d71d108bc1086b498df5191f9f925db3.jpg*/ function deleteFile($file_name) {// 要删除的文件路径$file app()->getRootPath() . public/uploads/ . $file_name; $result [];if (is_file($file)) {if (unlin…...

虚拟内存知识详解

虚拟内存 单片机的 CPU 是直接操作内存的「物理地址」 在这种情况下,要想在内存中同时运行两个程序是不可能的 操作系统是如何解决这个问题呢? 关键的问题是这两个程序都引用了绝对物理地址,而这正是我们最需要避免的。 可以把进程所使用的…...

数据结构初阶:顺序表和链表

线性表 线性表 ( linear list ) 是 n 个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使 用的数据结构,常见的线性表:顺序表、链表、栈、队列、字符串 ... 线性表在逻辑上是线性结构,也就说是连续的一条直线。但是在物理结构上并不一定是连续的, 线性…...

在flutter中添加video_player【视频播放插件】

添加插件依赖 dependencies:video_player: ^2.8.3插件的用途 在Flutter框架中,video_player 插件是一个专门用于播放视频的插件。它允许开发者在Flutter应用中嵌入视频播放器,并提供了一系列功能来控制和定制视频播放体验。这个插件对于需要在应用中展…...

golang微服务框架特性分析及选型

目录 一、微服务框架特性(10个)包括:Istio、go-zero、go-kit、go-kratos、go-micro、rpcx、kitex、goa、jupiter、dubbo-go、tarsgo 1、特性及使用场景2、比较 二、web框架特性(7个)包括:gin、fiber、beego…...

苹果cmsV10 MXProV4.5自适应PC手机影视站主题模板苹果cms模板mxone pro

演示站:http://a.88531.cn:8016 MXPro 模板主题(又名:mxonepro)是一款基于苹果 cms程序的一款全新的简洁好看 UI 的影视站模板类似于西瓜视频,不过同对比 MxoneV10 魔改模板来说功能没有那么多,也没有那么大气,但是比较且可视化功…...

GPU的了解

3D动画揭秘显卡的GPU是如何工作的_哔哩哔哩_bilibili 位于显卡中。 与CPU区别: 100名小学生和1位数学博士 做100道非常简单的算术题,小朋友一个人一道题,比博士快。 做1道非常复杂的数学问题,只有博士可以做出来。 CPU主要用于快…...

鸿蒙实战开发-如何使用Stage模型卡片

介绍 本示例展示了Stage模型卡片提供方的创建与使用。 用到了卡片扩展模块接口,ohos.app.form.FormExtensionAbility 。 卡片信息和状态等相关类型和枚举接口,ohos.app.form.formInfo 。 卡片提供方相关接口的能力接口,ohos.app.form.for…...

【Linux】C语言执行shell指令

在C语言中执行Shell指令 在C语言中&#xff0c;有几种方法可以执行Shell指令&#xff1a; 1. 使用system()函数 这是最简单的方法&#xff0c;包含在stdlib.h头文件中&#xff1a; #include <stdlib.h>int main() {system("ls -l"); // 执行ls -l命令retu…...

C++使用 new 来创建动态数组

问题&#xff1a; 不能使用变量定义数组大小 原因&#xff1a; 这是因为数组在内存中是连续存储的&#xff0c;编译器需要在编译阶段就确定数组的大小&#xff0c;以便正确地分配内存空间。如果允许使用变量来定义数组的大小&#xff0c;那么编译器就无法在编译时确定数组的大…...

站群服务器的应用场景都有哪些?

站群服务器主要是为了多个网站的托管和管理所设计的&#xff0c;可以通过集中管理和高效资源的分配&#xff0c;来支持多个独立的网站同时运行&#xff0c;让每一个网站都可以分配到独立的IP地址&#xff0c;避免出现IP关联的风险&#xff0c;用户还可以通过控制面板进行管理功…...

LangFlow技术架构分析

&#x1f527; LangFlow 的可视化技术栈 前端节点编辑器 底层框架&#xff1a;基于 &#xff08;一个现代化的 React 节点绘图库&#xff09; 功能&#xff1a; 拖拽式构建 LangGraph 状态机 实时连线定义节点依赖关系 可视化调试循环和分支逻辑 与 LangGraph 的深…...

MFE(微前端) Module Federation:Webpack.config.js文件中每个属性的含义解释

以Module Federation 插件详为例&#xff0c;Webpack.config.js它可能的配置和含义如下&#xff1a; 前言 Module Federation 的Webpack.config.js核心配置包括&#xff1a; name filename&#xff08;定义应用标识&#xff09; remotes&#xff08;引用远程模块&#xff0…...

WEB3全栈开发——面试专业技能点P7前端与链上集成

一、Next.js技术栈 ✅ 概念介绍 Next.js 是一个基于 React 的 服务端渲染&#xff08;SSR&#xff09;与静态网站生成&#xff08;SSG&#xff09; 框架&#xff0c;由 Vercel 开发。它简化了构建生产级 React 应用的过程&#xff0c;并内置了很多特性&#xff1a; ✅ 文件系…...

AI书签管理工具开发全记录(十八):书签导入导出

文章目录 AI书签管理工具开发全记录&#xff08;十八&#xff09;&#xff1a;书签导入导出1.前言 &#x1f4dd;2.书签结构分析 &#x1f4d6;3.书签示例 &#x1f4d1;4.书签文件结构定义描述 &#x1f523;4.1. ​整体文档结构​​4.2. ​核心元素类型​​4.3. ​层级关系4.…...

Java求职者面试指南:Spring、Spring Boot、Spring MVC与MyBatis技术点解析

Java求职者面试指南&#xff1a;Spring、Spring Boot、Spring MVC与MyBatis技术点解析 第一轮&#xff1a;基础概念问题 请解释Spring框架的核心容器是什么&#xff1f;它的作用是什么&#xff1f; 程序员JY回答&#xff1a;Spring框架的核心容器是IoC容器&#xff08;控制反转…...

Flask和Django,你怎么选?

Flask 和 Django 是 Python 两大最流行的 Web 框架&#xff0c;但它们的设计哲学、目标和适用场景有显著区别。以下是详细的对比&#xff1a; 核心区别&#xff1a;哲学与定位 Django: 定位: "全栈式" Web 框架。奉行"开箱即用"的理念。 哲学: "包含…...

Unity基础-Mathf相关

Unity基础-Mathf相关 一、Mathf数学工具 概述 Mathf是Unity中封装好用于数学计算的工具结构体&#xff0c;提供了丰富的数学计算方法&#xff0c;特别适用于游戏开发场景。它是Unity开发中最常用的数学工具之一&#xff0c;能够帮助我们处理各种数学计算和插值运算。 Mathf…...