当前位置: 首页 > news >正文

cuda从零开始手搓PB神经网络

cuda实现PB神经网络


基于上一篇的矩阵点乘,实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下:

#ifndef __BP_NETWORK_HPP__
#define __BP_NETWORK_HPP__
#include "matrix.hpp"
#include "mat.hpp"
#include "update_methods.hpp"template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_, int ... remain_layer>
struct bp_network
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using next_node_type = typename bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>;using term_output_type = typename next_node_type::term_output_type;weight_type weight;update_type_tpl<weight_type> weight_update_method;output_type bias;update_type_tpl<output_type> bias_update_method;input_type pre_input;output_type pre_func_input;next_node_type next_node;bp_network():weight_update_method(), bias_update_method(){weight.template reset<init_type>();bias.template reset<init_type>();next_node = bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>();}auto forward(input_type& input){output_type curr_output;pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;curr_output = pre_func_input.template activate<forward_func>();return next_node.forward(curr_output);}auto backward(term_output_type& delta, val_type lr){output_type curr_delta = next_node.backward(delta, lr);curr_delta = pre_func_input.template activate<backward_func>() * curr_delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update_method.update(weight, delta_weight);bias = bias_update_method.update(bias, curr_delta);return ret;}   // 更新惯性量void update_inert(){weight_update_method.update_inert();bias_update_method.update_inert();next_node.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("=================\n");next_node.print();}
};template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_>
struct bp_network<activate_type, val_type_, update_type_tpl, init_type, input_num_, output_num_>
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using term_output_type = typename output_type;using weight_update_type = typename update_type_tpl<weight_type>;using bias_update_type = typename update_type_tpl<output_type>;weight_type weight;weight_update_type weight_update;output_type bias;bias_update_type bias_update;output_type pre_func_input;input_type pre_input;bp_network():weight_update(), bias_update(){weight.template reset<init_type>();bias.template reset<init_type>();}auto forward(input_type& input){pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;return pre_func_input.template activate<forward_func>();}auto backward(output_type& delta, val_type lr){output_type curr_delta = pre_func_input.template activate<backward_func>() * delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update.update(weight, delta_weight);bias = bias_update.update(bias, curr_delta);return ret;}void update_inert(){weight_update.update_inert();bias_update.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("*****************\n");}
};#endif

下面实验一下我们的bp神经网络。

#include <chrono>
#include <thread>
#include "matrix.hpp"
#include "bp_network.hpp"
int main()
{constexpr int row_num = 32;constexpr int adj_num = 32;constexpr int col_num = 32;/*matrix_device_proxy<row_num, adj_num, double> A;eyes(A(), 2.0);matrix_device_proxy<adj_num, col_num, double> B;eyes(B(), 1.0);matrix_device_proxy<row_num, col_num, double> C;mat_dot<sigmoid>(A(), B(), C());print(type_cast(C()));auto A = mat<row_num, adj_num, double>::eyes(2.0);auto B = mat<adj_num, col_num, double>::eyes(1.0);auto C = A.dot(B);C = C + 1.0;C = sqrtl(C);C = C - 2.0;C = C * 3.0;C = C / 4.0;C.print();std::cout << "---------- D ----------" << std::endl;auto D = mat<row_num, col_num, double>::xavier_gaussian();D.print();std::cout << "---------- E ----------" << std::endl;auto E = mat<row_num, col_num, double>::xavier_mean();E.print();std::cout << "---------- F ----------" << std::endl;auto F = mat<row_num, col_num, double>::he_gaussian();F.print();std::cout << "---------- G ----------" << std::endl;auto G = mat<row_num, col_num, double>::he_mean();G.print();*/bp_network<sigmoid, double, nadam, xavier_gaussian_type, row_num, adj_num, col_num> node;auto input = mat<row_num, 1, double>::ones(0.2);auto expect = mat<col_num, 1, double>::ones(0.4);int times = 8000;int update_inert_times = 100;int step = times / update_inert_times;// 计时开始auto start = std::chrono::high_resolution_clock::now();for (int i = 0; i < times; ++i){auto output = node.forward(input);auto delta = (output - expect);node.backward(delta, 0.001);if (i == times - 1){output.t().print();}if (i % step == 0 && i != 0){node.update_inert();}}// 计时结束// 获取结束时间点auto end = std::chrono::high_resolution_clock::now();// 计算持续时间std::chrono::duration<double> duration = end - start;// 输出执行时间std::cout << "Execution time: " << duration.count() << " seconds" << std::endl;//node.print();cudaDeviceReset();return 0;
}

以上代码有个学习率lr没有地方设置哈,将来优化,见谅。执行结果如下:
在这里插入图片描述
可以看出,经过8000次的训练,这个使用sigmoid激活函数、NAdam优化、Xavier-Gaussian初始化的323232的PB能够将误差缩减到0.0001这个量级,而训练时间仅为8.54秒。还是相当给力的。
虽然这对于我的工作没有任何关系,但是我还是想搞一下。毕竟“越是没用的知识就越有用,越是有用的东西就越没用”。

相关文章:

cuda从零开始手搓PB神经网络

cuda实现PB神经网络 基于上一篇的矩阵点乘&#xff0c;实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下&#xff1a; #ifndef __BP_NETWORK_HPP__ #define __BP_NETWORK_HPP__ #include "matrix.hpp&quo…...

mac 安装mongodb

本文分享2种mac本地安装mongodb的方法&#xff0c;一种是通过homebrew安装&#xff0c;一种是通过tar包安装 homebrew安装 brew tap mongodb/brew brew upate brew install mongodb-community8.0tar包安装 安装mongodb 1.下载mongodb社区版的tar包 mongdb tar包下载地址 2…...

K8S-Pod资源清单的编写,资源的增删改查,镜像的下载策略

1. Pod资源清单的编写 1.1 Pod运行单个容器的资源清单 ##创建工作目录 mkdir -p /root/manifests/pods && cd /root/manifests/pods vim 01-nginx.yaml ##指定api版本 apiVersion: v1 ##指定资源类型 kind: Pod ##指定元数据 metadata:##指定名称name: myweb ##用户…...

【Maui】视图界面与数据模型绑定

文章目录 前言一、问题描述二、解决方案三、软件开发&#xff08;源码&#xff09;3.1 创建模型3.2 视图界面3.3 控制器逻辑层 四、项目展示 前言 .NET 多平台应用 UI (.NET MAUI) 是一个跨平台框架&#xff0c;用于使用 C# 和 XAML 创建本机移动和桌面应用。 使用 .NET MAUI&…...

JavaScript笔记基础篇02——运算符、语句、数组

黑马程序员视频地址&#xff1a;黑马程序员前端JavaScript入门到精通全套视频教程https://www.bilibili.com/video/BV1Y84y1L7Nn?vd_source0a2d366696f87e241adc64419bf12cab&spm_id_from333.788.videopod.episodes 目录 运算符 赋值运算符 ​编辑​编辑 一元运算符…...

心法利器[127] | 24年算法思考-特征工程和经典深度学习

心法利器 本栏目主要和大家一起讨论近期自己学习的心得和体会。具体介绍&#xff1a;仓颉专项&#xff1a;飞机大炮我都会&#xff0c;利器心法我还有。 2023年新的文章合集已经发布&#xff0c;获取方式看这里&#xff1a;又添十万字-CS的陋室2023年文章合集来袭&#xff0c;更…...

ASP.NET Core 中的 JWT 鉴权实现

在当今的软件开发中&#xff0c;安全性和用户认证是至关重要的方面。JSON Web Token&#xff08;JWT&#xff09;作为一种流行的身份验证机制&#xff0c;因其简洁性和无状态特性而被广泛应用于各种应用中&#xff0c;尤其是在 ASP.NET Core 项目里。本文将详细介绍如何在 ASP.…...

PyTorch基本功能与实现代码

PyTorch是一个开源的深度学习框架&#xff0c;提供了丰富的函数和工具&#xff0c;以下为其主要功能的归纳&#xff1a; 核心数据结构&#xff1a; • 张量&#xff08;Tensor&#xff09;&#xff1a;类似于Numpy的ndarray&#xff0c;是PyTorch中基本的数据结构&#xff0c…...

SparkSQL数据模型综合实践

文章目录 1. 实战概述2. 实战步骤2.1 创建数据集2.2 创建数据模型对象2.2.1 创建常量2.2.2 创建加载数据方法2.2.3 创建过滤年龄方法2.2.4 创建平均薪水方法2.2.5 创建主方法2.2.6 查看完整代码 2.3 运行程序&#xff0c;查看结果 3. 实战小结 1. 实战概述 在本次实战中&#…...

3 查找重复的电子邮箱(having与where区别,distinct去重使用)

3 查找重复的电子邮箱&#xff08;having与where区别&#xff0c;distinct去重使用&#xff09; 表: Person ---------------------- | Column Name | Type | ---------------------- | id | int | | email | varchar | ---------------------- id 是该…...

uniapp——App 监听下载文件状态,打开文件(三)

5 实现下载文件并打开 这里演示&#xff0c;导出Excel 表格 文章目录 5 实现下载文件并打开DEMO监听下载进度效果图为什么 totalSize 一直为0&#xff1f; 相关Api&#xff1a; downloader DEMO 提示&#xff1a; 请求方式支持&#xff1a;GET、POST&#xff1b;POST 方式需要…...

循环队列(C语言)

从今天开始我会开启一个专栏leetcode每日一题&#xff0c;大家互相交流代码经验&#xff0c;也当作我每天练习的自我回顾。第一天的内容是leetcode622.设计循环队列。 一、题目详细 设计你的循环队列实现。 循环队列是一种线性数据结构&#xff0c;其操作表现基于 FIFO&#…...

数据可视化:让数据讲故事的艺术

目录 1 前言2 数据可视化的基本概念2.1 可视化的核心目标2.2 传统可视化手段 3 数据可视化在知识图谱中的应用3.1 知识图谱的可视化需求3.2 知识图谱的可视化方法 4 数据可视化叙事&#xff1a;让数据讲故事4.1 叙事可视化的关键要素4.2 数据可视化叙事的实现方法 5 数据可视化…...

雷电9最新版安装Magisk+LSPosd(新手速通)

大家好啊&#xff01;我是NiJiMingCheng 我的博客&#xff1a;NiJiMingCheng 在安卓系统的定制与拓展过程中&#xff0c;获取 ROOT 权限以及安装各类框架是进阶玩家常用的操作&#xff0c;这可以帮助我们实现更多系统层面的个性化功能。今天&#xff0c;我将为大家详细介绍如何…...

Ubuntu 24.04 LTS 开启 SMB 服务,并通过 windows 访问

Ubuntu 24.04 LTS 背景资料 Ubuntu服务器折腾集Ubuntu linux 文件权限Ubuntu 空闲硬盘挂载到 文件管理器的 other locations Ubuntu开启samba和window共享文件 Ubuntu 配置 SMB 服务 安装 Samba 确保 Samba 已安装。如果未安装&#xff0c;运行以下命令进行安装&#xff…...

使用Websocket进行前后端实时通信

1、引入jar&#xff0c;spring-websocket-starter <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId> </dependency> 2、配置websocket config import org.springframe…...

vue2使用flv.js在浏览器打开flv格式视频

组件地址&#xff1a;GitHub - bilibili/flv.js: HTML5 FLV Player flv.js 仅支持 H.264 和 AAC/MP3 编码的 FLV 文件。如果视频文件使用了其他编码格式就打不开。 flv.vue <template><div><el-dialog :visible.sync"innerVisibleFlv" :close-on-pre…...

OpenCV相机标定与3D重建(61)处理未校准的立体图像对函数stereoRectifyUncalibrated()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 为未校准的立体相机计算一个校正变换。 cv::stereoRectifyUncalibrated 是 OpenCV 库中的一个函数&#xff0c;用于处理未校准的立体图像对。该函…...

[cg] glProgramBinary

参考&#xff1a; glProgramBinary - OpenGL 4 Reference Pages opengl 通过gpu编译好的 shader 可以存储到二进制文件中&#xff0c;第二次使用的时候直接加载二进制文件即可&#xff0c; glProgramBinary用于加载shader的二进制数据 实列代码如下&#xff1a; // 假设已经…...

LeetCode hot 力扣热题100 二叉树的最大深度

class Solution { public:int maxDepth(TreeNode* root) {if (root nullptr) {return 0;}int l_depth maxDepth(root->left);int r_depth maxDepth(root->right);return max(l_depth, r_depth) 1;} }; 代码作用 该函数通过递归计算二叉树的最大深度&#xff08;从根节…...

云原生核心技术 (7/12): K8s 核心概念白话解读(上):Pod 和 Deployment 究竟是什么?

大家好&#xff0c;欢迎来到《云原生核心技术》系列的第七篇&#xff01; 在上一篇&#xff0c;我们成功地使用 Minikube 或 kind 在自己的电脑上搭建起了一个迷你但功能完备的 Kubernetes 集群。现在&#xff0c;我们就像一个拥有了一块崭新数字土地的农场主&#xff0c;是时…...

地震勘探——干扰波识别、井中地震时距曲线特点

目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波&#xff1a;可以用来解决所提出的地质任务的波&#xff1b;干扰波&#xff1a;所有妨碍辨认、追踪有效波的其他波。 地震勘探中&#xff0c;有效波和干扰波是相对的。例如&#xff0c;在反射波…...

【论文笔记】若干矿井粉尘检测算法概述

总的来说&#xff0c;传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度&#xff0c;通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...

数据链路层的主要功能是什么

数据链路层&#xff08;OSI模型第2层&#xff09;的核心功能是在相邻网络节点&#xff08;如交换机、主机&#xff09;间提供可靠的数据帧传输服务&#xff0c;主要职责包括&#xff1a; &#x1f511; 核心功能详解&#xff1a; 帧封装与解封装 封装&#xff1a; 将网络层下发…...

优选算法第十二讲:队列 + 宽搜 优先级队列

优选算法第十二讲&#xff1a;队列 宽搜 && 优先级队列 1.N叉树的层序遍历2.二叉树的锯齿型层序遍历3.二叉树最大宽度4.在每个树行中找最大值5.优先级队列 -- 最后一块石头的重量6.数据流中的第K大元素7.前K个高频单词8.数据流的中位数 1.N叉树的层序遍历 2.二叉树的锯…...

Unity | AmplifyShaderEditor插件基础(第七集:平面波动shader)

目录 一、&#x1f44b;&#x1f3fb;前言 二、&#x1f608;sinx波动的基本原理 三、&#x1f608;波动起来 1.sinx节点介绍 2.vertexPosition 3.集成Vector3 a.节点Append b.连起来 4.波动起来 a.波动的原理 b.时间节点 c.sinx的处理 四、&#x1f30a;波动优化…...

USB Over IP专用硬件的5个特点

USB over IP技术通过将USB协议数据封装在标准TCP/IP网络数据包中&#xff0c;从根本上改变了USB连接。这允许客户端通过局域网或广域网远程访问和控制物理连接到服务器的USB设备&#xff08;如专用硬件设备&#xff09;&#xff0c;从而消除了直接物理连接的需要。USB over IP的…...

MySQL 知识小结(一)

一、my.cnf配置详解 我们知道安装MySQL有两种方式来安装咱们的MySQL数据库&#xff0c;分别是二进制安装编译数据库或者使用三方yum来进行安装,第三方yum的安装相对于二进制压缩包的安装更快捷&#xff0c;但是文件存放起来数据比较冗余&#xff0c;用二进制能够更好管理咱们M…...

在 Spring Boot 中使用 JSP

jsp&#xff1f; 好多年没用了。重新整一下 还费了点时间&#xff0c;记录一下。 项目结构&#xff1a; pom: <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://ww…...

【LeetCode】算法详解#6 ---除自身以外数组的乘积

1.题目介绍 给定一个整数数组 nums&#xff0c;返回 数组 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法&#xff0c;且在 O…...