当前位置: 首页 > news >正文

Pytorch在cuda、AMD DirectML和AMD CPU下性能比较

一、测试环境

CUDA环境: i7-8550u + 16G DDR4 2133MHz + nVidia MX150 2GB

AMD DirectML环境: Ryzen 5 5600G + 32G DDR4 3200MHz + Vega7 4GB

AMD 纯CPU环境:Ryzen 5 5600G + 32G DDR4 3200MHz 

其他硬件配置的硬盘、电源均一致。Pytorch版本为2.0.0,Python环境为3.7.11,Win10 LTSC。

二、测试代码

拟合一个100万点数的函数,并计算从神经网络被传入内存/显存开始,到计算结果出来,所耗费的时间。不含前面准备时间、出图时间。计算三次手动记录平均值。代码如下:

CUDA测试代码

# -*- coding: utf-8 -*-
# @Time    : 19/12/9 16:38
# @Author  : JL
# @File    : pytorchTest.py
# @Software: PyCharmimport matplotlib.pyplot as plt
import torch
import timex = torch.unsqueeze(torch.linspace(-1, 1, 1000000), dim=1).cuda()
y = x.pow(2) + 0.3 * torch.rand(x.size()).cuda()net1 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
lossFunc = torch.nn.MSELoss()device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("当前使用的设备是:" + str(torch.cuda.get_device_name(torch.cuda.current_device())))
print("当前CUDA、CUDNN版本号分别为:"+str(torch.version.cuda)+"、"+str(torch.backends.cudnn.version()))
print("当前Pytorch版本号为:"+str(torch.__version__))startTime = time.perf_counter()net1.to(device)for t in range(100):prediction = net1(x)loss = lossFunc(prediction, y)optimizer.zero_grad()loss.backward()optimizer.step()print(loss.data.cpu().numpy())endTime = time.perf_counter()
delta = endTime-startTimeprint("Treat a net in %0.2f s." % delta)plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
plt.show()

DirectML、AMD CPU测试代码:

# -*- coding: utf-8 -*-
# @Time    : 19/12/9 16:38
# @Author  : JL
# @File    : pytorchTest.py
# @Software: PyCharmimport matplotlib.pyplot as plt
import torch
import torch_directml
import timedml = torch_directml.device()  # 如果使用DirectML,则分配到dml上
cpuML = torch.device("cpu")  # 如果仅使用CPU,则选择分配到cupML上# 注意修改dml或cpuML
x = torch.unsqueeze(torch.linspace(-1, 1, 1000000), dim=1).to(dml)
y = x.pow(2) + 0.3 * torch.rand(x.size()).to(dml)net1 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
lossFunc = torch.nn.MSELoss()print("当前Pytorch版本号为:" + str(torch.__version__))startTime = time.perf_counter()net1.to(dml)  # 注意修改dml或cpuMLfor t in range(100):prediction = net1(x)loss = lossFunc(prediction, y)optimizer.zero_grad()loss.backward()optimizer.step()print(loss.data.cpu().numpy())endTime = time.perf_counter()
delta = endTime - startTimeprint("Treat a net in %0.2f s." % delta)plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
plt.show()

三、测试结论

测试类型 耗费时间(秒,越小越好)
CUDA3.57
DirectML4.48
纯CPU5.31

看起来DirectML有点加速效果,但是还是和CUDA有差距,更何况这个是笔记本上最弱的MX150显卡的CUDA。微软要加油了。另外AMD的CPU,还是安心打游戏好了。

相关文章:

Pytorch在cuda、AMD DirectML和AMD CPU下性能比较

一、测试环境 CUDA环境: i7-8550u 16G DDR4 2133MHz nVidia MX150 2GB AMD DirectML环境: Ryzen 5 5600G 32G DDR4 3200MHz Vega7 4GB AMD 纯CPU环境:Ryzen 5 5600G 32G DDR4 3200MHz 其他硬件配置的硬盘、电源均一致。Pytorch版本为2.0.0,Pyt…...

哈工大计算机网络课程局域网详解之:交换机概念

哈工大计算机网络课程局域网详解之:交换机概念 文章目录 哈工大计算机网络课程局域网详解之:交换机概念以太网交换机(switch)交换机:多端口间同时传输交换机转发表:交换表交换机:自学习交换机互…...

Jenkins Pipeline的hasProperty函数

函数的作用 用于判断某个参数或者字段是否存在。 用法 例子一 def projectStr "P1,P2,P3" pipeline {agent anyparameters {extendedChoice(defaultValue: "${projectStr}",description: 选择要发布的项目,multiSelectDelimiter: ,,name: SELECT_PROJ…...

芯片制造详解.净洁室的秘密.学习笔记(三)

这是芯片制造系列的第三期跟学up主三圈,这里对其视频内容做了一下整理和归纳,喜欢的可以看原视频。 芯片制造详解03: 洁净室的秘密|为何芯片厂缺人? 芯片制造详解.净洁室的秘密.学习笔记 三 简介一、干净的级别二、芯片…...

可解释的 AI:在transformer中可视化注意力

Visualizing Attention in Transformers | Generative AI (medium.com) 一、说明 在本文中,我们将探讨可视化变压器架构核心区别特征的最流行的工具之一:注意力机制。继续阅读以了解有关BertViz的更多信息,以及如何将此注意力可视化工具整合到…...

k8s Webhook 使用java springboot实现webhook 学习总结

k8s Webhook 使用java springboot实现webhook 学习总结 大纲 基础概念准入控制器(Admission Controllers)ValidatingWebhookConfiguration 与 MutatingWebhookConfiguration准入检查(AdmissionReview)使用Springboot实现k8s-Web…...

JS逆向之猿人学爬虫第20题-wasm

文章目录 题目地址sign参数分析python算法还原往期逆向文章推荐题目地址 https://match.yuanrenxue.cn/match/20第20题被置顶到了第1页,题目难度 写的是中等 算法很简单,就一个标准的md5算法,主要是盐值不确定, 而盐值就在wasm里面,可以说难点就在于wasm分析 sign参数分…...

【双指针优化DP】The 2022 Hangzhou Normal U Summer Trials H

Problem - H - Codeforces 题意: 思路: 首先很明显是DP 因为只有1e6个站点,因此可以以站点作为阶段 注意到K很小,因此可以尝试把这个当作第二维 设dp[i][j]为到达第i个站点,已经花了j元钱的最小步数 然后就想了一…...

[论文笔记] LLM数据集——金融数据集

一、chatglm_金融 ModelScope 魔搭社区 请将modelscope sdk升级到v1.7.2rc0,执行: ​ pip3 install "modelscope1.7.2rc0" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html # 方式1 git clone http://www.modelscope…...

在亚马逊平台,如何有效举报违规行为?

众所周知,在每个行业都有一些违规现象,甚至这些违规现象还会给自己带来利益方面的损失,一旦触犯到自己的利益的话,那自己是需要想办法解决的,想办法规避。 就拿开亚马逊店铺来说,比较容易遇到的就是产品侵…...

深度学习入门教学——神经网络

深度学习就是训练神经网络。 1、神经网络 举个最简单的例子,以下是一个使用线性回归来预测房屋价格的函数。这样一个用于预测房屋价格的函数被称作是一单个神经元。大一点的神经网络,就是将这些单个神经元叠加起来。例如:神经网络根据多个相…...

阿里Java开发手册~OOP 规约

1. 【强制】避免通过一个类的对象引用访问此类的静态变量或静态方法,无谓增加编译器解析成 本,直接用 类名 来访问即可。 2. 【强制】所有的覆写方法,必须加 Override 注解。 说明: getObject() 与 get 0 bject() 的问题。…...

【Mysql数据库面试01】内连接 左连接 右连接 全连接

【Mysql数据库】内连接 左连接 右连接 全连接 0.准备1.内连接1.1 SQL(不带where)1.2 SQL(带where)1.3总结 2.左连接2.1SQL(不带where)2.2SQL(带where)2.3总结 3.右连接3.1 SQL(不带where&#x…...

事务隔离:为什么你改了我还看不见

前提概要 你肯定不陌生,和数据库打交道的时候,我们总是会用到事务。最经典的例子就 是转账,你要给朋友小王转 100 块钱,而此时你的银行卡只有 100 块钱。 转账过程具体到程序里会有一系列的操作,比如查询余额、做加减法…...

吴恩达ChatGPT《LangChain Chat with Your Data》笔记

文章目录 1. Introduction2. Document Loading2.1 Retrieval Augmented Generation(RAG)2.2 Load PDFs2.3 Load YouTube2.4 Load URLs2.5 Load Notion 3. Document Splitting3.1 Splitter Flow3.2 Character Splitter3.3 Token Splitter3.4 Markdown Spl…...

https和http有什么区别

https和http有什么区别 简要 区别如下: ​ https的端口是443.而http的端口是80,且二者连接方式不同;http传输时明文,而https是用ssl进行加密的,https的安全性更高;https是需要申请证书的,而h…...

振弦采集仪及在线监测系统完整链条的岩土工程隧道安全监测

振弦采集仪及在线监测系统完整链条的岩土工程隧道安全监测 近年来,随着城市化的不断推进和基础设施建设的不断发展,隧道建设也日益成为城市交通发展的必需品。然而,隧道建设中存在着一定的安全隐患,如地质灾害、地下水涌流等&…...

linux基础学习

1.day1 2.day2 1、VIM配置; 2、安装SSH,调用putty接入终端; 3、shell命令; *:匹配任意长度的字符 ?:匹配一个长度的字符 [...]:匹配其中指定的一个字符 [-]:匹配指定…...

android 前端常用布局文件升级总结(二)

问题一: android:name“android.support.v4.content.FileProvider” 报红 问题解决方案: 把xml布局文件里面: android.support.v4.content.FileProvider 更换成 androidx.core.content.FileProvider 问题二: android.support.design.wid…...

Linux复习——基础知识

作者简介:一名云计算网络运维人员、每天分享网络与运维的技术与干货。 座右铭:低头赶路,敬事如仪 个人主页:网络豆的主页​​​​​ 1. 有关早期linux系统中 sysvin的init的7个级别描述正确的是( )[选择1项] A. init 1 关机状态 B. init 2 字符界面多用户模式 …...

gemma-3-12b-it镜像开箱即用:3分钟完成多模态服务启动与测试

gemma-3-12b-it镜像开箱即用:3分钟完成多模态服务启动与测试 1. 快速了解Gemma-3-12b-it 如果你正在寻找一个既能理解文字又能看懂图片的AI模型,而且希望它能在普通电脑上运行,那么Gemma-3-12b-it就是为你准备的。 Gemma是Google推出的轻量…...

AML启动器:智能管理XCOM 2模组的一站式解决方案

AML启动器:智能管理XCOM 2模组的一站式解决方案 【免费下载链接】xcom2-launcher The Alternative Mod Launcher (AML) is a replacement for the default game launchers from XCOM 2 and XCOM Chimera Squad. 项目地址: https://gitcode.com/gh_mirrors/xc/xcom…...

掌握上下文工程,小白也能轻松驾驭大模型(收藏版)

本文深入解析了上下文工程的概念及其与提示工程的核心区别。随着AI进入Agent时代,上下文工程成为构建高效AI应用的关键。文章详细阐述了如何通过优化系统提示、设计高效工具和运用Few-shot Prompting来提升上下文管理能力,并介绍了应对长时程任务的压缩、…...

Youtu-Parsing镜像免配置:预置outputs目录权限+日志轮转自动配置

Youtu-Parsing镜像免配置:预置outputs目录权限日志轮转自动配置 1. 引言:告别繁琐配置,专注文档解析 如果你用过一些AI模型,肯定遇到过这样的麻烦:好不容易把服务跑起来了,结果发现生成的图片没地方保存&…...

NEURAL MASK 模型调试技巧:使用IDE进行Python代码跟踪与问题定位

NEURAL MASK 模型调试技巧:使用IDE进行Python代码跟踪与问题定位 调试代码,尤其是涉及复杂模型加载和推理的代码,有时候就像在黑暗的房间里找一颗掉落的螺丝钉。你大概知道它就在那儿,但就是看不见摸不着。对于NEURAL MASK这类模…...

MedGemma 1.5企业应用:三甲医院科研组如何用其加速文献摘要与机制推演

MedGemma 1.5企业应用:三甲医院科研组如何用其加速文献摘要与机制推演 1. 引言:当科研遇上AI助手 想象一下这个场景:深夜的医院科研办公室里,桌上堆满了待读的医学文献,电脑屏幕上同时打开了十几篇PDF。一位研究员正…...

Xinference-v1.17.1GPU算力优化:显存自动分片+KV Cache压缩,72B模型显存占用降40%

Xinference v1.17.1 GPU算力优化:显存自动分片KV Cache压缩,72B模型显存占用降40% 1. 引言:大模型部署的显存困境与曙光 如果你尝试过在单张消费级显卡上部署一个超过70B参数的大语言模型,大概率会看到一个熟悉的错误提示&#…...

造相 Z-Image 电商提效:淘宝主图/拼多多详情页/小红书种草图量产

造相 Z-Image 电商提效:淘宝主图/拼多多详情页/小红书种草图量产 1. 电商视觉内容生产的痛点与机遇 电商卖家每天面临的最大挑战之一就是视觉内容的生产。无论是淘宝主图、拼多多详情页还是小红书种草图文,都需要大量高质量的图片来吸引用户眼球。传统…...

惊艳!Pi0具身智能v1动作轨迹可视化:关节控制曲线清晰呈现

惊艳!Pi0具身智能v1动作轨迹可视化:关节控制曲线清晰呈现 1. 具身智能的动作可视化革命 在机器人实验室里,工程师小李正盯着屏幕上一堆杂乱的数据点发愁——这是他们最新研发的机械臂在执行抓取任务时生成的关节角度数据。理论上这些数字应…...

从TKMath到STL导出:一份OCCTProxy for .NET的模块化封装实战笔记

从TKMath到STL导出:OCCTProxy for .NET的模块化封装实战 在工业软件开发的深水区,几何内核的封装从来都不是简单的语法转换。当我们需要将OpenCASCADE这样的庞然大物引入.NET生态时,C/CLI就像一座精心设计的悬索桥,既要承受原生代…...