当前位置: 首页 > news >正文

李沐深度学习-d2lzh_pytorch模块实现

d2lzh_pytorch 模块

import random
import torch
import matplotlib_inline
from matplotlib import pyplot as plt
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets
import sys
from collections import OrderedDict# ---------------------------------------------------------------------------------------------
# 图表展示
def use_svg_display():# 用矢量图表示matplotlib_inline.backend_inline.set_matplotlib_formats('svg')def set_figsize(figsize=(3.5, 2.5)):use_svg_display()# 设置图的尺寸plt.rcParams['figure.figsize'] = figsize# ---------------------------------------------------------------------------------------------
# 读取数据
# 获取总的样本数量,然后打乱顺序,用batch-size获取每一部分索引去索引对应样本中的数据,使用yield返回
'''
函数详解:
torch.linspace(start, end, steps, dtype) → Tensor  从start开始到end结束,生成steps个数据点,数据类型为dtype
torch.index_select(input, dim, index)   索引张量中的子集
**input:需要进行索引操作的输入张量dim:张量维度  0,1index:索引号,是张量类型
**
yield: 使用yield的函数返回迭代器对象,每次使用时会保存变量信息,使用next()或者使用for可以循环访问迭代器中的内容
'''def data_iter(batch_size, features, labels):num_examples = len(features)  # features   nxmindices = list(range(num_examples))  # 借助range生成索引序列random.shuffle(indices)  # 把list列表中的值打乱顺序for i in range(0, num_examples, batch_size):j = torch.LongTensor(indices[i:min(i + batch_size, num_examples)])  # 这里的i是对标乱序表中的下标索引号yield features.index_select(0, j), labels.index_select(0, j)  # 0维度,有1000个样本,j就是他们的下标# ---------------------------------------------------------------------------------------------# 定义模型
def linreg(X, w, b):return torch.mm(X, w) + b  # 传进来的参数和样本特征都符合矩阵形式 w,b都是列矩阵  X:1000x2  w:2x1  b:1x1# 这里使用了广播# ---------------------------------------------------------------------------------------------# 定义损失函数
def square_loss(y_hat, y):# 保证y_hat和y同型,pytorch中的MSELoss没有除以2的操作return (y_hat - y.view(y_hat.size())) ** 2 / 2# 这里的得到的也是一个小批量的样本的损失张量# ---------------------------------------------------------------------------------------------
# 定义优化算法
# 这里使用的是sgd算法,使用小批量梯度和(参数求导后的和:梯度会自动累加,不用自己加和梯度)除以小批量样本个数来求小批量平均值
def sgd(params, lr, batch_size):for param in params:param.data -= lr * param.grad / batch_size  # 这里更改param时使用的是param.data,这样就不会影响反向梯度# 这里的param指的是w1,w2,b# 这里应该是小批量中的每个loss运行完,得到小批量每个样本的梯度然后pytorch自动进行了梯度累加,之后一个小批量得到一个累加和后的
# 梯度w1,w2,b
# ---------------------------------------------------------------------------------------------'''
FashionMNIST 数据集
'''# ----------------------------------------------------------将数值标签转换成文本标签
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal','shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# -----------------------------------------------------在一行里画出多张图像和对应标签的函数
def show_fashion_mnist(images, labels):use_svg_display()# 这里的_表示忽略(不使用)的变量_, figs = plt.subplots(1, len(images), figsize=(12, 12))  # 设置一行 len(images)个数量,每个figsize大小的画布# figs 返回的是一个画布对象,这个对象有imshow,set_tittle,axes_get_xasis().set_visible,# axes.get_yaxis().set_visible()这几种函数调用方式,用来给figs里面添加图像for f, img, lbl, in zip(figs, images, labels):  # 这个画布对象循环往里面添加图像信息f.imshow(img.view((28, 28)).numpy())  # img承接图像信息,将tensor转化为numpy  这里参数为数组元素f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.savefig("路径")# ----------------------------------------------------------------获取并读取FashionMNIST数据集函数,返回小批量train,test
def load_data_fashion_mnist(batch_size):mnist_train = torchvision.datasets.FashionMNIST(root='路径',train=True, download=True, transform=transforms.ToTensor())mnist_test = torchvision.datasets.FashionMNIST(root='路径',train=False, download=True, transform=transforms.ToTensor())'''上面的mnist_train,mnist_test都是torch.utils.data.Dataset的子类,所以可以使用len()获取数据集的大小训练集和测试集中的每个类别的图像数分别是6000,1000,两个数据集分别有10个类别'''# mnist是torch.utils.data.dataset的子类,因此可以将其传入torch.utils.data.DataLoader来创建一个DataLoader实例来读取数据# 在实践中,数据读取一般是训练的性能瓶颈,特别是模型较简单或者计算硬件性能比较高的时候# DataLoader一个很有用的功能就是允许多进程来加速读取  使用num_works来设置4个进程读取数据if sys.platform.startswith('win'):num_workers = 0else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=num_workers)return train_iter, test_iter# -------------------------------------------------------------查看mnist前10个图像和标签
def check_mnist():mnist_train = torchvision.datasets.FashionMNIST(root='路径',train=True, download=True, transform=transforms.ToTensor())mnist_test = torchvision.datasets.FashionMNIST(root='路径',train=False, download=True, transform=transforms.ToTensor())X, y = [], []for i in range(10):X.append(mnist_train[i][0])  # 循环获取图像张量矩阵y.append(mnist_train[i][1])  # 循环获取图像对应数值标签show_fashion_mnist(X, get_fashion_mnist_labels(y))# feature, label = mnist_train[0]# print(feature.shape, label)  CxHxW# feature对应高和宽均为28像素的图像,因为使用了transforms.ToTensor(),所以每个像素的数值对应于【0.0,1.0】的32位浮点数# C 是通道数,RGB,灰色图像,通道数为1,H,W分别为高,宽# mnist_train[0] 是一个元祖,它包含两部分,图像数据结构和图像标签值,图像的数据结构是1x28x28结构,是一个浮点数矩阵,代表一个图像# -------------------------------------------------------------------------评价模型net在数据集data_iter上的准确率
def evaluate_accuracy(test_iter, net):acc_sum, n, x = 0.0, 0, 0.0for X, y in test_iter:  # 返回一个批量的数据元组迭代对象acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()  # 将net模型的预测y与标签y进行了准确率比较n += y.shape[0]  # 累加获得样本个数x = acc_sum / nreturn x# -------------------------------------------------------------------------训练模型函数
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params=None, lr=None, optimizer=None):for epochs in range(num_epochs):  # 循环周期train_l_sum, train_acc_sum, n = 0.0, 0.0, 0  # 预先定义 训练损失,训练精度,批量个数for X, y in train_iter:  # 批量更新y_hat = net(X)l = loss(y_hat, y).sum()  # 损失计算# 梯度清零if optimizer is not None:optimizer.zero_grad()elif params is not None and params[0].grad is not None:  # 权重存在并且权重的梯度存在for param in params:param.grad.data.zero_()l.backward()  # 反向传播# 梯度更新操作if optimizer is None:sgd(params, lr, batch_size)  # 调用sgd进行梯度下降操作else:optimizer.step()  # softmax回归的简洁实现将要用到train_l_sum += l.item()  # 损失累加train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()  # (y_hat.argmax(dim=1) == y)# 取出y_hat每一行中最大的概率索引和y比较,结果为tensor,元素值为0/1n += y.shape[0]  # 计算一个批量中标签的个数test_acc = evaluate_accuracy(test_iter, net)  # 一个循环之后进行测试集的准确度计算print(f'epoch %d,loss %.4f,train_acc %.3f,test_acc %.3f'% (epochs + 1, train_l_sum / n, train_acc_sum / n, test_acc))# x = torch.tensor([[0.1, 0.4, 0.2], [1, 0.06, 0.5]])
# print((x.argmax(dim=1)==torch.tensor([[1,1]])).float())# -------------------------------------------------------------------------x的形状转换功能函数
class FlattenLayer(torch.nn.Module):def __init__(self):super(FlattenLayer, self).__init__()  # 初始化函数,自动调用forward函数def forward(self, x):  # x shape: (batch,*,*,....)return x.view(x.shape[0], -1)  # 转换成(batch_size,特征数)形状# 这样就方便定义模型
net = torch.nn.Sequential(# FlattenLayer()# torch.nn.Linear(num_inputs,num_outputs)OrderedDict([('flatten', FlattenLayer()),('linear', torch.nn.Linear(2, 3))])
)'''
-------------------------------------------------------------------作图函数
'''def semilogy(x_vals, y_vals, xlabel, ylabel, label, x2_vals=None, y2_vals=None, legend=None):plt.xlabel(xlabel)plt.ylabel(ylabel)plt.semilogy(x_vals, y_vals)  # y轴使用对数尺度if x2_vals and y2_vals:plt.semilogy(x2_vals, y2_vals, linestyle=':')plt.legend(legend)plt.savefig("路径/多项式" + label + "模拟.png")

相关文章:

李沐深度学习-d2lzh_pytorch模块实现

d2lzh_pytorch 模块 import random import torch import matplotlib_inline from matplotlib import pyplot as plt import torchvision import torchvision.transforms as transforms import torchvision.datasets import sys from collections import OrderedDict# --------…...

什么是OSPF?为什么需要OSPF?OSPF基础概念

什么是OSPF? 开放式最短路径优先OSPF(Open Shortest Path First)是IETF组织开发的一个基于链路状态的内部网关协议(Interior Gateway Protocol)。 目前针对IPv4协议使用的是OSPF Version 2(RFC2328&#x…...

Java多线程并发篇----第二十六篇

系列文章目录 文章目录 系列文章目录前言一、什么是 Executors 框架?二、什么是阻塞队列?阻塞队列的实现原理是什么?如何使用阻塞队列来实现生产者-消费者模型?三、什么是 Callable 和 Future?前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分…...

list下

文章目录 注意:const迭代器怎么写?运用场合? inserterase析构函数赋值和拷贝构造区别?拷贝构造不能写那个swap,为什么?拷贝构造代码 面试问题什么是迭代器失效?vector、list的区别? 完整代码 注…...

【Linux】进程间通信——system V 共享内存、消息队列、信号量

需要云服务器等云产品来学习Linux的同学可以移步/–>腾讯云<–/官网&#xff0c;轻量型云服务器低至112元/年&#xff0c;优惠多多。&#xff08;联系我有折扣哦&#xff09; 文章目录 写在前面1. 共享内存1.1 共享内存的概念1.2 共享内存的原理1.3 共享内存的使用1.3.1 …...

网络卡问题排查手段

问题 对后端来说&#xff0c;网络卡了问题&#xff0c;本身很难去排查&#xff0c;因为是 App 通过互联网连接服务 总结下&#xff0c;以往经验&#xff0c;网络卡&#xff0c;通常会有以下情况造成&#xff1a; 某地区网络问题某地区某运营商问题后端服务超载前端网络模块 …...

20240119-子数组最小值之和

题目要求 给定一个整数数组 arr&#xff0c;求 min(b) 的总和&#xff0c;其中 b 的范围涵盖 arr 的每个&#xff08;连续&#xff09;子数组。由于答案可能很大&#xff0c;因此返回答案模数 Example 1: Input: arr [3,1,2,4] Output: 17 Explanation: Subarrays are [3]…...

c# 释放所有嵌入资源, 到某个本地文件夹

版本号 .net 8 代码 using System.Reflection;namespace Demo;internal class Program {static void Main(string[] args){// 获取当前 执行exe 的目录 / 当前命令行所在的目录 var currentDir Directory.GetCurrentDirectory();Console.WriteLine(currentDir);Extract…...

Unity SnapScrollRect 滚动 匹配 列表 整页

展示效果 原理: 当停止滑动时 判断Contet的horizontalNormalizedPosition 与子Item的缓存值 相减,并得到最小值&#xff0c;然后将Content horizontalNormalizedPosition滚动过去 使用方式&#xff1a; 直接将脚本挂到ScrollRect上 注意&#xff1a;在创建Content子物体时…...

网络命令ping和telnet

1. 请解释ping和telnet的工作原理。 ping和telnet是两种常用的网络工具&#xff0c;其工作原理分别如下&#xff1a; ping&#xff1a; 目的&#xff1a;ping主要用于检查网络是否通畅以及测量网络连接速度。工作原理&#xff1a;ping是基于ICMP&#xff08;Internet Control …...

ros2学习笔记-CLI工具,记录命令对应操作。

目录 环境变量turtlesim和rqt以初始状态打开rqt node启动节点查看节点列表查看节点更多信息命令行参数 --ros-args topic话题列表话题类型话题列表&#xff0c;附加话题类型根据类型查找话题名查看话题发布的数据查看话题的详细信息查看类型的详细信息给话题发布消息&#xff0…...

自然语言处理的发展

自然语言处理的发展大致经历了四个阶段&#xff1a;萌芽期、快速发展期、低谷的发展期和复苏融合期。 萌芽期&#xff08;1956年以前&#xff09;&#xff1a;这个阶段可以看作自然语言处理的基础研究阶段。人类文明经过了几千年的发展&#xff0c;积累了大量的数学、语言学和…...

flink operator 拉取阿里云私有镜像(其他私有类似)

创建 k8s secret kubectl --namespace flink create secret docker-registry aliyun-docker-registry --docker-serverregistry.cn-shenzhen.aliyuncs.com --docker-usernameops_acr1060896234 --docker-passwordpasswd --docker-emailDOCKER_EMAIL注意命名空间指定你使用的 我…...

C语言算法赛——蓝桥杯(省赛试题)

一、十四届C/C程序设计C组试题 十四届程序C组试题A#include <stdio.h> int main() {long long sum 0;int n 20230408;int i 0;// 累加从1到n的所有整数for (i 1; i < n; i){sum i;}// 输出结果printf("%lld\n", sum);return 0; }//十四届程序C组试题B…...

【文本到上下文 #2】:NLP 的数据预处理步骤

一、说明 欢迎阅读此文&#xff0c;NLP 爱好者&#xff01;当我们继续探索自然语言处理 (NLP) 的广阔前景时&#xff0c;我们已经在最初的博客中探讨了它的历史、应用和挑战。今天&#xff0c;我们更深入地探讨 NLP 的核心——数据预处理的复杂世界。 这篇文章是我们的“完整 N…...

Minio文件分片上传实现

资源准备 MacM1Pro 安装Parallels19.1.0请参考 https://blog.csdn.net/qq_41594280/article/details/135420241 MacM1Pro Parallels安装CentOS7.9请参考 https://blog.csdn.net/qq_41594280/article/details/135420461 部署Minio和整合SpringBoot请参考 https://blog.csdn.net/…...

C语言总结十一:自定义类型:结构体、枚举、联合(共用体)

本篇博客详细介绍C语言最后的三种自定义类型&#xff0c;它们分别有着各自的特点和应用场景&#xff0c;重点在于理解这三种自定义类型的声明方式和使用&#xff0c;以及各自的特点&#xff0c;最后重点掌握该章节常考的考点&#xff0c;如&#xff1a;结构体内存对齐问题&…...

解决Spring Boot应用打包后文件访问问题

在Spring Boot项目的开发过程中&#xff0c;一个常见的挑战是如何有效地访问和操作资源文件。这一挑战尤其显著当应用从IDE环境&#xff08;如IntelliJ IDEA&#xff09;迁移到被打包成JAR文件后的生产环境。开发者经常遇到的问题是&#xff0c;在IDE中运行正常的代码&#xff…...

循环神经网络的变体模型-LSTM、GRU

一.LSTM&#xff08;长短时记忆网络&#xff09; 1.1基本介绍 长短时记忆网络&#xff08;Long Short-Term Memory&#xff0c;LSTM&#xff09;是一种深度学习模型&#xff0c;属于循环神经网络&#xff08;Recurrent Neural Network&#xff0c;RNN&#xff09;的一种变体。…...

视频图像的color range简介

介绍 研究FFmpeg发现&#xff0c;在avcodec.h中有关于color的解释&#xff0c;主要有四个属性&#xff0c;primaries、transfer、space和range。 color primaries&#xff1a; 基于RGB空间对应的绝对颜色XYZ的变换&#xff0c;决定了最终三原色RGB分别是什么颜色&#xff1b;…...

day52 ResNet18 CBAM

在深度学习的旅程中&#xff0c;我们不断探索如何提升模型的性能。今天&#xff0c;我将分享我在 ResNet18 模型中插入 CBAM&#xff08;Convolutional Block Attention Module&#xff09;模块&#xff0c;并采用分阶段微调策略的实践过程。通过这个过程&#xff0c;我不仅提升…...

练习(含atoi的模拟实现,自定义类型等练习)

一、结构体大小的计算及位段 &#xff08;结构体大小计算及位段 详解请看&#xff1a;自定义类型&#xff1a;结构体进阶-CSDN博客&#xff09; 1.在32位系统环境&#xff0c;编译选项为4字节对齐&#xff0c;那么sizeof(A)和sizeof(B)是多少&#xff1f; #pragma pack(4)st…...

Leetcode 3577. Count the Number of Computer Unlocking Permutations

Leetcode 3577. Count the Number of Computer Unlocking Permutations 1. 解题思路2. 代码实现 题目链接&#xff1a;3577. Count the Number of Computer Unlocking Permutations 1. 解题思路 这一题其实就是一个脑筋急转弯&#xff0c;要想要能够将所有的电脑解锁&#x…...

智能在线客服平台:数字化时代企业连接用户的 AI 中枢

随着互联网技术的飞速发展&#xff0c;消费者期望能够随时随地与企业进行交流。在线客服平台作为连接企业与客户的重要桥梁&#xff0c;不仅优化了客户体验&#xff0c;还提升了企业的服务效率和市场竞争力。本文将探讨在线客服平台的重要性、技术进展、实际应用&#xff0c;并…...

【git】把本地更改提交远程新分支feature_g

创建并切换新分支 git checkout -b feature_g 添加并提交更改 git add . git commit -m “实现图片上传功能” 推送到远程 git push -u origin feature_g...

Caliper 配置文件解析:config.yaml

Caliper 是一个区块链性能基准测试工具,用于评估不同区块链平台的性能。下面我将详细解释你提供的 fisco-bcos.json 文件结构,并说明它与 config.yaml 文件的关系。 fisco-bcos.json 文件解析 这个文件是针对 FISCO-BCOS 区块链网络的 Caliper 配置文件,主要包含以下几个部…...

mysql已经安装,但是通过rpm -q 没有找mysql相关的已安装包

文章目录 现象&#xff1a;mysql已经安装&#xff0c;但是通过rpm -q 没有找mysql相关的已安装包遇到 rpm 命令找不到已经安装的 MySQL 包时&#xff0c;可能是因为以下几个原因&#xff1a;1.MySQL 不是通过 RPM 包安装的2.RPM 数据库损坏3.使用了不同的包名或路径4.使用其他包…...

什么是Ansible Jinja2

理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具&#xff0c;可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板&#xff0c;允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板&#xff0c;并通…...

AI,如何重构理解、匹配与决策?

AI 时代&#xff0c;我们如何理解消费&#xff1f; 作者&#xff5c;王彬 封面&#xff5c;Unplash 人们通过信息理解世界。 曾几何时&#xff0c;PC 与移动互联网重塑了人们的购物路径&#xff1a;信息变得唾手可得&#xff0c;商品决策变得高度依赖内容。 但 AI 时代的来…...

让回归模型不再被异常值“带跑偏“,MSE和Cauchy损失函数在噪声数据环境下的实战对比

在机器学习的回归分析中&#xff0c;损失函数的选择对模型性能具有决定性影响。均方误差&#xff08;MSE&#xff09;作为经典的损失函数&#xff0c;在处理干净数据时表现优异&#xff0c;但在面对包含异常值的噪声数据时&#xff0c;其对大误差的二次惩罚机制往往导致模型参数…...