当前位置: 首页 > news >正文

CUDNN详解

文章目录

  • CUDNN详解
    • 一、引言
    • 二、cuDNN的基本使用
      • 1、初始化cuDNN句柄
      • 2、创建和设置描述符
    • 三、执行卷积操作
      • 1、设置卷积参数
      • 2、选择卷积算法
      • 3、执行卷积
    • 四、使用示例
    • 五、总结

CUDNN详解

在这里插入图片描述

一、引言

cuDNN(CUDA Deep Neural Network library)是NVIDIA为深度神经网络开发的GPU加速库。它提供了高效实现深度学习算法所需的基本构建块,如卷积、池化、激活函数等。cuDNN通过优化这些操作,显著提高了深度学习模型的训练和推理速度,是深度学习框架(如TensorFlow、PyTorch)在GPU上高效运行的关键组件。

二、cuDNN的基本使用

1、初始化cuDNN句柄

在使用cuDNN之前,需要创建一个cuDNN句柄,该句柄用于管理cuDNN的上下文。例如:

cudnnHandle_t cudnn;
checkCUDNN(cudnnCreate(&cudnn));

这里使用了checkCUDNN宏来检查cuDNN函数调用的返回值,确保操作成功。

2、创建和设置描述符

cuDNN使用描述符(Descriptor)来描述张量(Tensor)、卷积核(Filter)和卷积操作(Convolution)等。例如,创建一个输入张量描述符并设置其属性:

cudnnTensorDescriptor_t input_descriptor;
checkCUDNN(cudnnCreateTensorDescriptor(&input_descriptor));
checkCUDNN(cudnnSetTensor4dDescriptor(input_descriptor,CUDNN_TENSOR_NCHW,CUDNN_DATA_FLOAT,batch_size, channels, height, width));

这里设置了输入张量的格式(NCHW)、数据类型(float)和维度(批量大小、通道数、高度、宽度)。

三、执行卷积操作

1、设置卷积参数

在执行卷积操作之前,需要设置卷积核描述符和卷积描述符。例如:

cudnnFilterDescriptor_t kernel_descriptor;
checkCUDNN(cudnnCreateFilterDescriptor(&kernel_descriptor));
checkCUDNN(cudnnSetFilter4dDescriptor(kernel_descriptor,CUDNN_DATA_FLOAT,CUDNN_TENSOR_NCHW,output_channels, input_channels, kernel_height, kernel_width));cudnnConvolutionDescriptor_t convolution_descriptor;
checkCUDNN(cudnnCreateConvolutionDescriptor(&convolution_descriptor));
checkCUDNN(cudnnSetConvolution2dDescriptor(convolution_descriptor,padding_height, padding_width,vertical_stride, horizontal_stride,dilation_height, dilation_width,CUDNN_CROSS_CORRELATION,CUDNN_DATA_FLOAT));

这里设置了卷积核的大小、输入和输出通道数、步长、填充等参数。

2、选择卷积算法

cuDNN提供了多种卷积算法,可以选择最适合当前硬件和数据的算法。例如:

int algo_count;
checkCUDNN(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn, &algo_count));
cudnnConvolutionFwdAlgo_t algo;
checkCUDNN(cudnnGetConvolutionForwardAlgorithm(cudnn,input_descriptor,kernel_descriptor,convolution_descriptor,output_descriptor,CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,0,&algo));

这里选择了最快的卷积算法。

3、执行卷积

最后,使用选择的算法执行卷积操作:

float alpha = 1.0f, beta = 0.0f;
checkCUDNN(cudnnConvolutionForward(cudnn,&alpha,input_descriptor, input_data,kernel_descriptor, kernel_data,convolution_descriptor,algo,workspace, workspace_size,&beta,output_descriptor, output_data));

这里alphabeta是缩放因子,input_datakernel_dataoutput_data分别是输入、卷积核和输出数据的指针。

四、使用示例

以下是一个完整的cuDNN卷积操作示例,包括初始化、设置描述符、执行卷积和清理资源:

#include <iostream>
#include <cuda_runtime.h>
#include <cudnn.h>#define CHECK_CUDNN(call) \
{ \cudnnStatus_t status = call; \if (status != CUDNN_STATUS_SUCCESS) { \std::cerr << "cuDNN error: " << cudnnGetErrorString(status) << std::endl; \exit(1); \} \
}int main() {int batch_size = 1, channels = 1, height = 28, width = 28;int output_channels = 16, kernel_height = 3, kernel_width = 3;// 创建cuDNN句柄cudnnHandle_t cudnn;CHECK_CUDNN(cudnnCreate(&cudnn));// 创建输入张量描述符cudnnTensorDescriptor_t input_descriptor;CHECK_CUDNN(cudnnCreateTensorDescriptor(&input_descriptor));CHECK_CUDNN(cudnnSetTensor4dDescriptor(input_descriptor,CUDNN_TENSOR_NCHW,CUDNN_DATA_FLOAT,batch_size, channels, height, width));// 创建输出张量描述符cudnnTensorDescriptor_t output_descriptor;CHECK_CUDNN(cudnnCreateTensorDescriptor(&output_descriptor));CHECK_CUDNN(cudnnSetTensor4dDescriptor(output_descriptor,CUDNN_TENSOR_NCHW,CUDNN_DATA_FLOAT,batch_size, output_channels, height, width));// 创建卷积核描述符cudnnFilterDescriptor_t kernel_descriptor;CHECK_CUDNN(cudnnCreateFilterDescriptor(&kernel_descriptor));CHECK_CUDNN(cudnnSetFilter4dDescriptor(kernel_descriptor,CUDNN_DATA_FLOAT,CUDNN_TENSOR_NCHW,output_channels, channels, kernel_height, kernel_width));// 创建卷积描述符cudnnConvolutionDescriptor_t convolution_descriptor;CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&convolution_descriptor));CHECK_CUDNN(cudnnSetConvolution2dDescriptor(convolution_descriptor,1, 1, 1, 1, 1, 1,CUDNN_CROSS_CORRELATION,CUDNN_DATA_FLOAT));// 选择卷积算法cudnnConvolutionFwdAlgo_t algo;CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm(cudnn,input_descriptor,kernel_descriptor,convolution_descriptor,output_descriptor,CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,0,&algo));// 分配内存float *input_data, *kernel_data, *output_data, *workspace;size_t workspace_size;cudaMalloc(&input_data, batch_size * channels * height * width * sizeof(float));cudaMalloc(&kernel_data, output_channels * channels * kernel_height * kernel_width * sizeof(float));cudaMalloc(&output_data, batch_size * output_channels * height * width * sizeof(float));CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn,input_descriptor,kernel_descriptor,convolution_descriptor,output_descriptor,algo,&workspace_size));cudaMalloc(&workspace, workspace_size);// 执行卷积float alpha = 1.0f, beta = 0.0f;CHECK_CUDNN(cudnnConvolutionForward(cudnn,&alpha,input_descriptor, input_data,kernel_descriptor, kernel_data,convolution_descriptor,algo,workspace, workspace_size,&beta,output_descriptor, output_data));// 清理资源cudaFree(input_data);cudaFree(kernel_data);cudaFree(output_data);cudaFree(workspace);cudnnDestroyTensorDescriptor(input_descriptor);cudnnDestroyTensorDescriptor(output_descriptor);cudnnDestroyFilterDescriptor(kernel_descriptor);cudnnDestroyConvolutionDescriptor(convolution_descriptor);cudnnDestroy(cudnn);return 0;
}

这个示例展示了如何使用cuDNN进行二维卷积操作,包括初始化、设置描述符、选择算法、执行卷积和清理资源。

五、总结

cuDNN是深度学习中不可或缺的加速库,通过优化卷积、池化、激活等操作,显著提高了模型的训练和推理速度。掌握cuDNN的基本使用方法,可以帮助开发者更高效地实现深度学习模型。在实际应用中,cuDNN与CUDA、深度学习框架(如TensorFlow、PyTorch)紧密配合,提供了强大的计算支持。


版权声明:本博客内容为原创,转载请保留原文链接及作者信息。

参考文章

  • [cuDNN API的使用与测试-以二维卷积+Relu激活函数为例](https://www.hbblog.cn/cuda%E7%9B%B8%E5%85%B3/2022%E5%B9%B407%E6%9C%8823%E6%97%A5%2023%E

相关文章:

CUDNN详解

文章目录 CUDNN详解一、引言二、cuDNN的基本使用1、初始化cuDNN句柄2、创建和设置描述符 三、执行卷积操作1、设置卷积参数2、选择卷积算法3、执行卷积 四、使用示例五、总结 CUDNN详解 一、引言 cuDNN&#xff08;CUDA Deep Neural Network library&#xff09;是NVIDIA为深度…...

下载并安装MySQL

在Linux系统上下载并安装数据库&#xff08;以MySQL为例&#xff09;的步骤如下&#xff1a; 一、下载MySQL 访问MySQL官网 打开浏览器&#xff0c;访问MySQL的官方网站&#xff1a;https://www.mysql.com/。 进入下载页面 在MySQL官网首页&#xff0c;找到并点击“Downloads…...

Linux ffmpeg 基础用法

简介 FFmpeg 是一个强大的开源多媒体框架&#xff0c;用于处理视频、音频和其他多媒体文件和流。它允许转换、录制、编辑、流媒体等等。 安装 Debian/Ubuntu sudo apt update sudo apt install ffmpegRed Hat/CentOS sudo dnf install ffmpegmacOS (via Homebrew) brew i…...

【C++入门】详解(中)

目录 &#x1f495;1.函数的重载 &#x1f495;2.引用的定义 &#x1f495;3.引用的一些常见问题 &#x1f495;4.引用——权限的放大/缩小/平移 &#x1f495;5. 不存在的空引用 &#x1f495;6.引用作为函数参数的速度之快&#xff08;代码体现&#xff09; &#x1f4…...

深度学习的加速器:Horovod,让分布式训练更简单高效!

什么是 Horovod&#xff1f; Horovod 是 Uber 开发的一个专注于深度学习分布式训练的开源框架&#xff0c;旨在简化和加速多 GPU、多节点环境下的训练过程。它以轻量级、易用、高性能著称&#xff0c;特别适合需要快速部署分布式训练的场景。Horovod 的名字来源于俄罗斯传统舞…...

计算机的错误计算(二百零八)

摘要 用两个大模型计算 arccot(0.9911588354432518e10) . 保留16位有效数字。两个的输出均是错误的。代码的输出格式亦均出错。 本节题目为一读者来信提议&#xff08;不知该题目有何玄机&#xff1f;&#xff09;。 例1. 计算 arccot(0.9911588354432518e10) . 保留16位有…...

海康机器人IPO,又近了一步

导语 大家好&#xff0c;我是社长&#xff0c;老K。专注分享智能制造和智能仓储物流等内容。欢迎大家到本文底部评论区留言。 海康机器人的IPO之路&#xff0c;一路跌宕起伏&#xff0c;让无数投资者和业内人士关注。这不仅仅是一家企业的上市之旅&#xff0c;更是中国智能制造…...

【环境搭建】Metersphere v2.x 容器部署教程踩坑总结

前言 Metersphere部署过程中遇到的问题有点多&#xff0c;原因是其容器的架构蛮复杂的&#xff0c;比较容易踩坑&#xff0c;所以记录一下。 介绍 MeterSphere 是开源持续测试平台&#xff0c;遵循 GPL v3 开源许可协议&#xff0c;涵盖测试管理、接口测试、UI 测试和性能测…...

系统看门狗配置--以ubuntu为例

linux系统配置看门狗 以 ubuntu 系统配置看门狗为例 配置看门狗使用的脚本文件&#xff0c;需要使用管理员权限来执行&#xff1a; 配置是&#xff1a;系统每 30S 喂一次狗&#xff0c;超过 60S 不进行投喂&#xff0c;就会自动重启。 1. 系统脚本内容&#xff1a; #!/bin/b…...

阅读笔记——《A survey of protocol fuzzing》

【参考文献】Zhang X, Zhang C, Li X, et al. A survey of protocol fuzzing[J]. ACM Computing Surveys, 2024, 57(2): 1-36.【注】本文仅为作者个人学习笔记&#xff0c;如有冒犯&#xff0c;请联系作者删除。 目录 1、Introduction 2、Background 2.1、Communication Pro…...

C# 语法中级

总目录 C# 语法总目录 C# 语法中级 lambda 表达式1. 捕获外部变量2. 捕获迭代变量 匿名类型匿名方法异常相关1. 枚举器2. 可枚举对象3. 迭代器3. 迭代器语义4. yield break 语句5. 组合序列 可空类型1. Nullable< T > 结构体 lambda 表达式 编译器在内部将lambda表达式编…...

STORM:从多时间点2D图像中快速重建动态3D场景的技术突破

随着计算机视觉和机器学习技术的迅猛发展,我们已经能够利用AI来解决许多复杂的问题。然而,在处理大规模室外动态3D场景重建时,现有的方法往往面临着诸多挑战,如需要大量人工标注数据、处理速度慢以及难以准确捕捉移动物体等。为了解决这些问题,研究者们开发了STORM(Spati…...

excel前缀和(递增求和)

方法一&#xff1a;https://www.zhihu.com/zvideo/1382164996659515392?utm_id0 假设输入数据在B2:B10&#xff0c;选中单元格C2&#xff0c;输入SUM(B2:B2&#xff0c;然后选中其中的B2&#xff0c;按F4&#xff08;或者直接输入SUM(B$2:B2&#xff09;&#xff0c;回车确认&…...

【AI日记】25.01.11 Weights Biases | AI 笔记 notion

【AI论文解读】【AI知识点】【AI小项目】【AI战略思考】【AI日记】【读书与思考】 AI kaggle 比赛&#xff1a;Forecasting Sticker Sales笔记&#xff1a;我的 AI 笔记主要记在两个地方 有道云笔记&#xff1a;数学公式和符号比较多的笔记notion&#xff1a;没什么数学公式的…...

P8772 [蓝桥杯 2022 省 A] 求和

题目描述 给定 &#x1d45b; 个整数 &#x1d44e;1,&#x1d44e;2,⋯ ,&#x1d44e;&#x1d45b; 求它们两两相乘再相加的和&#xff0c;即 &#x1d446;&#x1d44e;1⋅&#x1d44e;2&#x1d44e;1⋅&#x1d44e;3⋯&#x1d44e;1⋅&#x1d44e;&#x1d45b;&…...

【Oracle篇】深入了解执行计划中的访问路径(含表级别、B树索引、位图索引、簇表四大类访问路径)

&#x1f4ab;《博主介绍》&#xff1a;✨又是一天没白过&#xff0c;我是奈斯&#xff0c;从事IT领域✨ &#x1f4ab;《擅长领域》&#xff1a;✌️擅长阿里云AnalyticDB for MySQL(分布式数据仓库)、Oracle、MySQL、Linux、prometheus监控&#xff1b;并对SQLserver、NoSQL(…...

WSDL的基本概念

《WSDL 语法》这篇文章将详细介绍WSDL&#xff08;Web Services Description Language&#xff09;的语法。WSDL是一种基于XML的语言&#xff0c;用于描述Web服务及其访问方式。它允许开发者将Web服务定义为服务访问点或端口的集合&#xff0c;这些服务访问点可以通过特定的协议…...

RabbitMQ解决消息积压的方法

目录 减少发送mq的消息体内容 增加消费者数量 批量消费消息 临时队列转移 监控和预警机制 分阶段实施 最后还有一个方法就是开启队列的懒加载 这篇文章总结一下自己知道的解决消息积压得方法。 减少发送mq的消息体内容 像我们没有必要知道一个的中间状态&#xff0c;只需…...

Android 网络层相关介绍

关注 Android 默认支持的网络管理行为,默认支持的网络服务功能。 功能术语 术语缩写全称释义DHCPv6Dynamic Host Configuration Protocol for IPv6动态主机配置协议的第六版,用于在IPv6网络中动态分配IP地址和其他网络配置参数。DNS Domain Name System域名系统。LLALink-Loc…...

2025年第三届“华数杯”国际赛B题解题思路与代码(Matlab版)

问题1&#xff1a;产业关联性分析 在 question1.m 文件中&#xff0c;我们分析了中国主要产业之间的相互关系。以下是代码的详细解读&#xff1a; % 问题1&#xff1a;分析中国主要产业之间的相互关系function question1()% 清空工作区和命令窗口clear;clc;% 设置中文显示set…...

突破不可导策略的训练难题:零阶优化与强化学习的深度嵌合

强化学习&#xff08;Reinforcement Learning, RL&#xff09;是工业领域智能控制的重要方法。它的基本原理是将最优控制问题建模为马尔可夫决策过程&#xff0c;然后使用强化学习的Actor-Critic机制&#xff08;中文译作“知行互动”机制&#xff09;&#xff0c;逐步迭代求解…...

Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件

今天呢&#xff0c;博主的学习进度也是步入了Java Mybatis 框架&#xff0c;目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学&#xff0c;希望能对大家有所帮助&#xff0c;也特别欢迎大家指点不足之处&#xff0c;小生很乐意接受正确的建议&…...

Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)

概述 在 Swift 开发语言中&#xff0c;各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过&#xff0c;在涉及到多个子类派生于基类进行多态模拟的场景下&#xff0c;…...

深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法

深入浅出&#xff1a;JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中&#xff0c;随机数的生成看似简单&#xff0c;却隐藏着许多玄机。无论是生成密码、加密密钥&#xff0c;还是创建安全令牌&#xff0c;随机数的质量直接关系到系统的安全性。Jav…...

【JVM】- 内存结构

引言 JVM&#xff1a;Java Virtual Machine 定义&#xff1a;Java虚拟机&#xff0c;Java二进制字节码的运行环境好处&#xff1a; 一次编写&#xff0c;到处运行自动内存管理&#xff0c;垃圾回收的功能数组下标越界检查&#xff08;会抛异常&#xff0c;不会覆盖到其他代码…...

【机器视觉】单目测距——运动结构恢复

ps&#xff1a;图是随便找的&#xff0c;为了凑个封面 前言 在前面对光流法进行进一步改进&#xff0c;希望将2D光流推广至3D场景流时&#xff0c;发现2D转3D过程中存在尺度歧义问题&#xff0c;需要补全摄像头拍摄图像中缺失的深度信息&#xff0c;否则解空间不收敛&#xf…...

Python实现prophet 理论及参数优化

文章目录 Prophet理论及模型参数介绍Python代码完整实现prophet 添加外部数据进行模型优化 之前初步学习prophet的时候&#xff0c;写过一篇简单实现&#xff0c;后期随着对该模型的深入研究&#xff0c;本次记录涉及到prophet 的公式以及参数调优&#xff0c;从公式可以更直观…...

基于数字孪生的水厂可视化平台建设:架构与实践

分享大纲&#xff1a; 1、数字孪生水厂可视化平台建设背景 2、数字孪生水厂可视化平台建设架构 3、数字孪生水厂可视化平台建设成效 近几年&#xff0c;数字孪生水厂的建设开展的如火如荼。作为提升水厂管理效率、优化资源的调度手段&#xff0c;基于数字孪生的水厂可视化平台的…...

srs linux

下载编译运行 git clone https:///ossrs/srs.git ./configure --h265on make 编译完成后即可启动SRS # 启动 ./objs/srs -c conf/srs.conf # 查看日志 tail -n 30 -f ./objs/srs.log 开放端口 默认RTMP接收推流端口是1935&#xff0c;SRS管理页面端口是8080&#xff0c;可…...

C++ 基础特性深度解析

目录 引言 一、命名空间&#xff08;namespace&#xff09; C 中的命名空间​ 与 C 语言的对比​ 二、缺省参数​ C 中的缺省参数​ 与 C 语言的对比​ 三、引用&#xff08;reference&#xff09;​ C 中的引用​ 与 C 语言的对比​ 四、inline&#xff08;内联函数…...