当前位置: 首页 > news >正文

如何指定多块GPU卡进行训练-数据并行

训练代码:

train.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F# 假设我们有一个简单的文本数据集
class TextDataset(Dataset):def __init__(self, texts, labels, vocab):self.texts = textsself.labels = labelsself.vocab = vocabdef __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]label = self.labels[idx]# 将文本转换为索引text_indices = [self.vocab.get(word, self.vocab['<UNK>']) for word in text.split()]return torch.tensor(text_indices, dtype=torch.long), torch.tensor(label, dtype=torch.long)# 定义一个简单的LSTM分类器
class LSTMClassifier(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):super(LSTMClassifier, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):embedded = self.embedding(x)_, (hidden, _) = self.lstm(embedded)output = self.fc(hidden[-1])return output# 构建词汇表
vocab = {'<PAD>': 0, '<UNK>': 1, 'I': 2, 'love': 3, 'this': 4, 'movie': 5, 'is': 6, 'terrible': 7}
vocab_size = len(vocab)# 示例数据
texts = ["I love this movie", "This movie is terrible"]
labels = [1, 0]  # 1表示正面情感,0表示负面情感# 创建数据集和数据加载器
dataset = TextDataset(texts, labels, vocab)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda x: (torch.nn.utils.rnn.pad_sequence([item[0] for item in x], batch_first=True), torch.stack([item[1] for item in x])))# 实例化模型
embedding_dim = 50
hidden_dim = 50
output_dim = 2
model = LSTMClassifier(vocab_size, embedding_dim, hidden_dim, output_dim)# 使用DataParallel包装模型
model = nn.DataParallel(model)# 将模型移动到GPU
model = model.cuda()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练步骤
model.train()
for epoch in range(10):  # 训练10个epochfor inputs, labels in dataloader:inputs, labels = inputs.cuda(), labels.cuda()optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f"Epoch {epoch+1}, Loss: {loss.item()}")print("训练完成")# 测试模型
model.eval()
test_texts = ["I love this movie", "This movie is terrible"]
test_dataset = TextDataset(test_texts, [1, 0], vocab)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=lambda x: (torch.nn.utils.rnn.pad_sequence([item[0] for item in x], batch_first=True), torch.stack([item[1] for item in x])))with torch.no_grad():for inputs, labels in test_dataloader:inputs, labels = inputs.cuda(), labels.cuda()outputs = model(inputs)predictions = torch.argmax(F.softmax(outputs, dim=1), dim=1)print(f"Predictions: {predictions.cpu().numpy()}, Labels: {labels.cpu().numpy()}")

执行命令:

  • export CUDA_VISIBLE_DEVICES=0,2
  • python train.py

GPU监控

训练前
在这里插入图片描述
训练中
在这里插入图片描述
Epoch 1, Loss: 0.7198400497436523
Epoch 2, Loss: 0.6889444589614868
Epoch 3, Loss: 0.6591541767120361
Epoch 4, Loss: 0.630306601524353
Epoch 5, Loss: 0.6022476553916931
Epoch 6, Loss: 0.5748419761657715
Epoch 7, Loss: 0.5479871034622192
Epoch 8, Loss: 0.5216072201728821
Epoch 9, Loss: 0.4956483840942383
Epoch 10, Loss: 0.47007784247398376
训练完成
Predictions: [1 0], Labels: [1 0]

结论

export CUDA_VISIBLE_DEVICES=0,2与nn.DataParallel(model)结合的方法是正确的

为什么需要指定 CUDA_VISIBLE_DEVICES

  • 在多GPU系统中,默认情况下,PyTorch 会尝试使用所有可用的GPU进行训练。
  • 通过设置 CUDA_VISIBLE_DEVICES 环境变量,用于控制哪些GPU对当前进程可见,PyTorch 只会使用这些可见的GPU进行训练。
  • 通过设置环境变量,你可以在不修改代码的情况下控制使用的GPU。这使得代码更加简洁和通用,不需要在代码中硬编码GPU的选择逻辑。
    总的来说:通过设置 CUDA_VISIBLE_DEVICES 环境变量,你可以灵活地控制哪些GPU对当前进程可见,从而避免资源冲突、简化代码并更好地管理多GPU资源。这是使用 torch.nn.DataParallel 进行多GPU训练时的一种常见做法。

nn.DataParallel原理是什么

nn.DataParallel 是 PyTorch 中用于多 GPU 并行计算的一个模块。它的主要原理是将输入数据分割成多个子集,并将这些子集分配到不同的 GPU 上进行并行计算。具体来说,nn.DataParallel 的工作流程如下:

  • 模型复制:首先,nn.DataParallel 会将模型复制到每个 GPU 上。这意味着每个 GPU 都会有一份完整的模型副本。
  • 数据分割:输入数据会被分割成多个子集,每个子集会被分配到一个 GPU 上。通常,这个分割是按批次(batch)维度进行的。
  • 并行计算:每个 GPU 使用其本地的模型副本对分配到的子集进行前向传播和后向传播计算。
  • 梯度汇总:在所有 GPU 上完成计算后,nn.DataParallel 会将每个 GPU 计算得到的梯度汇总到主 GPU 上(通常是 GPU 0)。
  • 参数更新:主 GPU 汇总梯度后,使用这些梯度更新模型参数。更新后的参数会同步到所有 GPU 上的模型副本。

相关文章:

如何指定多块GPU卡进行训练-数据并行

训练代码&#xff1a; train.py import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset import torch.nn.functional as F# 假设我们有一个简单的文本数据集 class TextDataset(Dataset):def __init__(self, te…...

RK3568笔记三十三: helloworld 驱动测试

若该文为原创文章&#xff0c;转载请注明原文出处。 报着学习态度&#xff0c;接下来学习驱动是如何使用的&#xff0c;从简单的helloworld驱动学习起。 开始编写第一个驱动程序—helloworld 驱动。 一、环境 1、开发板&#xff1a;正点原子的ATK-DLRK3568 2、系统&#xf…...

【智能制造-14】机器视觉软件

CCD相机和COMS相机? CCD&#xff08;Charge-Coupled Device&#xff09;相机和CMOS&#xff08;Complementary Metal-Oxide-Semiconductor&#xff09;相机是两种常见的数字图像传感器技术&#xff0c;用于捕捉和处理图像。 CCD相机&#xff1a; CCD相机使用一种称为CCD的光电…...

MVC分页

public ActionResult Index(int ? page){IPagedList<EF.ACCOUNT> userPagedList;using (EF.eMISENT content new EF.eMISENT()){第几页int pageNumber page ?? 1;每页数据条数&#xff0c;这个可以放在配置文件中int pageSize 10;//var infoslist.C660List.OrderBy(…...

webGL可用的14种3D文件格式,但要具体问题具体分析。

hello&#xff0c;我威斯数据&#xff0c;你在网上看到的各种炫酷的3d交互效果&#xff0c;背后都必须有三维文件支撑&#xff0c;就好比你网页的时候&#xff0c;得有设计稿源文件一样。WebGL是一种基于OpenGL ES 2.0标准的3D图形库&#xff0c;可以在网页上实现硬件加速的3D图…...

HybridCLR原理中的重点总结

序言 该文章以一个新手的身份&#xff0c;讲一下自己学习的经过&#xff0c;大家更快的学习HrbirdCLR。 我之前的两个Unity项目中&#xff0c;都使用到了热更新功能&#xff0c;而热更新的技术栈都是用的HybridCLR。 第一个项目本身虽然已经集成好了热更逻辑&#xff08;使用…...

昇思学习打卡-14-ResNet50迁移学习

文章目录 数据集可视化预训练模型的使用部分实现 推理 迁移学习&#xff1a;在一个很大的数据集上训练得到一个预训练模型&#xff0c;然后使用该模型来初始化网络的权重参数或作为固定特征提取器应用于特定的任务中。本章学习使用的是前面学过的ResNet50&#xff0c;使用迁移学…...

软件开发面试题C#,.NET知识点(续)

1.C#中的封装是什么&#xff0c;以及它的重要性。 封装&#xff08;Encapsulation&#xff09; 是面向对象编程&#xff08;OOP&#xff09;的一个基本概念。它指的是将对象的状态&#xff08;属性&#xff09;和行为&#xff08;方法&#xff09;绑定在一起&#xff0c;并且将…...

2019年美赛题目Problem A: Game of Ecology

本题分析&#xff1a; 本题想要要求从实际生物角度出发&#xff0c;对权力游戏中龙这种虚拟生物的生态环境和生物特性进行建模&#xff0c;感觉属于比较开放类型的题目&#xff0c;重点在于参考生物的选择&#xff0c;龙虽然是虚拟的但是龙的生态特性可以参考目前生物圈里存在…...

沙龙回顾|MongoDB如何充当企业开发加速器?

数据不仅是企业发展转型的驱动力&#xff0c;也是开发者最棘手的问题。前日&#xff0c;MongoDB携手阿里云、NineData在杭州成功举办了“数据驱动&#xff0c;敏捷前行——MongoDB企业开发加速器”技术沙龙。此次活动吸引了来自各行各业的专业人员&#xff0c;共同探讨MongoDB的…...

云端编码:将您的技术API文档安全存储在iCloud的最佳实践

云端编码&#xff1a;将您的技术API文档安全存储在iCloud的最佳实践 作为一名技术专业人士&#xff0c;管理不断增长的API文档库是一项挑战。iCloud提供了一个无缝的解决方案&#xff0c;允许您在所有设备上存储、同步和访问您的个人技术API文档。本文将指导您如何在iCloud中高…...

在Spring Boot项目中集成单点登录解决方案

在Spring Boot项目中集成单点登录解决方案 大家好&#xff0c;我是微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01; 在现代的企业应用中&#xff0c;单点登录&#xff08;Single Sign-On, SSO&#xff09;解决方案是确保用户…...

Java-常用API

1-Java API &#xff1a; 指的就是 JDK 中提供的各种功能的 Java类。 2-Scanner基本使用 Scanner&#xff1a; 一个简单的文本扫描程序&#xff0c;可以获取基本类型数据和字符串数据 构造方法&#xff1a; Scanner(InputStream source)&#xff1a;创建 Scanner 对象 Sy…...

Python从Excel表中查找指定数据填入新表

#读取xls文件中的数据 import xlrd file "原表.xls" wb xlrd.open_workbook(file) #读取工作簿 ws wb.sheets()[0] #选第一个工作表 data [] for row in range(7, ws.nrows): name ws.cell(row, 1).value.strip() #科室名称 total1 ws.cell(row, 2…...

从零开始实现大语言模型(三):Token Embedding与位置编码

1. 前言 Embedding是深度学习领域一种常用的类别特征数值化方法。在自然语言处理领域&#xff0c;Embedding用于将对自然语言文本做tokenization后得到的tokens映射成实数域上的向量。 本文介绍Embedding的基本原理&#xff0c;将训练大语言模型文本数据对应的tokens转换成Em…...

视频怎么压缩变小?最佳视频压缩器

即使在云存储和廉价硬盘空间时代&#xff0c;大视频文件使用起来仍然不方便。无论是存储、发送到电子邮件帐户还是刻录到 DVD&#xff0c;拥有最好的免费压缩软件可以确保您快速缩小文件大小&#xff0c;而不必担心视频质量下降。继续阅读以探索一些顶级最佳 免费视频压缩器选项…...

LLM - 绝对与相对位置编码 与 RoPE 旋转位置编码 源码

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/140281680 免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。 Transformer 是基于 MHSA (多头自注意力),然而,MHSA 对于位置是不敏感…...

B3917 [语言月赛 202401] 小跳蛙

OK 挠~ stop here~ 好啊&#xff0c;现在呢&#xff0c;把手头的事情先放一放啊&#xff0c;我们来做道练习 OK&#xff1f; 好啊来&#xff1a; 小跳蛙 题目描述 有 &#x1d45b;−1 只小跳蛙在池塘中&#xff0c;依次被编号为 1,2,⋯ ,&#x1d45b;−1。池塘里有 &am…...

Bash ——shell

Bash作为用户与操作系统之间的接口&#xff0c;让用户通过命令行输入各种指令来控制和操作计算机系统。 shell的两种解释&#xff1a; 1.linux命令解释器 Terminal 终端 ——》shell命令 ——》 Linux kernel &#xff08;内核&#xff09; Linux内核的作用&#xff1a; 1.…...

PyTorch复现PointNet——模型训练+可视化测试显示

因为项目涉及到3D点云项目&#xff0c;故学习下PointNet这个用来处理点云的神经网络 论文的话&#xff0c;大致都看了下&#xff0c;网络结构有了一定的了解&#xff0c;本博文主要为了下载调试PointNet网络源码&#xff0c;训练和测试调通而已。 我是在Anaconda下创建一个新的…...

大数据学习栈记——Neo4j的安装与使用

本文介绍图数据库Neofj的安装与使用&#xff0c;操作系统&#xff1a;Ubuntu24.04&#xff0c;Neofj版本&#xff1a;2025.04.0。 Apt安装 Neofj可以进行官网安装&#xff1a;Neo4j Deployment Center - Graph Database & Analytics 我这里安装是添加软件源的方法 最新版…...

模型参数、模型存储精度、参数与显存

模型参数量衡量单位 M&#xff1a;百万&#xff08;Million&#xff09; B&#xff1a;十亿&#xff08;Billion&#xff09; 1 B 1000 M 1B 1000M 1B1000M 参数存储精度 模型参数是固定的&#xff0c;但是一个参数所表示多少字节不一定&#xff0c;需要看这个参数以什么…...

CentOS下的分布式内存计算Spark环境部署

一、Spark 核心架构与应用场景 1.1 分布式计算引擎的核心优势 Spark 是基于内存的分布式计算框架&#xff0c;相比 MapReduce 具有以下核心优势&#xff1a; 内存计算&#xff1a;数据可常驻内存&#xff0c;迭代计算性能提升 10-100 倍&#xff08;文档段落&#xff1a;3-79…...

鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院挂号小程序

一、开发准备 ​​环境搭建​​&#xff1a; 安装DevEco Studio 3.0或更高版本配置HarmonyOS SDK申请开发者账号 ​​项目创建​​&#xff1a; File > New > Create Project > Application (选择"Empty Ability") 二、核心功能实现 1. 医院科室展示 /…...

【快手拥抱开源】通过快手团队开源的 KwaiCoder-AutoThink-preview 解锁大语言模型的潜力

引言&#xff1a; 在人工智能快速发展的浪潮中&#xff0c;快手Kwaipilot团队推出的 KwaiCoder-AutoThink-preview 具有里程碑意义——这是首个公开的AutoThink大语言模型&#xff08;LLM&#xff09;。该模型代表着该领域的重大突破&#xff0c;通过独特方式融合思考与非思考…...

【ROS】Nav2源码之nav2_behavior_tree-行为树节点列表

1、行为树节点分类 在 Nav2(Navigation2)的行为树框架中,行为树节点插件按照功能分为 Action(动作节点)、Condition(条件节点)、Control(控制节点) 和 Decorator(装饰节点) 四类。 1.1 动作节点 Action 执行具体的机器人操作或任务,直接与硬件、传感器或外部系统…...

MySQL用户和授权

开放MySQL白名单 可以通过iptables-save命令确认对应客户端ip是否可以访问MySQL服务&#xff1a; test: # iptables-save | grep 3306 -A mp_srv_whitelist -s 172.16.14.102/32 -p tcp -m tcp --dport 3306 -j ACCEPT -A mp_srv_whitelist -s 172.16.4.16/32 -p tcp -m tcp -…...

return this;返回的是谁

一个审批系统的示例来演示责任链模式的实现。假设公司需要处理不同金额的采购申请&#xff0c;不同级别的经理有不同的审批权限&#xff1a; // 抽象处理者&#xff1a;审批者 abstract class Approver {protected Approver successor; // 下一个处理者// 设置下一个处理者pub…...

AGain DB和倍数增益的关系

我在设置一款索尼CMOS芯片时&#xff0c;Again增益0db变化为6DB&#xff0c;画面的变化只有2倍DN的增益&#xff0c;比如10变为20。 这与dB和线性增益的关系以及传感器处理流程有关。以下是具体原因分析&#xff1a; 1. dB与线性增益的换算关系 6dB对应的理论线性增益应为&…...

QT3D学习笔记——圆台、圆锥

类名作用Qt3DWindow3D渲染窗口容器QEntity场景中的实体&#xff08;对象或容器&#xff09;QCamera控制观察视角QPointLight点光源QConeMesh圆锥几何网格QTransform控制实体的位置/旋转/缩放QPhongMaterialPhong光照材质&#xff08;定义颜色、反光等&#xff09;QFirstPersonC…...