cuda从零开始手搓PB神经网络
cuda实现PB神经网络
基于上一篇的矩阵点乘,实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下:
#ifndef __BP_NETWORK_HPP__
#define __BP_NETWORK_HPP__
#include "matrix.hpp"
#include "mat.hpp"
#include "update_methods.hpp"template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_, int ... remain_layer>
struct bp_network
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using next_node_type = typename bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>;using term_output_type = typename next_node_type::term_output_type;weight_type weight;update_type_tpl<weight_type> weight_update_method;output_type bias;update_type_tpl<output_type> bias_update_method;input_type pre_input;output_type pre_func_input;next_node_type next_node;bp_network():weight_update_method(), bias_update_method(){weight.template reset<init_type>();bias.template reset<init_type>();next_node = bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>();}auto forward(input_type& input){output_type curr_output;pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;curr_output = pre_func_input.template activate<forward_func>();return next_node.forward(curr_output);}auto backward(term_output_type& delta, val_type lr){output_type curr_delta = next_node.backward(delta, lr);curr_delta = pre_func_input.template activate<backward_func>() * curr_delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update_method.update(weight, delta_weight);bias = bias_update_method.update(bias, curr_delta);return ret;} // 更新惯性量void update_inert(){weight_update_method.update_inert();bias_update_method.update_inert();next_node.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("=================\n");next_node.print();}
};template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_>
struct bp_network<activate_type, val_type_, update_type_tpl, init_type, input_num_, output_num_>
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using term_output_type = typename output_type;using weight_update_type = typename update_type_tpl<weight_type>;using bias_update_type = typename update_type_tpl<output_type>;weight_type weight;weight_update_type weight_update;output_type bias;bias_update_type bias_update;output_type pre_func_input;input_type pre_input;bp_network():weight_update(), bias_update(){weight.template reset<init_type>();bias.template reset<init_type>();}auto forward(input_type& input){pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;return pre_func_input.template activate<forward_func>();}auto backward(output_type& delta, val_type lr){output_type curr_delta = pre_func_input.template activate<backward_func>() * delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update.update(weight, delta_weight);bias = bias_update.update(bias, curr_delta);return ret;}void update_inert(){weight_update.update_inert();bias_update.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("*****************\n");}
};#endif
下面实验一下我们的bp神经网络。
#include <chrono>
#include <thread>
#include "matrix.hpp"
#include "bp_network.hpp"
int main()
{constexpr int row_num = 32;constexpr int adj_num = 32;constexpr int col_num = 32;/*matrix_device_proxy<row_num, adj_num, double> A;eyes(A(), 2.0);matrix_device_proxy<adj_num, col_num, double> B;eyes(B(), 1.0);matrix_device_proxy<row_num, col_num, double> C;mat_dot<sigmoid>(A(), B(), C());print(type_cast(C()));auto A = mat<row_num, adj_num, double>::eyes(2.0);auto B = mat<adj_num, col_num, double>::eyes(1.0);auto C = A.dot(B);C = C + 1.0;C = sqrtl(C);C = C - 2.0;C = C * 3.0;C = C / 4.0;C.print();std::cout << "---------- D ----------" << std::endl;auto D = mat<row_num, col_num, double>::xavier_gaussian();D.print();std::cout << "---------- E ----------" << std::endl;auto E = mat<row_num, col_num, double>::xavier_mean();E.print();std::cout << "---------- F ----------" << std::endl;auto F = mat<row_num, col_num, double>::he_gaussian();F.print();std::cout << "---------- G ----------" << std::endl;auto G = mat<row_num, col_num, double>::he_mean();G.print();*/bp_network<sigmoid, double, nadam, xavier_gaussian_type, row_num, adj_num, col_num> node;auto input = mat<row_num, 1, double>::ones(0.2);auto expect = mat<col_num, 1, double>::ones(0.4);int times = 8000;int update_inert_times = 100;int step = times / update_inert_times;// 计时开始auto start = std::chrono::high_resolution_clock::now();for (int i = 0; i < times; ++i){auto output = node.forward(input);auto delta = (output - expect);node.backward(delta, 0.001);if (i == times - 1){output.t().print();}if (i % step == 0 && i != 0){node.update_inert();}}// 计时结束// 获取结束时间点auto end = std::chrono::high_resolution_clock::now();// 计算持续时间std::chrono::duration<double> duration = end - start;// 输出执行时间std::cout << "Execution time: " << duration.count() << " seconds" << std::endl;//node.print();cudaDeviceReset();return 0;
}
以上代码有个学习率lr没有地方设置哈,将来优化,见谅。执行结果如下:

可以看出,经过8000次的训练,这个使用sigmoid激活函数、NAdam优化、Xavier-Gaussian初始化的323232的PB能够将误差缩减到0.0001这个量级,而训练时间仅为8.54秒。还是相当给力的。
虽然这对于我的工作没有任何关系,但是我还是想搞一下。毕竟“越是没用的知识就越有用,越是有用的东西就越没用”。
相关文章:
cuda从零开始手搓PB神经网络
cuda实现PB神经网络 基于上一篇的矩阵点乘,实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下: #ifndef __BP_NETWORK_HPP__ #define __BP_NETWORK_HPP__ #include "matrix.hpp&quo…...
mac 安装mongodb
本文分享2种mac本地安装mongodb的方法,一种是通过homebrew安装,一种是通过tar包安装 homebrew安装 brew tap mongodb/brew brew upate brew install mongodb-community8.0tar包安装 安装mongodb 1.下载mongodb社区版的tar包 mongdb tar包下载地址 2…...
K8S-Pod资源清单的编写,资源的增删改查,镜像的下载策略
1. Pod资源清单的编写 1.1 Pod运行单个容器的资源清单 ##创建工作目录 mkdir -p /root/manifests/pods && cd /root/manifests/pods vim 01-nginx.yaml ##指定api版本 apiVersion: v1 ##指定资源类型 kind: Pod ##指定元数据 metadata:##指定名称name: myweb ##用户…...
【Maui】视图界面与数据模型绑定
文章目录 前言一、问题描述二、解决方案三、软件开发(源码)3.1 创建模型3.2 视图界面3.3 控制器逻辑层 四、项目展示 前言 .NET 多平台应用 UI (.NET MAUI) 是一个跨平台框架,用于使用 C# 和 XAML 创建本机移动和桌面应用。 使用 .NET MAUI&…...
JavaScript笔记基础篇02——运算符、语句、数组
黑马程序员视频地址:黑马程序员前端JavaScript入门到精通全套视频教程https://www.bilibili.com/video/BV1Y84y1L7Nn?vd_source0a2d366696f87e241adc64419bf12cab&spm_id_from333.788.videopod.episodes 目录 运算符 赋值运算符 编辑编辑 一元运算符…...
心法利器[127] | 24年算法思考-特征工程和经典深度学习
心法利器 本栏目主要和大家一起讨论近期自己学习的心得和体会。具体介绍:仓颉专项:飞机大炮我都会,利器心法我还有。 2023年新的文章合集已经发布,获取方式看这里:又添十万字-CS的陋室2023年文章合集来袭,更…...
ASP.NET Core 中的 JWT 鉴权实现
在当今的软件开发中,安全性和用户认证是至关重要的方面。JSON Web Token(JWT)作为一种流行的身份验证机制,因其简洁性和无状态特性而被广泛应用于各种应用中,尤其是在 ASP.NET Core 项目里。本文将详细介绍如何在 ASP.…...
PyTorch基本功能与实现代码
PyTorch是一个开源的深度学习框架,提供了丰富的函数和工具,以下为其主要功能的归纳: 核心数据结构: • 张量(Tensor):类似于Numpy的ndarray,是PyTorch中基本的数据结构,…...
SparkSQL数据模型综合实践
文章目录 1. 实战概述2. 实战步骤2.1 创建数据集2.2 创建数据模型对象2.2.1 创建常量2.2.2 创建加载数据方法2.2.3 创建过滤年龄方法2.2.4 创建平均薪水方法2.2.5 创建主方法2.2.6 查看完整代码 2.3 运行程序,查看结果 3. 实战小结 1. 实战概述 在本次实战中&#…...
3 查找重复的电子邮箱(having与where区别,distinct去重使用)
3 查找重复的电子邮箱(having与where区别,distinct去重使用) 表: Person ---------------------- | Column Name | Type | ---------------------- | id | int | | email | varchar | ---------------------- id 是该…...
uniapp——App 监听下载文件状态,打开文件(三)
5 实现下载文件并打开 这里演示,导出Excel 表格 文章目录 5 实现下载文件并打开DEMO监听下载进度效果图为什么 totalSize 一直为0? 相关Api: downloader DEMO 提示: 请求方式支持:GET、POST;POST 方式需要…...
循环队列(C语言)
从今天开始我会开启一个专栏leetcode每日一题,大家互相交流代码经验,也当作我每天练习的自我回顾。第一天的内容是leetcode622.设计循环队列。 一、题目详细 设计你的循环队列实现。 循环队列是一种线性数据结构,其操作表现基于 FIFO&#…...
数据可视化:让数据讲故事的艺术
目录 1 前言2 数据可视化的基本概念2.1 可视化的核心目标2.2 传统可视化手段 3 数据可视化在知识图谱中的应用3.1 知识图谱的可视化需求3.2 知识图谱的可视化方法 4 数据可视化叙事:让数据讲故事4.1 叙事可视化的关键要素4.2 数据可视化叙事的实现方法 5 数据可视化…...
雷电9最新版安装Magisk+LSPosd(新手速通)
大家好啊!我是NiJiMingCheng 我的博客:NiJiMingCheng 在安卓系统的定制与拓展过程中,获取 ROOT 权限以及安装各类框架是进阶玩家常用的操作,这可以帮助我们实现更多系统层面的个性化功能。今天,我将为大家详细介绍如何…...
Ubuntu 24.04 LTS 开启 SMB 服务,并通过 windows 访问
Ubuntu 24.04 LTS 背景资料 Ubuntu服务器折腾集Ubuntu linux 文件权限Ubuntu 空闲硬盘挂载到 文件管理器的 other locations Ubuntu开启samba和window共享文件 Ubuntu 配置 SMB 服务 安装 Samba 确保 Samba 已安装。如果未安装,运行以下命令进行安装ÿ…...
使用Websocket进行前后端实时通信
1、引入jar,spring-websocket-starter <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId> </dependency> 2、配置websocket config import org.springframe…...
vue2使用flv.js在浏览器打开flv格式视频
组件地址:GitHub - bilibili/flv.js: HTML5 FLV Player flv.js 仅支持 H.264 和 AAC/MP3 编码的 FLV 文件。如果视频文件使用了其他编码格式就打不开。 flv.vue <template><div><el-dialog :visible.sync"innerVisibleFlv" :close-on-pre…...
OpenCV相机标定与3D重建(61)处理未校准的立体图像对函数stereoRectifyUncalibrated()的使用
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 为未校准的立体相机计算一个校正变换。 cv::stereoRectifyUncalibrated 是 OpenCV 库中的一个函数,用于处理未校准的立体图像对。该函…...
[cg] glProgramBinary
参考: glProgramBinary - OpenGL 4 Reference Pages opengl 通过gpu编译好的 shader 可以存储到二进制文件中,第二次使用的时候直接加载二进制文件即可, glProgramBinary用于加载shader的二进制数据 实列代码如下: // 假设已经…...
LeetCode hot 力扣热题100 二叉树的最大深度
class Solution { public:int maxDepth(TreeNode* root) {if (root nullptr) {return 0;}int l_depth maxDepth(root->left);int r_depth maxDepth(root->right);return max(l_depth, r_depth) 1;} }; 代码作用 该函数通过递归计算二叉树的最大深度(从根节…...
用Python和MNE库玩转BCI Competition IV 2a脑电数据集:从数据加载到可视化全流程
用Python和MNE库玩转BCI Competition IV 2a脑电数据集:从数据加载到可视化全流程当你第一次接触脑电信号处理时,面对原始数据文件可能会感到无从下手。BCI Competition IV 2a数据集作为脑机接口领域的经典基准数据,包含了9名受试者四种运动想…...
30岁裸辞后,我用两个月拿下AI应用认证,现在OFFER选择困难症犯了
30岁裸辞那天,我最怕的不是没收入,而是突然发现:过去积累的经验,正在被AI重新定价。以前会写方案、做表格、跟项目,算是职场硬通货;到了2026年,招聘JD里开始频繁出现AI工具应用、智能工作流、Pr…...
智慧树自动刷课助手:3步告别手动操作的学习效率工具
智慧树自动刷课助手:3步告别手动操作的学习效率工具 【免费下载链接】zhihuishu 智慧树刷课插件,自动播放下一集、1.5倍速度、无声 项目地址: https://gitcode.com/gh_mirrors/zh/zhihuishu 还在为智慧树平台的重复刷课操作而烦恼吗?智…...
Scroll Reverser:让Mac的多设备滚动体验回归直觉的免费神器
Scroll Reverser:让Mac的多设备滚动体验回归直觉的免费神器 【免费下载链接】Scroll-Reverser Per-device scrolling prefs on macOS. 项目地址: https://gitcode.com/gh_mirrors/sc/Scroll-Reverser 你是否曾经在MacBook的触控板和鼠标之间切换时࿰…...
Web渗透测试能力成长地图:从工具使用到漏洞认知跃迁
1. 这不是工具清单,而是一张Web渗透测试的“能力成长地图”你刚点开这篇文章,大概率正站在两个路口之间:一边是网上铺天盖地的“十大免费扫描器推荐”,点进去全是截图下载链接一句“一键扫漏洞”,结果装完跑两下&#…...
taotoken如何帮助ubuntu开发者应对大模型api的频繁更新与版本迭代
🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 Taotoken如何帮助Ubuntu开发者应对大模型API的频繁更新与版本迭代 对于在Ubuntu环境下进行开发的工程师而言,大模型API…...
用图神经网络做缺陷定位,准确率比传统方法高出30%
在现代软件工程的复杂迷宫中,缺陷定位始终是测试团队面临的核心挑战。想象这样一个场景:一个电商系统在特定压力条件下偶发订单丢失,日志中只留下泛泛的超时错误,问题可能深藏在上百个微服务的调用链、分布式事务的竞态条件或某个…...
DeepSeek代码风格检查避坑指南(内部审计报告首次披露:37个被忽略的合规红线)
更多请点击: https://intelliparadigm.com 第一章:DeepSeek代码风格检查的合规性本质与审计背景 DeepSeek代码风格检查并非单纯的技术偏好约束,而是嵌入研发治理链条中的合规性控制节点。其本质是将编程实践与组织级安全策略、行业监管要求&…...
别再瞎拖拽了!Unity Prefab从创建到批量修改的保姆级工作流(含变体与嵌套实战)
Unity Prefab高效工作流:从创建到批量修改的实战指南在Unity项目开发中,Prefab(预制体)是最基础也最强大的工具之一。但很多开发者,尤其是初学者,往往停留在简单的"拖拽-修改"阶段,没…...
自然语言处理的实战项目:从0到1搭建属于自己的文本分类系统
对于软件测试从业者而言,日常工作中我们每天都会接触大量的文本数据:缺陷管理系统中的bug描述、测试用例的步骤说明、用户反馈的问题报告、需求文档的规格描述,甚至是接口返回的异常信息文本。这些非结构化文本往往隐含着关键业务信息&#x…...
