当前位置: 首页 > news >正文

cuda从零开始手搓PB神经网络

cuda实现PB神经网络


基于上一篇的矩阵点乘,实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下:

#ifndef __BP_NETWORK_HPP__
#define __BP_NETWORK_HPP__
#include "matrix.hpp"
#include "mat.hpp"
#include "update_methods.hpp"template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_, int ... remain_layer>
struct bp_network
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using next_node_type = typename bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>;using term_output_type = typename next_node_type::term_output_type;weight_type weight;update_type_tpl<weight_type> weight_update_method;output_type bias;update_type_tpl<output_type> bias_update_method;input_type pre_input;output_type pre_func_input;next_node_type next_node;bp_network():weight_update_method(), bias_update_method(){weight.template reset<init_type>();bias.template reset<init_type>();next_node = bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>();}auto forward(input_type& input){output_type curr_output;pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;curr_output = pre_func_input.template activate<forward_func>();return next_node.forward(curr_output);}auto backward(term_output_type& delta, val_type lr){output_type curr_delta = next_node.backward(delta, lr);curr_delta = pre_func_input.template activate<backward_func>() * curr_delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update_method.update(weight, delta_weight);bias = bias_update_method.update(bias, curr_delta);return ret;}   // 更新惯性量void update_inert(){weight_update_method.update_inert();bias_update_method.update_inert();next_node.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("=================\n");next_node.print();}
};template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_>
struct bp_network<activate_type, val_type_, update_type_tpl, init_type, input_num_, output_num_>
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using term_output_type = typename output_type;using weight_update_type = typename update_type_tpl<weight_type>;using bias_update_type = typename update_type_tpl<output_type>;weight_type weight;weight_update_type weight_update;output_type bias;bias_update_type bias_update;output_type pre_func_input;input_type pre_input;bp_network():weight_update(), bias_update(){weight.template reset<init_type>();bias.template reset<init_type>();}auto forward(input_type& input){pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;return pre_func_input.template activate<forward_func>();}auto backward(output_type& delta, val_type lr){output_type curr_delta = pre_func_input.template activate<backward_func>() * delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update.update(weight, delta_weight);bias = bias_update.update(bias, curr_delta);return ret;}void update_inert(){weight_update.update_inert();bias_update.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("*****************\n");}
};#endif

下面实验一下我们的bp神经网络。

#include <chrono>
#include <thread>
#include "matrix.hpp"
#include "bp_network.hpp"
int main()
{constexpr int row_num = 32;constexpr int adj_num = 32;constexpr int col_num = 32;/*matrix_device_proxy<row_num, adj_num, double> A;eyes(A(), 2.0);matrix_device_proxy<adj_num, col_num, double> B;eyes(B(), 1.0);matrix_device_proxy<row_num, col_num, double> C;mat_dot<sigmoid>(A(), B(), C());print(type_cast(C()));auto A = mat<row_num, adj_num, double>::eyes(2.0);auto B = mat<adj_num, col_num, double>::eyes(1.0);auto C = A.dot(B);C = C + 1.0;C = sqrtl(C);C = C - 2.0;C = C * 3.0;C = C / 4.0;C.print();std::cout << "---------- D ----------" << std::endl;auto D = mat<row_num, col_num, double>::xavier_gaussian();D.print();std::cout << "---------- E ----------" << std::endl;auto E = mat<row_num, col_num, double>::xavier_mean();E.print();std::cout << "---------- F ----------" << std::endl;auto F = mat<row_num, col_num, double>::he_gaussian();F.print();std::cout << "---------- G ----------" << std::endl;auto G = mat<row_num, col_num, double>::he_mean();G.print();*/bp_network<sigmoid, double, nadam, xavier_gaussian_type, row_num, adj_num, col_num> node;auto input = mat<row_num, 1, double>::ones(0.2);auto expect = mat<col_num, 1, double>::ones(0.4);int times = 8000;int update_inert_times = 100;int step = times / update_inert_times;// 计时开始auto start = std::chrono::high_resolution_clock::now();for (int i = 0; i < times; ++i){auto output = node.forward(input);auto delta = (output - expect);node.backward(delta, 0.001);if (i == times - 1){output.t().print();}if (i % step == 0 && i != 0){node.update_inert();}}// 计时结束// 获取结束时间点auto end = std::chrono::high_resolution_clock::now();// 计算持续时间std::chrono::duration<double> duration = end - start;// 输出执行时间std::cout << "Execution time: " << duration.count() << " seconds" << std::endl;//node.print();cudaDeviceReset();return 0;
}

以上代码有个学习率lr没有地方设置哈,将来优化,见谅。执行结果如下:
在这里插入图片描述
可以看出,经过8000次的训练,这个使用sigmoid激活函数、NAdam优化、Xavier-Gaussian初始化的323232的PB能够将误差缩减到0.0001这个量级,而训练时间仅为8.54秒。还是相当给力的。
虽然这对于我的工作没有任何关系,但是我还是想搞一下。毕竟“越是没用的知识就越有用,越是有用的东西就越没用”。

相关文章:

cuda从零开始手搓PB神经网络

cuda实现PB神经网络 基于上一篇的矩阵点乘&#xff0c;实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下&#xff1a; #ifndef __BP_NETWORK_HPP__ #define __BP_NETWORK_HPP__ #include "matrix.hpp&quo…...

mac 安装mongodb

本文分享2种mac本地安装mongodb的方法&#xff0c;一种是通过homebrew安装&#xff0c;一种是通过tar包安装 homebrew安装 brew tap mongodb/brew brew upate brew install mongodb-community8.0tar包安装 安装mongodb 1.下载mongodb社区版的tar包 mongdb tar包下载地址 2…...

K8S-Pod资源清单的编写,资源的增删改查,镜像的下载策略

1. Pod资源清单的编写 1.1 Pod运行单个容器的资源清单 ##创建工作目录 mkdir -p /root/manifests/pods && cd /root/manifests/pods vim 01-nginx.yaml ##指定api版本 apiVersion: v1 ##指定资源类型 kind: Pod ##指定元数据 metadata:##指定名称name: myweb ##用户…...

【Maui】视图界面与数据模型绑定

文章目录 前言一、问题描述二、解决方案三、软件开发&#xff08;源码&#xff09;3.1 创建模型3.2 视图界面3.3 控制器逻辑层 四、项目展示 前言 .NET 多平台应用 UI (.NET MAUI) 是一个跨平台框架&#xff0c;用于使用 C# 和 XAML 创建本机移动和桌面应用。 使用 .NET MAUI&…...

JavaScript笔记基础篇02——运算符、语句、数组

黑马程序员视频地址&#xff1a;黑马程序员前端JavaScript入门到精通全套视频教程https://www.bilibili.com/video/BV1Y84y1L7Nn?vd_source0a2d366696f87e241adc64419bf12cab&spm_id_from333.788.videopod.episodes 目录 运算符 赋值运算符 ​编辑​编辑 一元运算符…...

心法利器[127] | 24年算法思考-特征工程和经典深度学习

心法利器 本栏目主要和大家一起讨论近期自己学习的心得和体会。具体介绍&#xff1a;仓颉专项&#xff1a;飞机大炮我都会&#xff0c;利器心法我还有。 2023年新的文章合集已经发布&#xff0c;获取方式看这里&#xff1a;又添十万字-CS的陋室2023年文章合集来袭&#xff0c;更…...

ASP.NET Core 中的 JWT 鉴权实现

在当今的软件开发中&#xff0c;安全性和用户认证是至关重要的方面。JSON Web Token&#xff08;JWT&#xff09;作为一种流行的身份验证机制&#xff0c;因其简洁性和无状态特性而被广泛应用于各种应用中&#xff0c;尤其是在 ASP.NET Core 项目里。本文将详细介绍如何在 ASP.…...

PyTorch基本功能与实现代码

PyTorch是一个开源的深度学习框架&#xff0c;提供了丰富的函数和工具&#xff0c;以下为其主要功能的归纳&#xff1a; 核心数据结构&#xff1a; • 张量&#xff08;Tensor&#xff09;&#xff1a;类似于Numpy的ndarray&#xff0c;是PyTorch中基本的数据结构&#xff0c…...

SparkSQL数据模型综合实践

文章目录 1. 实战概述2. 实战步骤2.1 创建数据集2.2 创建数据模型对象2.2.1 创建常量2.2.2 创建加载数据方法2.2.3 创建过滤年龄方法2.2.4 创建平均薪水方法2.2.5 创建主方法2.2.6 查看完整代码 2.3 运行程序&#xff0c;查看结果 3. 实战小结 1. 实战概述 在本次实战中&#…...

3 查找重复的电子邮箱(having与where区别,distinct去重使用)

3 查找重复的电子邮箱&#xff08;having与where区别&#xff0c;distinct去重使用&#xff09; 表: Person ---------------------- | Column Name | Type | ---------------------- | id | int | | email | varchar | ---------------------- id 是该…...

uniapp——App 监听下载文件状态,打开文件(三)

5 实现下载文件并打开 这里演示&#xff0c;导出Excel 表格 文章目录 5 实现下载文件并打开DEMO监听下载进度效果图为什么 totalSize 一直为0&#xff1f; 相关Api&#xff1a; downloader DEMO 提示&#xff1a; 请求方式支持&#xff1a;GET、POST&#xff1b;POST 方式需要…...

循环队列(C语言)

从今天开始我会开启一个专栏leetcode每日一题&#xff0c;大家互相交流代码经验&#xff0c;也当作我每天练习的自我回顾。第一天的内容是leetcode622.设计循环队列。 一、题目详细 设计你的循环队列实现。 循环队列是一种线性数据结构&#xff0c;其操作表现基于 FIFO&#…...

数据可视化:让数据讲故事的艺术

目录 1 前言2 数据可视化的基本概念2.1 可视化的核心目标2.2 传统可视化手段 3 数据可视化在知识图谱中的应用3.1 知识图谱的可视化需求3.2 知识图谱的可视化方法 4 数据可视化叙事&#xff1a;让数据讲故事4.1 叙事可视化的关键要素4.2 数据可视化叙事的实现方法 5 数据可视化…...

雷电9最新版安装Magisk+LSPosd(新手速通)

大家好啊&#xff01;我是NiJiMingCheng 我的博客&#xff1a;NiJiMingCheng 在安卓系统的定制与拓展过程中&#xff0c;获取 ROOT 权限以及安装各类框架是进阶玩家常用的操作&#xff0c;这可以帮助我们实现更多系统层面的个性化功能。今天&#xff0c;我将为大家详细介绍如何…...

Ubuntu 24.04 LTS 开启 SMB 服务,并通过 windows 访问

Ubuntu 24.04 LTS 背景资料 Ubuntu服务器折腾集Ubuntu linux 文件权限Ubuntu 空闲硬盘挂载到 文件管理器的 other locations Ubuntu开启samba和window共享文件 Ubuntu 配置 SMB 服务 安装 Samba 确保 Samba 已安装。如果未安装&#xff0c;运行以下命令进行安装&#xff…...

使用Websocket进行前后端实时通信

1、引入jar&#xff0c;spring-websocket-starter <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId> </dependency> 2、配置websocket config import org.springframe…...

vue2使用flv.js在浏览器打开flv格式视频

组件地址&#xff1a;GitHub - bilibili/flv.js: HTML5 FLV Player flv.js 仅支持 H.264 和 AAC/MP3 编码的 FLV 文件。如果视频文件使用了其他编码格式就打不开。 flv.vue <template><div><el-dialog :visible.sync"innerVisibleFlv" :close-on-pre…...

OpenCV相机标定与3D重建(61)处理未校准的立体图像对函数stereoRectifyUncalibrated()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 为未校准的立体相机计算一个校正变换。 cv::stereoRectifyUncalibrated 是 OpenCV 库中的一个函数&#xff0c;用于处理未校准的立体图像对。该函…...

[cg] glProgramBinary

参考&#xff1a; glProgramBinary - OpenGL 4 Reference Pages opengl 通过gpu编译好的 shader 可以存储到二进制文件中&#xff0c;第二次使用的时候直接加载二进制文件即可&#xff0c; glProgramBinary用于加载shader的二进制数据 实列代码如下&#xff1a; // 假设已经…...

LeetCode hot 力扣热题100 二叉树的最大深度

class Solution { public:int maxDepth(TreeNode* root) {if (root nullptr) {return 0;}int l_depth maxDepth(root->left);int r_depth maxDepth(root->right);return max(l_depth, r_depth) 1;} }; 代码作用 该函数通过递归计算二叉树的最大深度&#xff08;从根节…...

用Python和MNE库玩转BCI Competition IV 2a脑电数据集:从数据加载到可视化全流程

用Python和MNE库玩转BCI Competition IV 2a脑电数据集&#xff1a;从数据加载到可视化全流程当你第一次接触脑电信号处理时&#xff0c;面对原始数据文件可能会感到无从下手。BCI Competition IV 2a数据集作为脑机接口领域的经典基准数据&#xff0c;包含了9名受试者四种运动想…...

30岁裸辞后,我用两个月拿下AI应用认证,现在OFFER选择困难症犯了

30岁裸辞那天&#xff0c;我最怕的不是没收入&#xff0c;而是突然发现&#xff1a;过去积累的经验&#xff0c;正在被AI重新定价。以前会写方案、做表格、跟项目&#xff0c;算是职场硬通货&#xff1b;到了2026年&#xff0c;招聘JD里开始频繁出现AI工具应用、智能工作流、Pr…...

智慧树自动刷课助手:3步告别手动操作的学习效率工具

智慧树自动刷课助手&#xff1a;3步告别手动操作的学习效率工具 【免费下载链接】zhihuishu 智慧树刷课插件&#xff0c;自动播放下一集、1.5倍速度、无声 项目地址: https://gitcode.com/gh_mirrors/zh/zhihuishu 还在为智慧树平台的重复刷课操作而烦恼吗&#xff1f;智…...

Scroll Reverser:让Mac的多设备滚动体验回归直觉的免费神器

Scroll Reverser&#xff1a;让Mac的多设备滚动体验回归直觉的免费神器 【免费下载链接】Scroll-Reverser Per-device scrolling prefs on macOS. 项目地址: https://gitcode.com/gh_mirrors/sc/Scroll-Reverser 你是否曾经在MacBook的触控板和鼠标之间切换时&#xff0…...

Web渗透测试能力成长地图:从工具使用到漏洞认知跃迁

1. 这不是工具清单&#xff0c;而是一张Web渗透测试的“能力成长地图”你刚点开这篇文章&#xff0c;大概率正站在两个路口之间&#xff1a;一边是网上铺天盖地的“十大免费扫描器推荐”&#xff0c;点进去全是截图下载链接一句“一键扫漏洞”&#xff0c;结果装完跑两下&#…...

taotoken如何帮助ubuntu开发者应对大模型api的频繁更新与版本迭代

&#x1f680; 告别海外账号与网络限制&#xff01;稳定直连全球优质大模型&#xff0c;限时半价接入中。 &#x1f449; 点击领取海量免费额度 Taotoken如何帮助Ubuntu开发者应对大模型API的频繁更新与版本迭代 对于在Ubuntu环境下进行开发的工程师而言&#xff0c;大模型API…...

用图神经网络做缺陷定位,准确率比传统方法高出30%

在现代软件工程的复杂迷宫中&#xff0c;缺陷定位始终是测试团队面临的核心挑战。想象这样一个场景&#xff1a;一个电商系统在特定压力条件下偶发订单丢失&#xff0c;日志中只留下泛泛的超时错误&#xff0c;问题可能深藏在上百个微服务的调用链、分布式事务的竞态条件或某个…...

DeepSeek代码风格检查避坑指南(内部审计报告首次披露:37个被忽略的合规红线)

更多请点击&#xff1a; https://intelliparadigm.com 第一章&#xff1a;DeepSeek代码风格检查的合规性本质与审计背景 DeepSeek代码风格检查并非单纯的技术偏好约束&#xff0c;而是嵌入研发治理链条中的合规性控制节点。其本质是将编程实践与组织级安全策略、行业监管要求&…...

别再瞎拖拽了!Unity Prefab从创建到批量修改的保姆级工作流(含变体与嵌套实战)

Unity Prefab高效工作流&#xff1a;从创建到批量修改的实战指南在Unity项目开发中&#xff0c;Prefab&#xff08;预制体&#xff09;是最基础也最强大的工具之一。但很多开发者&#xff0c;尤其是初学者&#xff0c;往往停留在简单的"拖拽-修改"阶段&#xff0c;没…...

自然语言处理的实战项目:从0到1搭建属于自己的文本分类系统

对于软件测试从业者而言&#xff0c;日常工作中我们每天都会接触大量的文本数据&#xff1a;缺陷管理系统中的bug描述、测试用例的步骤说明、用户反馈的问题报告、需求文档的规格描述&#xff0c;甚至是接口返回的异常信息文本。这些非结构化文本往往隐含着关键业务信息&#x…...