当前位置: 首页 > article >正文

【深度学习-pytorch篇】1. Pytorch矩阵操作与DataSet创建

Pytorch矩阵操作与DataSet创建

1. Python 环境配置

1.1 安装 Anaconda

推荐使用 Anaconda 来管理 Python 环境,访问官网下载安装:

https://www.anaconda.com/download/success

1.2 安装 PyTorch

请根据自己的系统平台(Windows/Linux/macOS)和 CUDA 版本访问以下链接获取安装命令:

https://pytorch.org/get-started/locally/

检查 CUDA 工具链是否安装成功:

!nvcc --version

测试 GPU 加速是否可用:

import torch
print(torch.cuda.is_available())  # True 表示当前设备支持 GPU 运算

2. 矩阵操作复习

2.1 创建随机整数矩阵

import torch
import numpy as np# 使用 PyTorch 生成 5x5 的随机整数矩阵,取值范围在 [0, 5)
# torch.randint(lower_bound, upper_bound, size)
A = torch.randint(0, 5, (5, 5))
print("Torch 矩阵 A:\n", A)# 使用 NumPy 生成相同尺寸的随机整数矩阵
# np.random.randint(lower_bound, upper_bound, size)
A_np = np.random.randint(0, 5, (5, 5))
print("NumPy 矩阵 A_np:\n", A_np)# 访问第 4 行第 3 列元素(从0开始)
print("A[3][2] =", A[3][2])

2.2 基本矩阵运算

# 构造两个 5x6 的整数矩阵
A = torch.randint(0, 5, (5, 6))
B = torch.randint(0, 5, (5, 6))A_np = np.random.randint(0, 5, (5, 6))
B_np = np.random.randint(0, 5, (5, 6))
加法:C = A + B
C = A + B  # 按元素相加
print("Torch 加法:\n", C)C_np = A_np + B_np
print("NumPy 加法:\n", C_np)
减法:C = A - B
C = A - B  # 按元素相减
print("Torch 减法:\n", C)C_np = A_np - B_np
print("NumPy 减法:\n", C_np)
元素级乘法(Hadamard积)
C = A * B  # 对应元素相乘
print("Torch 元素级乘法:\n", C)C_np = A_np * B_np
print("NumPy 元素级乘法:\n", C_np)
矩阵乘法(内积)
# 对 B 进行转置后乘法,得到结果为 [5, 5]
C = torch.matmul(A, B.T)
print("Torch 矩阵乘法 matmul:\n", C)C = A @ B.T  # 简写形式
print("Torch 矩阵乘法 @:\n", C)C = A.mm(B.T)  # 旧API
print("Torch 矩阵乘法 mm:\n", C)C_np = np.matmul(A_np, B_np.T)
print("NumPy 矩阵乘法:\n", C_np)

3. 生成二维高斯分布数据集

import matplotlib.pyplot as plttorch.manual_seed(42)  # 设置随机种子以确保结果一致num_samples = 500  # 每类样本数量
dim = 2  # 每个样本的维度# 类别 0:中心在 (-2, -2),协方差为单位阵
mean_0 = torch.tensor([-2.0, -2.0])
cov_0 = torch.eye(dim)# 类别 1:中心在 (2, 2),协方差为单位阵
mean_1 = torch.tensor([2.0, 2.0])
cov_1 = torch.eye(dim)# 从多元高斯分布中采样
class_0 = torch.distributions.MultivariateNormal(mean_0, cov_0).sample((num_samples,))
class_1 = torch.distributions.MultivariateNormal(mean_1, cov_1).sample((num_samples,))# 创建标签
labels_0 = torch.zeros(num_samples, dtype=torch.long)
labels_1 = torch.ones(num_samples, dtype=torch.long)# 合并数据与标签
X = torch.cat([class_0, class_1], dim=0)
y = torch.cat([labels_0, labels_1], dim=0)# 打乱样本顺序
indices = torch.randperm(X.size(0))
X, y = X[indices], y[indices]# 可视化
plt.figure(figsize=(6, 6))
plt.scatter(X[:, 0], X[:, 1], c=y, cmap="coolwarm", edgecolors="k", alpha=0.7)
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title("二维高斯分布数据集")
plt.show()
  • torch.eye用法

    import torchI = torch.eye(3)
    print(I)
    
    tensor([[1., 0., 0.],[0., 1., 0.],[0., 0., 1.]])
    

4. 自定义 Dataset 类

from torch.utils.data import Datasetclass GaussianDataset(Dataset):def __init__(self, X, y):"""初始化数据集参数:- X (Tensor): 特征数据,形状为 [样本数, 特征维度]- y (Tensor): 标签数据,形状为 [样本数]"""self.X = Xself.y = ydef __len__(self):"""返回数据集总长度"""return len(self.X)def __getitem__(self, index):"""获取第 index 个样本参数:- index (int): 样本索引返回:- tuple: (样本特征, 样本标签)"""return self.X[index], self.y[index]

5. 使用 DataLoader 批次读取数据

from torch.utils.data import DataLoader# 实例化数据集
dataset = GaussianDataset(X, y)# 创建 Dataloader:按 batch_size 划分,每轮训练前打乱顺序
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

6. 查看与可视化数据批次

6.1 打印一个批次的数据

for batch_idx, batch in enumerate(dataloader):data, label = batchprint("Batch 索引:", batch_idx)print("数据形状:", data.shape)print("标签形状:", label.shape)print("第一个样本:", data[0])print("对应标签:", label[0])break  # 只看一个批次

6.2 可视化多个批次

for batch_idx, batch in enumerate(dataloader):data, label = batchplt.figure(figsize=(6, 6))plt.scatter(data[:, 0], data[:, 1], c=label, cmap="coolwarm", edgecolors="k", alpha=0.7)plt.xlabel("Feature 1")plt.ylabel("Feature 2")plt.title(f"第 {batch_idx+1} 个批次")plt.show()if batch_idx == 2:  # 展示前 3 个批次即可break

7. 总结函数功能与参数

7.1 torch.randint

  • 功能:生成随机整数张量
  • 参数:
    • low:最小值(含)
    • high:最大值(不含)
    • size:输出张量形状

7.2 torch.distributions.MultivariateNormal

  • 功能:定义多元正态分布并采样
  • 参数:
    • mean:均值向量
    • covariance_matrix:协方差矩阵
  • 方法:
    • .sample((n,)):采样 n 个样本

7.3 DataLoader

  • 功能:按批次加载数据,支持打乱与多线程
  • 参数:
    • dataset:继承自 Dataset 的对象
    • batch_size:每个批次的样本数量
    • shuffle:是否打乱数据顺序

7.4 plt.scatter

  • 功能:绘制二维散点图
  • 参数:
    • x, y:二维数据坐标
    • c:颜色标签
    • cmap:颜色映射
    • edgecolors:边框颜色
    • alpha:透明度

8. 矩阵乘法、Hadamard 积、点乘算法解释

8.1 点乘(Dot Product)

定义
点乘是两个向量之间的操作,将两个一维向量压缩成一个标量(数值)。

设有两个长度为 n n n 的向量:
a = [ a 1 , a 2 , … , a n ] , b = [ b 1 , b 2 , … , b n ] \mathbf{a} = [a_1, a_2, \dots, a_n], \quad \mathbf{b} = [b_1, b_2, \dots, b_n] a=[a1,a2,,an],b=[b1,b2,,bn]

它们的点积(dot product)定义为:
a ⋅ b = ∑ i = 1 n a i ⋅ b i \mathbf{a} \cdot \mathbf{b} = \sum_{i=1}^{n} a_i \cdot b_i ab=i=1naibi

举例

import torcha = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])dot = torch.dot(a, b)
print(dot)  # 输出:32.0 = 1×4 + 2×5 + 3×6

8.2 矩阵乘法(Matrix Multiplication)

定义
矩阵乘法是两个二维矩阵之间的乘积,满足以下维度条件:

A ∈ R m × n A \in \mathbb{R}^{m \times n} ARm×n B ∈ R n × k B \in \mathbb{R}^{n \times k} BRn×k
则它们的乘积为 C ∈ R m × k C \in \mathbb{R}^{m \times k} CRm×k,其计算方式为:

C i , j = ∑ l = 1 n A i , l ⋅ B l , j C_{i,j} = \sum_{l=1}^{n} A_{i,l} \cdot B_{l,j} Ci,j=l=1nAi,lBl,j

即:第 i i i 行 × 第 j j j 列。

算法解释

# 手动实现矩阵乘法(简化版本)
import torchA = torch.tensor([[1., 2.],[3., 4.]])  # 2x2B = torch.tensor([[5., 6.],[7., 8.]])  # 2x2# 目标结果矩阵
C = torch.zeros(2, 2)for i in range(2):         # 遍历 A 的行for j in range(2):     # 遍历 B 的列for k in range(2): # 对应维度求和C[i][j] += A[i][k] * B[k][j]print(C)
# 输出:
# tensor([[19., 22.], 19=5+14 22 = 6+16
#         [43., 50.]]) 43 = 15+28  50= 18+32

等效 PyTorch 写法

C = torch.matmul(A, B)

8.3 区别总结

对比项点乘(dot)矩阵乘法(matmul)
操作对象一维向量 a , b \mathbf{a}, \mathbf{b} a,b二维矩阵 A , B A, B A,B
输出标量(单个数字)矩阵(新的二维张量)
函数torch.dot(a, b)torch.matmul(A, B) or A @ B
应用场景计算相似度、投影、内积神经网络权重变换、特征转换、线性层

# 举例说明区别
a = torch.tensor([1., 2., 3.])
b = torch.tensor([4., 5., 6.])
print("点乘结果:", torch.dot(a, b))  # 输出:32.0A = torch.tensor([[1., 2.], [3., 4.]])
B = torch.tensor([[5., 6.], [7., 8.]])
print("矩阵乘法结果:\n", torch.matmul(A, B))
8.4. Hadamard 积

若有两个相同形状的矩阵:

A = [ a 11 a 12 a 21 a 22 ] , B = [ b 11 b 12 b 21 b 22 ] A = \begin{bmatrix} a_{11} & a_{12} \\\\ a_{21} & a_{22} \end{bmatrix}, \quad B = \begin{bmatrix} b_{11} & b_{12} \\\\ b_{21} & b_{22} \end{bmatrix} A= a11a21a12a22 ,B= b11b21b12b22

它们的 Hadamard 积 定义为:

C = A ⊙ B = [ a 11 ⋅ b 11 a 12 ⋅ b 12 a 21 ⋅ b 21 a 22 ⋅ b 22 ] C = A \odot B = \begin{bmatrix} a_{11} \cdot b_{11} & a_{12} \cdot b_{12} \\\\ a_{21} \cdot b_{21} & a_{22} \cdot b_{22} \end{bmatrix} C=AB= a11b11a21b21a12b12a22b22

其中 ⊙ \odot 表示逐元素相乘,即对应位置元素分别相乘,不进行矩阵乘法的求和操作。

import torch
A = torch.tensor([[1., 2.], [3., 4.]])
B = torch.tensor([[10., 20.], [30., 40.]])C = A * B  # Hadamard 积
print(C)
tensor([[10., 40.],[90., 160.]])

相关文章:

【深度学习-pytorch篇】1. Pytorch矩阵操作与DataSet创建

Pytorch矩阵操作与DataSet创建 1. Python 环境配置 1.1 安装 Anaconda 推荐使用 Anaconda 来管理 Python 环境,访问官网下载安装: https://www.anaconda.com/download/success 1.2 安装 PyTorch 请根据自己的系统平台(Windows/Linux/ma…...

游戏引擎学习第310天:利用网格划分完成排序加速优化

回顾并为今天的内容做个铺垫 昨天我们完成了一个用于排序的空间划分系统,但还没有机会真正利用它。昨天的工作刚好在结束时才完成,所以今天我们打算正式使用这个空间划分来加速排序。 现在我们在渲染代码中,可以看到在代码底部隐藏着一个“…...

数据结构 - 树的遍历

一、二叉树的遍历 对于二叉树,常用的遍历方式包括:先序遍历、中序遍历、后序遍历和层次遍历 。 1、先序遍历(PreOrder) 先序遍历的操作过程如下: 若二叉树为空,则什么也不做;否则&#xff0…...

时序模型介绍

一.整体介绍 1.单变量 vs 多变量时序数据 单变量就是只根据时间预测,多变量还要考虑用户 2.为什么不能用机器学习预测: a.时间不是影响标签的关键因素 b.时间与标签之间的联系过于弱/过于复杂,因此时序模型依赖于时间与时间的相关性来进行预…...

Java面试实战:从Spring到大数据的全栈挑战

Java面试实战:从Spring到大数据的全栈挑战 在某家知名互联网大厂,严肃的面试官正在面试一位名叫谢飞机的程序员。谢飞机以其搞笑的回答和对Java技术栈的独特见解而闻名。 第一轮:Spring与微服务的探索 面试官:“请你谈谈Spring…...

解决idea与springboot版本问题

遇到以下问题: 1、springboot3.2.0与jdk1.8 提示这个包org.springframework.web.bind.annotation不存在,但是pom已经引入了spring-boot-starter-web 2、Error:Cannot determine path to tools.jar library for 17 (D:/jdk17) 3、Error:(3, 28) java: …...

【第4章 图像与视频】4.4 离屏 canvas

文章目录 前言为什么要使用 offscreenCanvas为什么要使用 OffscreenCanvas如何使用 OffscreenCanvas第一种使用方式第二种使用方式 计算时长超过多长时间适合用Web Worker 前言 在 Canvas 开发中,我们经常需要处理复杂的图形和动画,这些操作可能会影响页…...

[AXI]如何验证AXI5原子操作

如何验证 AXI5 原子操作 摘要:在 UVM (Universal Verification Methodology) 验证环境中,验证 AXI5 协议的原子操作 (Atomic Operations) 是一项重要的任务,特别是在验证支持高并发和数据一致性的 SoC (System on Chip) 设计时。AXI5 引入了原…...

尚硅谷redis7 74-85 redis集群分片之集群是什么

74 redis集群分片之集群是什么 如果主机宕机,那么写操作就被暂时中断,后面就要由哨兵进行投票和选举。那么一瞬间若有大量的数据修改,由于写操作中断就会导致数据流失。 由于数据量过大,单个Master复制集难以承担,因此需要对多个复制集进行…...

Android获取设备信息

使用java: List<TableMessage> dataListnew ArrayList<TableMessage>();//获取设备信息Hashtable<String,String> ht MyDeviceInfo.getDeviceAllInfo2(LoginActivity.this);for (Map.Entry<String, String> entry : ht.entrySet()) {String key entry…...

WPF的基础控件:布局控件(StackPanel DockPanel)

布局控件&#xff08;StackPanel & DockPanel&#xff09; 1 StackPanel的Orientation属性2 DockPanel的LastChildFill3 嵌套布局示例4 性能优化建议5 常见问题排查 在WPF开发中&#xff0c;布局控件是构建用户界面的基石。StackPanel和DockPanel作为两种最基础的布局容器&…...

apache的commons-pool2原理与使用详解

Apache Commons Pool2 是一个高效的对象池化框架&#xff0c;通过复用昂贵资源&#xff08;如数据库连接、线程、网络连接&#xff09;优化系统性能。 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击…...

打印Yolo预训练模型的所有类别及对应的id

有时候我们可能只需要用yolo模型检测个别类别&#xff0c;并显示&#xff0c;这就需要知道id&#xff0c;以下代码可打印出 from ultralytics import YOLO# 加载模型 model YOLO(yolo11x.pt)# 打印所有类别名称及其对应的ID print(model.names) {0: person, 1: bicycle, 2: c…...

语法糖介绍(C++ Python)

语法糖&#xff08;Syntactic Sugar&#xff09;是编程语言中为了提升代码可读性和简洁性而设计的语法结构。它不改变语言的功能&#xff0c;但能让代码更易写和理解。以下是 C 和 Python 中常见的语法糖示例&#xff1a; C 中的常见语法糖 范围 for 循环&#xff08;Range-bas…...

事务详解及面试常考知识点整理

事务详解及面试常考知识点整理 1. 什么是事务&#xff1f; **事务&#xff08;Transaction&#xff09;**是将多条 SQL 语句打包执行的操作单元&#xff0c;具有“一气呵成”的特性。就好比你要完成“把大象放进冰箱”这件事&#xff0c;一共分三步&#xff1a; 打开冰箱门把…...

设计模式26——解释器模式

写文章的初心主要是用来帮助自己快速的回忆这个模式该怎么用&#xff0c;主要是下面的UML图可以起到大作用&#xff0c;在你学习过一遍以后可能会遗忘&#xff0c;忘记了不要紧&#xff0c;只要看一眼UML图就能想起来了。同时也请大家多多指教。 解释器模式&#xff08;Interp…...

在MDK中自动部署LVGL,在stm32f407ZGT6移植LVGL-8.3,运行demo,显示label

在MDK中自动部署LVGL&#xff0c;在stm32f407ZGT6移植LVGL-8.3 一、硬件平台二、实现功能三、移植步骤1、下载LVGL-8.42、MDK中安装LVGL-8.43、配置RTE4、配置头文件 lv_conf_cmsis.h5、配置lv_port_disp_template 四、添加心跳相关文件1、在STM32CubeMX中配置TIM7的参数2、使能…...

ArcGIS 与 HEC-RAS 协同:流域水文分析与洪水模拟全流程

技术点目录 洪水淹没危险性评价方法及技术介绍基于ArcGIS的水文分析基于HecRAS淹没模拟的洪水危险性评价洪水风险评价综合案例分析应用了解更多 —————————————————————————————————————————————————— 前言综述 洪水危险性及…...

树莓派设置静态ip 永久有效 我的需要设置三个 一个摄像头的 两个设备的

通过 systemd-networkd 配置 此方法适用于较新的Raspberry Pi OS版本&#xff0c;支持同时绑定多个IP地址到同一网卡&#xff0c;且配置清晰稳定。 1.禁用DHCP客户端对eth0的管理:编辑/etc/dhcpcd.conf文件&#xff0c;添加以下内容以忽略eth0接口的自动分配 sudo nano /etc…...

多模态大语言模型arxiv论文略读(九十九)

PartGLEE: A Foundation Model for Recognizing and Parsing Any Objects ➡️ 论文标题&#xff1a;PartGLEE: A Foundation Model for Recognizing and Parsing Any Objects ➡️ 论文作者&#xff1a;Junyi Li, Junfeng Wu, Weizhi Zhao, Song Bai, Xiang Bai ➡️ 研究机构…...

Fine-tuning:微调技术,训练方式,LLaMA-Factory,ms-swift

1&#xff0c;微调技术 特征Full-tuningFreeze-tuningLoRAQLoRA训练参数量全部少量极少极少显存需求高低很低最低模型性能最佳中等较好接近 LoRA模型修改方式无变化局部冻结插入模块量化插入模块多任务共享不便较便非常适合非常适合适合超大模型微调❌✅✅✅&#xff08;最优&…...

vscode连接的linux服务器,上传项目至github

问题 已将项目整个文件夹拷贝到克隆下来的文件夹中&#xff0c;并添加了所有文件&#xff0c;并修改了commit -m&#xff0c;使用git push -u origin main提交的时候会出现vscode请求登录github&#xff0c;确定之后需要等待很久&#xff0c;也无果 原因 由于 远程服务器无法…...

XCTF-web-mfw

发现了git 使用GitHack下载一下源文件&#xff0c;找到了php源代码 <?phpif (isset($_GET[page])) {$page $_GET[page]; } else {$page "home"; }$file "templates/" . $page . ".php";// I heard .. is dangerous! assert("strpos…...

indel_snp_ssr_primer

indel标记使用 1.得到vcf文件 2.提取指定区域vcf文件并压缩构建索引 bcftools view -r <CHROM>:<START>-<END> input.vcf -o output.vcf bgzip -c all.filtered.indel.vcf > all.filtered.indel.vcf.gz tabix -p vcf all.filtered.indel.vcf.gz3.准备参…...

图论核心:深度搜索DFS 与广度搜索BFS

一、深度优先搜索&#xff08;DFS&#xff09;&#xff1a;一条路走到黑的探索哲学 1. 算法核心思想 DFS&#xff08;Depth-First Search&#xff09;遵循 “深度优先” 原则&#xff0c;从起始节点出发&#xff0c;尽可能深入地访问每个分支&#xff0c;直到无法继续时回溯&a…...

Java 调用 HTTP 和 HTTPS 的方式详解

文章目录 1. HTTP 和 HTTPS 基础知识1.1 什么是 HTTP/HTTPS&#xff1f;1.2 HTTP 请求与响应结构1.3 常见的 HTTP 方法1.4 常见的 HTTP 状态码 2. Java 原生 HTTP 客户端2.1 使用 URLConnection 和 HttpURLConnection2.1.1 基本 GET 请求2.1.2 基本 POST 请求2.1.3 处理 HTTPS …...

Redis--基础知识点--28--慢查询相关

1 慢查询的原因 1.1 非命令数据相关原因 1.1.1 网络延迟 原因&#xff1a;客户端与 Redis 服务器之间的网络延迟可能导致客户端感知到的响应时间变长。 解决方案&#xff1a;优化网络环境 排查&#xff1a; 1.1.2 CPU 竞争 原因&#xff1a;Redis 是单线程的&#xff0c…...

目标检测:YOLO 模型详解

目录 一、YOLO&#xff08;You Only Look Once&#xff09;模型讲解 YOLOv1 YOLOv2 (YOLO9000) YOLOv3 YOLOv4 YOLOv5 YOLOv6 YOLOv7 YOLOv8 YOLOv9 YOLOv10 YOLOv11 YOLOv12 其他变体&#xff1a;PP-YOLO 二、YOLO 模型的 Backbone&#xff1a;Focus 结构 三、…...

HDFS存储原理与MapReduce计算模型

HDFS存储原理 1. 架构设计 主从架构&#xff1a;包含一个NameNode&#xff08;主节点&#xff09;和多个DataNode&#xff08;从节点&#xff09;。 NameNode&#xff1a;管理元数据&#xff08;文件目录结构、文件块映射、块位置信息&#xff09;&#xff0c;不存储实际数据…...

电机控制选 STM32 还是 DSP?技术选型背后的现实博弈

现在搞电机控制&#xff0c;圈里人都门儿清 —— 主流方案早就被 STM32 这些 Cortex-M 单片机给拿捏了。可要是撞上系统里的老甲方&#xff0c;技术认知还停留在诺基亚砸核桃的年代&#xff0c;非揪着 DSP 不放&#xff0c;咱也只能赔笑脸&#xff1a;“您老说的对&#xff0c;…...