当前位置: 首页 > news >正文

李沐深度学习-d2lzh_pytorch模块实现

d2lzh_pytorch 模块

import random
import torch
import matplotlib_inline
from matplotlib import pyplot as plt
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets
import sys
from collections import OrderedDict# ---------------------------------------------------------------------------------------------
# 图表展示
def use_svg_display():# 用矢量图表示matplotlib_inline.backend_inline.set_matplotlib_formats('svg')def set_figsize(figsize=(3.5, 2.5)):use_svg_display()# 设置图的尺寸plt.rcParams['figure.figsize'] = figsize# ---------------------------------------------------------------------------------------------
# 读取数据
# 获取总的样本数量,然后打乱顺序,用batch-size获取每一部分索引去索引对应样本中的数据,使用yield返回
'''
函数详解:
torch.linspace(start, end, steps, dtype) → Tensor  从start开始到end结束,生成steps个数据点,数据类型为dtype
torch.index_select(input, dim, index)   索引张量中的子集
**input:需要进行索引操作的输入张量dim:张量维度  0,1index:索引号,是张量类型
**
yield: 使用yield的函数返回迭代器对象,每次使用时会保存变量信息,使用next()或者使用for可以循环访问迭代器中的内容
'''def data_iter(batch_size, features, labels):num_examples = len(features)  # features   nxmindices = list(range(num_examples))  # 借助range生成索引序列random.shuffle(indices)  # 把list列表中的值打乱顺序for i in range(0, num_examples, batch_size):j = torch.LongTensor(indices[i:min(i + batch_size, num_examples)])  # 这里的i是对标乱序表中的下标索引号yield features.index_select(0, j), labels.index_select(0, j)  # 0维度,有1000个样本,j就是他们的下标# ---------------------------------------------------------------------------------------------# 定义模型
def linreg(X, w, b):return torch.mm(X, w) + b  # 传进来的参数和样本特征都符合矩阵形式 w,b都是列矩阵  X:1000x2  w:2x1  b:1x1# 这里使用了广播# ---------------------------------------------------------------------------------------------# 定义损失函数
def square_loss(y_hat, y):# 保证y_hat和y同型,pytorch中的MSELoss没有除以2的操作return (y_hat - y.view(y_hat.size())) ** 2 / 2# 这里的得到的也是一个小批量的样本的损失张量# ---------------------------------------------------------------------------------------------
# 定义优化算法
# 这里使用的是sgd算法,使用小批量梯度和(参数求导后的和:梯度会自动累加,不用自己加和梯度)除以小批量样本个数来求小批量平均值
def sgd(params, lr, batch_size):for param in params:param.data -= lr * param.grad / batch_size  # 这里更改param时使用的是param.data,这样就不会影响反向梯度# 这里的param指的是w1,w2,b# 这里应该是小批量中的每个loss运行完,得到小批量每个样本的梯度然后pytorch自动进行了梯度累加,之后一个小批量得到一个累加和后的
# 梯度w1,w2,b
# ---------------------------------------------------------------------------------------------'''
FashionMNIST 数据集
'''# ----------------------------------------------------------将数值标签转换成文本标签
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal','shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# -----------------------------------------------------在一行里画出多张图像和对应标签的函数
def show_fashion_mnist(images, labels):use_svg_display()# 这里的_表示忽略(不使用)的变量_, figs = plt.subplots(1, len(images), figsize=(12, 12))  # 设置一行 len(images)个数量,每个figsize大小的画布# figs 返回的是一个画布对象,这个对象有imshow,set_tittle,axes_get_xasis().set_visible,# axes.get_yaxis().set_visible()这几种函数调用方式,用来给figs里面添加图像for f, img, lbl, in zip(figs, images, labels):  # 这个画布对象循环往里面添加图像信息f.imshow(img.view((28, 28)).numpy())  # img承接图像信息,将tensor转化为numpy  这里参数为数组元素f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.savefig("路径")# ----------------------------------------------------------------获取并读取FashionMNIST数据集函数,返回小批量train,test
def load_data_fashion_mnist(batch_size):mnist_train = torchvision.datasets.FashionMNIST(root='路径',train=True, download=True, transform=transforms.ToTensor())mnist_test = torchvision.datasets.FashionMNIST(root='路径',train=False, download=True, transform=transforms.ToTensor())'''上面的mnist_train,mnist_test都是torch.utils.data.Dataset的子类,所以可以使用len()获取数据集的大小训练集和测试集中的每个类别的图像数分别是6000,1000,两个数据集分别有10个类别'''# mnist是torch.utils.data.dataset的子类,因此可以将其传入torch.utils.data.DataLoader来创建一个DataLoader实例来读取数据# 在实践中,数据读取一般是训练的性能瓶颈,特别是模型较简单或者计算硬件性能比较高的时候# DataLoader一个很有用的功能就是允许多进程来加速读取  使用num_works来设置4个进程读取数据if sys.platform.startswith('win'):num_workers = 0else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=num_workers)return train_iter, test_iter# -------------------------------------------------------------查看mnist前10个图像和标签
def check_mnist():mnist_train = torchvision.datasets.FashionMNIST(root='路径',train=True, download=True, transform=transforms.ToTensor())mnist_test = torchvision.datasets.FashionMNIST(root='路径',train=False, download=True, transform=transforms.ToTensor())X, y = [], []for i in range(10):X.append(mnist_train[i][0])  # 循环获取图像张量矩阵y.append(mnist_train[i][1])  # 循环获取图像对应数值标签show_fashion_mnist(X, get_fashion_mnist_labels(y))# feature, label = mnist_train[0]# print(feature.shape, label)  CxHxW# feature对应高和宽均为28像素的图像,因为使用了transforms.ToTensor(),所以每个像素的数值对应于【0.0,1.0】的32位浮点数# C 是通道数,RGB,灰色图像,通道数为1,H,W分别为高,宽# mnist_train[0] 是一个元祖,它包含两部分,图像数据结构和图像标签值,图像的数据结构是1x28x28结构,是一个浮点数矩阵,代表一个图像# -------------------------------------------------------------------------评价模型net在数据集data_iter上的准确率
def evaluate_accuracy(test_iter, net):acc_sum, n, x = 0.0, 0, 0.0for X, y in test_iter:  # 返回一个批量的数据元组迭代对象acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()  # 将net模型的预测y与标签y进行了准确率比较n += y.shape[0]  # 累加获得样本个数x = acc_sum / nreturn x# -------------------------------------------------------------------------训练模型函数
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params=None, lr=None, optimizer=None):for epochs in range(num_epochs):  # 循环周期train_l_sum, train_acc_sum, n = 0.0, 0.0, 0  # 预先定义 训练损失,训练精度,批量个数for X, y in train_iter:  # 批量更新y_hat = net(X)l = loss(y_hat, y).sum()  # 损失计算# 梯度清零if optimizer is not None:optimizer.zero_grad()elif params is not None and params[0].grad is not None:  # 权重存在并且权重的梯度存在for param in params:param.grad.data.zero_()l.backward()  # 反向传播# 梯度更新操作if optimizer is None:sgd(params, lr, batch_size)  # 调用sgd进行梯度下降操作else:optimizer.step()  # softmax回归的简洁实现将要用到train_l_sum += l.item()  # 损失累加train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()  # (y_hat.argmax(dim=1) == y)# 取出y_hat每一行中最大的概率索引和y比较,结果为tensor,元素值为0/1n += y.shape[0]  # 计算一个批量中标签的个数test_acc = evaluate_accuracy(test_iter, net)  # 一个循环之后进行测试集的准确度计算print(f'epoch %d,loss %.4f,train_acc %.3f,test_acc %.3f'% (epochs + 1, train_l_sum / n, train_acc_sum / n, test_acc))# x = torch.tensor([[0.1, 0.4, 0.2], [1, 0.06, 0.5]])
# print((x.argmax(dim=1)==torch.tensor([[1,1]])).float())# -------------------------------------------------------------------------x的形状转换功能函数
class FlattenLayer(torch.nn.Module):def __init__(self):super(FlattenLayer, self).__init__()  # 初始化函数,自动调用forward函数def forward(self, x):  # x shape: (batch,*,*,....)return x.view(x.shape[0], -1)  # 转换成(batch_size,特征数)形状# 这样就方便定义模型
net = torch.nn.Sequential(# FlattenLayer()# torch.nn.Linear(num_inputs,num_outputs)OrderedDict([('flatten', FlattenLayer()),('linear', torch.nn.Linear(2, 3))])
)'''
-------------------------------------------------------------------作图函数
'''def semilogy(x_vals, y_vals, xlabel, ylabel, label, x2_vals=None, y2_vals=None, legend=None):plt.xlabel(xlabel)plt.ylabel(ylabel)plt.semilogy(x_vals, y_vals)  # y轴使用对数尺度if x2_vals and y2_vals:plt.semilogy(x2_vals, y2_vals, linestyle=':')plt.legend(legend)plt.savefig("路径/多项式" + label + "模拟.png")

相关文章:

李沐深度学习-d2lzh_pytorch模块实现

d2lzh_pytorch 模块 import random import torch import matplotlib_inline from matplotlib import pyplot as plt import torchvision import torchvision.transforms as transforms import torchvision.datasets import sys from collections import OrderedDict# --------…...

什么是OSPF?为什么需要OSPF?OSPF基础概念

什么是OSPF? 开放式最短路径优先OSPF(Open Shortest Path First)是IETF组织开发的一个基于链路状态的内部网关协议(Interior Gateway Protocol)。 目前针对IPv4协议使用的是OSPF Version 2(RFC2328&#x…...

Java多线程并发篇----第二十六篇

系列文章目录 文章目录 系列文章目录前言一、什么是 Executors 框架?二、什么是阻塞队列?阻塞队列的实现原理是什么?如何使用阻塞队列来实现生产者-消费者模型?三、什么是 Callable 和 Future?前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分…...

list下

文章目录 注意:const迭代器怎么写?运用场合? inserterase析构函数赋值和拷贝构造区别?拷贝构造不能写那个swap,为什么?拷贝构造代码 面试问题什么是迭代器失效?vector、list的区别? 完整代码 注…...

【Linux】进程间通信——system V 共享内存、消息队列、信号量

需要云服务器等云产品来学习Linux的同学可以移步/–>腾讯云<–/官网&#xff0c;轻量型云服务器低至112元/年&#xff0c;优惠多多。&#xff08;联系我有折扣哦&#xff09; 文章目录 写在前面1. 共享内存1.1 共享内存的概念1.2 共享内存的原理1.3 共享内存的使用1.3.1 …...

网络卡问题排查手段

问题 对后端来说&#xff0c;网络卡了问题&#xff0c;本身很难去排查&#xff0c;因为是 App 通过互联网连接服务 总结下&#xff0c;以往经验&#xff0c;网络卡&#xff0c;通常会有以下情况造成&#xff1a; 某地区网络问题某地区某运营商问题后端服务超载前端网络模块 …...

20240119-子数组最小值之和

题目要求 给定一个整数数组 arr&#xff0c;求 min(b) 的总和&#xff0c;其中 b 的范围涵盖 arr 的每个&#xff08;连续&#xff09;子数组。由于答案可能很大&#xff0c;因此返回答案模数 Example 1: Input: arr [3,1,2,4] Output: 17 Explanation: Subarrays are [3]…...

c# 释放所有嵌入资源, 到某个本地文件夹

版本号 .net 8 代码 using System.Reflection;namespace Demo;internal class Program {static void Main(string[] args){// 获取当前 执行exe 的目录 / 当前命令行所在的目录 var currentDir Directory.GetCurrentDirectory();Console.WriteLine(currentDir);Extract…...

Unity SnapScrollRect 滚动 匹配 列表 整页

展示效果 原理: 当停止滑动时 判断Contet的horizontalNormalizedPosition 与子Item的缓存值 相减,并得到最小值&#xff0c;然后将Content horizontalNormalizedPosition滚动过去 使用方式&#xff1a; 直接将脚本挂到ScrollRect上 注意&#xff1a;在创建Content子物体时…...

网络命令ping和telnet

1. 请解释ping和telnet的工作原理。 ping和telnet是两种常用的网络工具&#xff0c;其工作原理分别如下&#xff1a; ping&#xff1a; 目的&#xff1a;ping主要用于检查网络是否通畅以及测量网络连接速度。工作原理&#xff1a;ping是基于ICMP&#xff08;Internet Control …...

ros2学习笔记-CLI工具,记录命令对应操作。

目录 环境变量turtlesim和rqt以初始状态打开rqt node启动节点查看节点列表查看节点更多信息命令行参数 --ros-args topic话题列表话题类型话题列表&#xff0c;附加话题类型根据类型查找话题名查看话题发布的数据查看话题的详细信息查看类型的详细信息给话题发布消息&#xff0…...

自然语言处理的发展

自然语言处理的发展大致经历了四个阶段&#xff1a;萌芽期、快速发展期、低谷的发展期和复苏融合期。 萌芽期&#xff08;1956年以前&#xff09;&#xff1a;这个阶段可以看作自然语言处理的基础研究阶段。人类文明经过了几千年的发展&#xff0c;积累了大量的数学、语言学和…...

flink operator 拉取阿里云私有镜像(其他私有类似)

创建 k8s secret kubectl --namespace flink create secret docker-registry aliyun-docker-registry --docker-serverregistry.cn-shenzhen.aliyuncs.com --docker-usernameops_acr1060896234 --docker-passwordpasswd --docker-emailDOCKER_EMAIL注意命名空间指定你使用的 我…...

C语言算法赛——蓝桥杯(省赛试题)

一、十四届C/C程序设计C组试题 十四届程序C组试题A#include <stdio.h> int main() {long long sum 0;int n 20230408;int i 0;// 累加从1到n的所有整数for (i 1; i < n; i){sum i;}// 输出结果printf("%lld\n", sum);return 0; }//十四届程序C组试题B…...

【文本到上下文 #2】:NLP 的数据预处理步骤

一、说明 欢迎阅读此文&#xff0c;NLP 爱好者&#xff01;当我们继续探索自然语言处理 (NLP) 的广阔前景时&#xff0c;我们已经在最初的博客中探讨了它的历史、应用和挑战。今天&#xff0c;我们更深入地探讨 NLP 的核心——数据预处理的复杂世界。 这篇文章是我们的“完整 N…...

Minio文件分片上传实现

资源准备 MacM1Pro 安装Parallels19.1.0请参考 https://blog.csdn.net/qq_41594280/article/details/135420241 MacM1Pro Parallels安装CentOS7.9请参考 https://blog.csdn.net/qq_41594280/article/details/135420461 部署Minio和整合SpringBoot请参考 https://blog.csdn.net/…...

C语言总结十一:自定义类型:结构体、枚举、联合(共用体)

本篇博客详细介绍C语言最后的三种自定义类型&#xff0c;它们分别有着各自的特点和应用场景&#xff0c;重点在于理解这三种自定义类型的声明方式和使用&#xff0c;以及各自的特点&#xff0c;最后重点掌握该章节常考的考点&#xff0c;如&#xff1a;结构体内存对齐问题&…...

解决Spring Boot应用打包后文件访问问题

在Spring Boot项目的开发过程中&#xff0c;一个常见的挑战是如何有效地访问和操作资源文件。这一挑战尤其显著当应用从IDE环境&#xff08;如IntelliJ IDEA&#xff09;迁移到被打包成JAR文件后的生产环境。开发者经常遇到的问题是&#xff0c;在IDE中运行正常的代码&#xff…...

循环神经网络的变体模型-LSTM、GRU

一.LSTM&#xff08;长短时记忆网络&#xff09; 1.1基本介绍 长短时记忆网络&#xff08;Long Short-Term Memory&#xff0c;LSTM&#xff09;是一种深度学习模型&#xff0c;属于循环神经网络&#xff08;Recurrent Neural Network&#xff0c;RNN&#xff09;的一种变体。…...

视频图像的color range简介

介绍 研究FFmpeg发现&#xff0c;在avcodec.h中有关于color的解释&#xff0c;主要有四个属性&#xff0c;primaries、transfer、space和range。 color primaries&#xff1a; 基于RGB空间对应的绝对颜色XYZ的变换&#xff0c;决定了最终三原色RGB分别是什么颜色&#xff1b;…...

CMake基础:构建流程详解

目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...

java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别

UnsatisfiedLinkError 在对接硬件设备中&#xff0c;我们会遇到使用 java 调用 dll文件 的情况&#xff0c;此时大概率出现UnsatisfiedLinkError链接错误&#xff0c;原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用&#xff0c;结果 dll 未实现 JNI 协…...

【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验

系列回顾&#xff1a; 在上一篇中&#xff0c;我们成功地为应用集成了数据库&#xff0c;并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了&#xff01;但是&#xff0c;如果你仔细审视那些 API&#xff0c;会发现它们还很“粗糙”&#xff1a;有…...

unix/linux,sudo,其发展历程详细时间线、由来、历史背景

sudo 的诞生和演化,本身就是一部 Unix/Linux 系统管理哲学变迁的微缩史。来,让我们拨开时间的迷雾,一同探寻 sudo 那波澜壮阔(也颇为实用主义)的发展历程。 历史背景:su的时代与困境 ( 20 世纪 70 年代 - 80 年代初) 在 sudo 出现之前,Unix 系统管理员和需要特权操作的…...

拉力测试cuda pytorch 把 4070显卡拉满

import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试&#xff0c;通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小&#xff0c;增大可提高计算复杂度duration: 测试持续时间&#xff08;秒&…...

深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南

&#x1f680; C extern 关键字深度解析&#xff1a;跨文件编程的终极指南 &#x1f4c5; 更新时间&#xff1a;2025年6月5日 &#x1f3f7;️ 标签&#xff1a;C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言&#x1f525;一、extern 是什么&#xff1f;&…...

多种风格导航菜单 HTML 实现(附源码)

下面我将为您展示 6 种不同风格的导航菜单实现&#xff0c;每种都包含完整 HTML、CSS 和 JavaScript 代码。 1. 简约水平导航栏 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport&qu…...

【Oracle】分区表

个人主页&#xff1a;Guiat 归属专栏&#xff1a;Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...

基于Java+MySQL实现(GUI)客户管理系统

客户资料管理系统的设计与实现 第一章 需求分析 1.1 需求总体介绍 本项目为了方便维护客户信息为了方便维护客户信息&#xff0c;对客户进行统一管理&#xff0c;可以把所有客户信息录入系统&#xff0c;进行维护和统计功能。可通过文件的方式保存相关录入数据&#xff0c;对…...

A2A JS SDK 完整教程:快速入门指南

目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库&#xff…...