当前位置: 首页 > news >正文

神经网络分类和回归任务实战

学习方法:torch 边用边学,边查边学 真正用查的过程才是学习的过程

直接上案例,先来跑,遇到什么解决什么

数据集Minist 数据集

做简单的任务 Minist 分类任务

总体代码(可以跑通)

from pathlib import Path
import requests
import pickle
import gzip
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch import optim
import numpy as np
bs=64
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
from matplotlib import pyplot
import numpy as np
print(x_train.shape)# pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
#获取训练数据和测试数据
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
#设置模型结构
class Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x
net = Mnist_NN()
# print(net)
# for name, parameter in net.named_parameters():
#     print(name, parameter,parameter.size())
#设置数据集和数据集加载器
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=64 * 2)
def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)loss_func = F.cross_entropy
def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)
#训练参数
def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))
def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(20, model, loss_func, opt, train_dl, valid_dl)
corret=0
total=0
for xb,yb in valid_dl:outputs=model(xb)_,predicted=torch.max(outputs.data,1)total+=yb.size(0)corret+=(predicted==yb).sum().item()
print('准确率是:%d %%'%(100*corret/total))

1.首先我们从最终实现的fit 函数开始看,

 在fit h函数之前有一个get_model 函数 得到model和优化器

model, opt = get_model()

得到模型的优化器以后

需要把训练轮数 模型 损失函数 训练数据 测试数据传入fit 训练函数

fit(20, model, loss_func, opt, train_dl, valid_dl)

fit 函数

def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))

xb是从dataloader 中取64个训练数据图片 也就是64*784维度,784代表28*28的手写数字图片的展平成一维向量

yb是64个图片对应的数字值

loss_batch(model, loss_func, xb, yb, opt)

我们看一下loss_batch 函数

loss 反向传播——更新优化器——优化器梯度归0

loss 是一个带有梯度的tensor    .item()返回的是loss 的值 len(xb )是为了求精度的时候算

输入是模型损失函数xb ,yb 和优化器

loss_func = F.cross_entropy 损失函数是交叉熵损失函数

将xb 经过model 得到输出后 和xb 求损失函数  model 是定义的一个简单的模型有两个隐藏层一个输出层 784——128——256——10

再回到fit 函数的验证部分model.evl()

先来一个 

with torch.no_grad()

不去计算梯度

zip(*的意思是解压缩 分别得到losses 和nums)

loss 和num 鲜橙 求每64个batch 的总loss 再将datalosder的所有batch 相加除以总数得到训练损失

相关文章:

神经网络分类和回归任务实战

学习方法:torch 边用边学,边查边学 真正用查的过程才是学习的过程 直接上案例,先来跑,遇到什么解决什么 数据集Minist 数据集 做简单的任务 Minist 分类任务 总体代码(可以跑通) from pathlib import …...

【数据结构】考研真题攻克与重点知识点剖析 - 第 4 篇:串

前言 本文基础知识部分来自于b站:分享笔记的好人儿的思维导图与王道考研课程,感谢大佬的开源精神,习题来自老师划的重点以及考研真题。此前我尝试了完全使用Python或是结合大语言模型对考研真题进行数据清洗与可视化分析,本人技术…...

深入浅出 -- 系统架构之分布式多形态的存储型集群

一、多形态的存储型集群 在上阶段,我们简单聊了下集群的基本知识,以及快速过了一下逻辑处理型集群的内容,下面重点来看看存储型集群,毕竟这块才是重头戏,集群的形态在其中有着多种多样的变化。 逻辑处理型的应用&…...

STL —— list

博主首页: 有趣的中国人 专栏首页: C专栏 本篇文章主要讲解 list模拟实现的相关内容 1. list简介 列表(list)是C标准模板库(STL)中的一个容器,它是一个双向链表数据结构&#xff0c…...

申请SSL证书

有很多方法可以确保您的网站安全。添加SSL证书可针对恶意攻击提供额外且关键的保护层。 即使网站不接受交易,您仍然需要保护用户的登录详细信息、地址和其他个人信息。 没有SSL证书的网站使用HTTP(一种基于文本的协议),这意味着…...

深入浅出 -- 系统架构之负载均衡Nginx环境搭建

引入负载均衡技术可带来的收益: 系统的高可用:当某个节点宕机后可以迅速将流量转移至其他节点。系统的高性能:多台服务器共同对外提供服务,为整个系统提供了更高规模的吞吐。系统的拓展性:当业务再次出现增长或萎靡时…...

notepad++绿色版添加右键菜单

解压路径 D:\Green\notepad_v8.0_x64_绿色版 添加右键菜单.reg 新建nodepad添加右键菜单.reg文件 Windows Registry Editor Version 5.00[HKEY_CLASSES_ROOT\*\shell\NotePad] "Edit with &Notepad" "Icon""D:\\Green\\notepad_v8.0_x64_绿色版…...

7 个 iMessage 恢复应用程序/软件可轻松恢复文本

由于误操作、iOS 升级中断、越狱失败、设备损坏等原因,您可能会丢失 iPhone/iPad 上的 iMessages。意外删除很大程度上增加了这种可能性。更糟糕的是,这种情况经常发生在 iDevice 缺乏备份的情况下。 (iPhone消息消失还占用空间?&…...

DockerFile启动jar程序

1.创建Dockerfile 在项目的根目录下创建一个名为Dockerfile的文件,并使用文本编辑器打开它。Dockerfile的内容如下: # 基础镜像 FROM openjdk:8-jre # 创建目录 RUN mkdir -p /usr/app/ # 设置工作目录 WORKDIR /usr/app # 将JAR文件复制到容器中,注:…...

基于R、Python的Copula变量相关性分析及AI大模型应用

在工程、水文和金融等各学科的研究中,总是会遇到很多变量,研究这些相互纠缠的变量间的相关关系是各学科的研究的重点。虽然皮尔逊相关、秩相关等相关系数提供了变量间相关关系的粗略结果,但这些系数都存在着无法克服的困难。例如,…...

鸿蒙组件学习_Tabs组件

说明 该组件从API Version 7 开始支持。 子组件 仅可包含子组件TabContent 参数 barPosition 设置Tabs的页签位置,默认值: BarPosition.StartStart vertical属性方法设置为true时,页签位于容器左侧;vertical属性方法设置为false时,页签位于容器顶部。End vertic…...

【LangChain学习之旅】—(19)BabyAGI:根据气候变化自动制定鲜花存储策略

【LangChain学习之旅】—(19)BabyAGI:根据气候变化自动制定鲜花存储策略 AutoGPTBaby AGIHuggingGPTLangChain 目前是将基于 CAMEL 框架的代理定义为 Simulation Agents(模拟代理)。这种代理在模拟环境中进行角色扮演,试图模拟特定场景或行为,而不是在真实世界中完成具体…...

thinkphp6入门(21)-- 如何删除图片、文件

假设文件的位置在 /*** 删除文件* $file_name avatar/20240208/d71d108bc1086b498df5191f9f925db3.jpg*/ function deleteFile($file_name) {// 要删除的文件路径$file app()->getRootPath() . public/uploads/ . $file_name; $result [];if (is_file($file)) {if (unlin…...

虚拟内存知识详解

虚拟内存 单片机的 CPU 是直接操作内存的「物理地址」 在这种情况下,要想在内存中同时运行两个程序是不可能的 操作系统是如何解决这个问题呢? 关键的问题是这两个程序都引用了绝对物理地址,而这正是我们最需要避免的。 可以把进程所使用的…...

数据结构初阶:顺序表和链表

线性表 线性表 ( linear list ) 是 n 个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使 用的数据结构,常见的线性表:顺序表、链表、栈、队列、字符串 ... 线性表在逻辑上是线性结构,也就说是连续的一条直线。但是在物理结构上并不一定是连续的, 线性…...

在flutter中添加video_player【视频播放插件】

添加插件依赖 dependencies:video_player: ^2.8.3插件的用途 在Flutter框架中,video_player 插件是一个专门用于播放视频的插件。它允许开发者在Flutter应用中嵌入视频播放器,并提供了一系列功能来控制和定制视频播放体验。这个插件对于需要在应用中展…...

golang微服务框架特性分析及选型

目录 一、微服务框架特性(10个)包括:Istio、go-zero、go-kit、go-kratos、go-micro、rpcx、kitex、goa、jupiter、dubbo-go、tarsgo 1、特性及使用场景2、比较 二、web框架特性(7个)包括:gin、fiber、beego…...

苹果cmsV10 MXProV4.5自适应PC手机影视站主题模板苹果cms模板mxone pro

演示站:http://a.88531.cn:8016 MXPro 模板主题(又名:mxonepro)是一款基于苹果 cms程序的一款全新的简洁好看 UI 的影视站模板类似于西瓜视频,不过同对比 MxoneV10 魔改模板来说功能没有那么多,也没有那么大气,但是比较且可视化功…...

GPU的了解

3D动画揭秘显卡的GPU是如何工作的_哔哩哔哩_bilibili 位于显卡中。 与CPU区别: 100名小学生和1位数学博士 做100道非常简单的算术题,小朋友一个人一道题,比博士快。 做1道非常复杂的数学问题,只有博士可以做出来。 CPU主要用于快…...

鸿蒙实战开发-如何使用Stage模型卡片

介绍 本示例展示了Stage模型卡片提供方的创建与使用。 用到了卡片扩展模块接口,ohos.app.form.FormExtensionAbility 。 卡片信息和状态等相关类型和枚举接口,ohos.app.form.formInfo 。 卡片提供方相关接口的能力接口,ohos.app.form.for…...

盘古信息PCB行业解决方案:以全域场景重构,激活智造新未来

一、破局:PCB行业的时代之问 在数字经济蓬勃发展的浪潮中,PCB(印制电路板)作为 “电子产品之母”,其重要性愈发凸显。随着 5G、人工智能等新兴技术的加速渗透,PCB行业面临着前所未有的挑战与机遇。产品迭代…...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建(全平台详解) 在开始使用 React Native 开发移动应用之前,正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南,涵盖 macOS 和 Windows 平台的配置步骤,如何在 Android 和 iOS…...

遍历 Map 类型集合的方法汇总

1 方法一 先用方法 keySet() 获取集合中的所有键。再通过 gey(key) 方法用对应键获取值 import java.util.HashMap; import java.util.Set;public class Test {public static void main(String[] args) {HashMap hashMap new HashMap();hashMap.put("语文",99);has…...

Python爬虫实战:研究feedparser库相关技术

1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的信息资源。RSS(Really Simple Syndication)作为一种标准化的信息聚合技术,被广泛用于网站内容的发布和订阅。通过 RSS,用户可以方便地获取网站更新的内容,而无需频繁访问各个网站。 然而,互联网…...

在 Nginx Stream 层“改写”MQTT ngx_stream_mqtt_filter_module

1、为什么要修改 CONNECT 报文? 多租户隔离:自动为接入设备追加租户前缀,后端按 ClientID 拆分队列。零代码鉴权:将入站用户名替换为 OAuth Access-Token,后端 Broker 统一校验。灰度发布:根据 IP/地理位写…...

1.3 VSCode安装与环境配置

进入网址Visual Studio Code - Code Editing. Redefined下载.deb文件,然后打开终端,进入下载文件夹,键入命令 sudo dpkg -i code_1.100.3-1748872405_amd64.deb 在终端键入命令code即启动vscode 需要安装插件列表 1.Chinese简化 2.ros …...

高等数学(下)题型笔记(八)空间解析几何与向量代数

目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...

【JavaSE】绘图与事件入门学习笔记

-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角,以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向,距离坐标原点x个像素;第二个是y坐标,表示当前位置为垂直方向,距离坐标原点y个像素。 坐标体系-像素 …...

免费PDF转图片工具

免费PDF转图片工具 一款简单易用的PDF转图片工具,可以将PDF文件快速转换为高质量PNG图片。无需安装复杂的软件,也不需要在线上传文件,保护您的隐私。 工具截图 主要特点 🚀 快速转换:本地转换,无需等待上…...

Go语言多线程问题

打印零与奇偶数(leetcode 1116) 方法1:使用互斥锁和条件变量 package mainimport ("fmt""sync" )type ZeroEvenOdd struct {n intzeroMutex sync.MutexevenMutex sync.MutexoddMutex sync.Mutexcurrent int…...