cuda从零开始手搓PB神经网络
cuda实现PB神经网络
基于上一篇的矩阵点乘,实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下:
#ifndef __BP_NETWORK_HPP__
#define __BP_NETWORK_HPP__
#include "matrix.hpp"
#include "mat.hpp"
#include "update_methods.hpp"template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_, int ... remain_layer>
struct bp_network
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using next_node_type = typename bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>;using term_output_type = typename next_node_type::term_output_type;weight_type weight;update_type_tpl<weight_type> weight_update_method;output_type bias;update_type_tpl<output_type> bias_update_method;input_type pre_input;output_type pre_func_input;next_node_type next_node;bp_network():weight_update_method(), bias_update_method(){weight.template reset<init_type>();bias.template reset<init_type>();next_node = bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>();}auto forward(input_type& input){output_type curr_output;pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;curr_output = pre_func_input.template activate<forward_func>();return next_node.forward(curr_output);}auto backward(term_output_type& delta, val_type lr){output_type curr_delta = next_node.backward(delta, lr);curr_delta = pre_func_input.template activate<backward_func>() * curr_delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update_method.update(weight, delta_weight);bias = bias_update_method.update(bias, curr_delta);return ret;} // 更新惯性量void update_inert(){weight_update_method.update_inert();bias_update_method.update_inert();next_node.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("=================\n");next_node.print();}
};template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_>
struct bp_network<activate_type, val_type_, update_type_tpl, init_type, input_num_, output_num_>
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using term_output_type = typename output_type;using weight_update_type = typename update_type_tpl<weight_type>;using bias_update_type = typename update_type_tpl<output_type>;weight_type weight;weight_update_type weight_update;output_type bias;bias_update_type bias_update;output_type pre_func_input;input_type pre_input;bp_network():weight_update(), bias_update(){weight.template reset<init_type>();bias.template reset<init_type>();}auto forward(input_type& input){pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;return pre_func_input.template activate<forward_func>();}auto backward(output_type& delta, val_type lr){output_type curr_delta = pre_func_input.template activate<backward_func>() * delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update.update(weight, delta_weight);bias = bias_update.update(bias, curr_delta);return ret;}void update_inert(){weight_update.update_inert();bias_update.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("*****************\n");}
};#endif
下面实验一下我们的bp神经网络。
#include <chrono>
#include <thread>
#include "matrix.hpp"
#include "bp_network.hpp"
int main()
{constexpr int row_num = 32;constexpr int adj_num = 32;constexpr int col_num = 32;/*matrix_device_proxy<row_num, adj_num, double> A;eyes(A(), 2.0);matrix_device_proxy<adj_num, col_num, double> B;eyes(B(), 1.0);matrix_device_proxy<row_num, col_num, double> C;mat_dot<sigmoid>(A(), B(), C());print(type_cast(C()));auto A = mat<row_num, adj_num, double>::eyes(2.0);auto B = mat<adj_num, col_num, double>::eyes(1.0);auto C = A.dot(B);C = C + 1.0;C = sqrtl(C);C = C - 2.0;C = C * 3.0;C = C / 4.0;C.print();std::cout << "---------- D ----------" << std::endl;auto D = mat<row_num, col_num, double>::xavier_gaussian();D.print();std::cout << "---------- E ----------" << std::endl;auto E = mat<row_num, col_num, double>::xavier_mean();E.print();std::cout << "---------- F ----------" << std::endl;auto F = mat<row_num, col_num, double>::he_gaussian();F.print();std::cout << "---------- G ----------" << std::endl;auto G = mat<row_num, col_num, double>::he_mean();G.print();*/bp_network<sigmoid, double, nadam, xavier_gaussian_type, row_num, adj_num, col_num> node;auto input = mat<row_num, 1, double>::ones(0.2);auto expect = mat<col_num, 1, double>::ones(0.4);int times = 8000;int update_inert_times = 100;int step = times / update_inert_times;// 计时开始auto start = std::chrono::high_resolution_clock::now();for (int i = 0; i < times; ++i){auto output = node.forward(input);auto delta = (output - expect);node.backward(delta, 0.001);if (i == times - 1){output.t().print();}if (i % step == 0 && i != 0){node.update_inert();}}// 计时结束// 获取结束时间点auto end = std::chrono::high_resolution_clock::now();// 计算持续时间std::chrono::duration<double> duration = end - start;// 输出执行时间std::cout << "Execution time: " << duration.count() << " seconds" << std::endl;//node.print();cudaDeviceReset();return 0;
}
以上代码有个学习率lr没有地方设置哈,将来优化,见谅。执行结果如下:

可以看出,经过8000次的训练,这个使用sigmoid激活函数、NAdam优化、Xavier-Gaussian初始化的323232的PB能够将误差缩减到0.0001这个量级,而训练时间仅为8.54秒。还是相当给力的。
虽然这对于我的工作没有任何关系,但是我还是想搞一下。毕竟“越是没用的知识就越有用,越是有用的东西就越没用”。
相关文章:
cuda从零开始手搓PB神经网络
cuda实现PB神经网络 基于上一篇的矩阵点乘,实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下: #ifndef __BP_NETWORK_HPP__ #define __BP_NETWORK_HPP__ #include "matrix.hpp&quo…...
mac 安装mongodb
本文分享2种mac本地安装mongodb的方法,一种是通过homebrew安装,一种是通过tar包安装 homebrew安装 brew tap mongodb/brew brew upate brew install mongodb-community8.0tar包安装 安装mongodb 1.下载mongodb社区版的tar包 mongdb tar包下载地址 2…...
K8S-Pod资源清单的编写,资源的增删改查,镜像的下载策略
1. Pod资源清单的编写 1.1 Pod运行单个容器的资源清单 ##创建工作目录 mkdir -p /root/manifests/pods && cd /root/manifests/pods vim 01-nginx.yaml ##指定api版本 apiVersion: v1 ##指定资源类型 kind: Pod ##指定元数据 metadata:##指定名称name: myweb ##用户…...
【Maui】视图界面与数据模型绑定
文章目录 前言一、问题描述二、解决方案三、软件开发(源码)3.1 创建模型3.2 视图界面3.3 控制器逻辑层 四、项目展示 前言 .NET 多平台应用 UI (.NET MAUI) 是一个跨平台框架,用于使用 C# 和 XAML 创建本机移动和桌面应用。 使用 .NET MAUI&…...
JavaScript笔记基础篇02——运算符、语句、数组
黑马程序员视频地址:黑马程序员前端JavaScript入门到精通全套视频教程https://www.bilibili.com/video/BV1Y84y1L7Nn?vd_source0a2d366696f87e241adc64419bf12cab&spm_id_from333.788.videopod.episodes 目录 运算符 赋值运算符 编辑编辑 一元运算符…...
心法利器[127] | 24年算法思考-特征工程和经典深度学习
心法利器 本栏目主要和大家一起讨论近期自己学习的心得和体会。具体介绍:仓颉专项:飞机大炮我都会,利器心法我还有。 2023年新的文章合集已经发布,获取方式看这里:又添十万字-CS的陋室2023年文章合集来袭,更…...
ASP.NET Core 中的 JWT 鉴权实现
在当今的软件开发中,安全性和用户认证是至关重要的方面。JSON Web Token(JWT)作为一种流行的身份验证机制,因其简洁性和无状态特性而被广泛应用于各种应用中,尤其是在 ASP.NET Core 项目里。本文将详细介绍如何在 ASP.…...
PyTorch基本功能与实现代码
PyTorch是一个开源的深度学习框架,提供了丰富的函数和工具,以下为其主要功能的归纳: 核心数据结构: • 张量(Tensor):类似于Numpy的ndarray,是PyTorch中基本的数据结构,…...
SparkSQL数据模型综合实践
文章目录 1. 实战概述2. 实战步骤2.1 创建数据集2.2 创建数据模型对象2.2.1 创建常量2.2.2 创建加载数据方法2.2.3 创建过滤年龄方法2.2.4 创建平均薪水方法2.2.5 创建主方法2.2.6 查看完整代码 2.3 运行程序,查看结果 3. 实战小结 1. 实战概述 在本次实战中&#…...
3 查找重复的电子邮箱(having与where区别,distinct去重使用)
3 查找重复的电子邮箱(having与where区别,distinct去重使用) 表: Person ---------------------- | Column Name | Type | ---------------------- | id | int | | email | varchar | ---------------------- id 是该…...
uniapp——App 监听下载文件状态,打开文件(三)
5 实现下载文件并打开 这里演示,导出Excel 表格 文章目录 5 实现下载文件并打开DEMO监听下载进度效果图为什么 totalSize 一直为0? 相关Api: downloader DEMO 提示: 请求方式支持:GET、POST;POST 方式需要…...
循环队列(C语言)
从今天开始我会开启一个专栏leetcode每日一题,大家互相交流代码经验,也当作我每天练习的自我回顾。第一天的内容是leetcode622.设计循环队列。 一、题目详细 设计你的循环队列实现。 循环队列是一种线性数据结构,其操作表现基于 FIFO&#…...
数据可视化:让数据讲故事的艺术
目录 1 前言2 数据可视化的基本概念2.1 可视化的核心目标2.2 传统可视化手段 3 数据可视化在知识图谱中的应用3.1 知识图谱的可视化需求3.2 知识图谱的可视化方法 4 数据可视化叙事:让数据讲故事4.1 叙事可视化的关键要素4.2 数据可视化叙事的实现方法 5 数据可视化…...
雷电9最新版安装Magisk+LSPosd(新手速通)
大家好啊!我是NiJiMingCheng 我的博客:NiJiMingCheng 在安卓系统的定制与拓展过程中,获取 ROOT 权限以及安装各类框架是进阶玩家常用的操作,这可以帮助我们实现更多系统层面的个性化功能。今天,我将为大家详细介绍如何…...
Ubuntu 24.04 LTS 开启 SMB 服务,并通过 windows 访问
Ubuntu 24.04 LTS 背景资料 Ubuntu服务器折腾集Ubuntu linux 文件权限Ubuntu 空闲硬盘挂载到 文件管理器的 other locations Ubuntu开启samba和window共享文件 Ubuntu 配置 SMB 服务 安装 Samba 确保 Samba 已安装。如果未安装,运行以下命令进行安装ÿ…...
使用Websocket进行前后端实时通信
1、引入jar,spring-websocket-starter <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId> </dependency> 2、配置websocket config import org.springframe…...
vue2使用flv.js在浏览器打开flv格式视频
组件地址:GitHub - bilibili/flv.js: HTML5 FLV Player flv.js 仅支持 H.264 和 AAC/MP3 编码的 FLV 文件。如果视频文件使用了其他编码格式就打不开。 flv.vue <template><div><el-dialog :visible.sync"innerVisibleFlv" :close-on-pre…...
OpenCV相机标定与3D重建(61)处理未校准的立体图像对函数stereoRectifyUncalibrated()的使用
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 为未校准的立体相机计算一个校正变换。 cv::stereoRectifyUncalibrated 是 OpenCV 库中的一个函数,用于处理未校准的立体图像对。该函…...
[cg] glProgramBinary
参考: glProgramBinary - OpenGL 4 Reference Pages opengl 通过gpu编译好的 shader 可以存储到二进制文件中,第二次使用的时候直接加载二进制文件即可, glProgramBinary用于加载shader的二进制数据 实列代码如下: // 假设已经…...
LeetCode hot 力扣热题100 二叉树的最大深度
class Solution { public:int maxDepth(TreeNode* root) {if (root nullptr) {return 0;}int l_depth maxDepth(root->left);int r_depth maxDepth(root->right);return max(l_depth, r_depth) 1;} }; 代码作用 该函数通过递归计算二叉树的最大深度(从根节…...
哔哩下载姬Downkyi:5分钟解锁B站视频批量下载新境界
哔哩下载姬Downkyi:5分钟解锁B站视频批量下载新境界 【免费下载链接】downkyi 哔哩下载姬downkyi,哔哩哔哩网站视频下载工具,支持批量下载,支持8K、HDR、杜比视界,提供工具箱(音视频提取、去水印等…...
落地即能用!声振温监测部署全流程:设备在线状态监控搭建指南
设备在线状态监控的核心,是通过声振温三大核心数据,捕捉设备隐性故障前兆,实现“早发现、早预警、早处置”,避免非计划停机。而声振温监测的部署,并非简单的“装传感器、连系统”,需遵循科学流程࿰…...
2.1SQL 学习:先懂数据库概念再学 SQL
2.1SQL 学习:先懂数据库概念再学 SQL 开篇:为什么学SQL前要先搞懂数据库概念 我入行第一年,领导丢给我一个数据库账号,说“去把昨天的订单数据查出来”。我打开Navicat,看到左边一长串陌生的表名,完全不知道…...
Meta推出由高薪超级智能实验室研发的全新AI模型
Meta于本周三正式发布了其最新人工智能模型,这也是该公司组建一支高薪团队以在AI赛道上与竞争对手展开较量后推出的首个重磅成果。这款名为Muse Spark的新模型由Meta超级智能实验室打造。该实验室汇聚了一批来自各大AI公司的顶尖人才,于去年正式成立&…...
Java虚拟线程落地避坑指南(生产环境血泪总结:从Spring Boot 3.3集成到Project Loom异常传播链断裂修复)
第一章:Java 25虚拟线程核心原理与高并发演进全景Java 25正式将虚拟线程(Virtual Threads)从预览特性转为标准特性,标志着JVM并发模型进入轻量级线程时代。虚拟线程由JVM在用户态调度,底层复用有限的平台线程ÿ…...
数据库回顾
题目:584. 寻找用户推荐人 表: Customer ---------------------- | Column Name | Type | ---------------------- | id | int | | name | varchar | | referee_id | int | ---------------------- 在 SQL 中,id 是该表的…...
为什么你的GraalVM镜像比JVM运行时多占62%内存?20年HotSpot/Graal双栈专家首次公开12项静态编译内存压缩清单
第一章:GraalVM静态镜像内存膨胀的本质归因GraalVM 静态原生镜像(Native Image)在启动性能与资源占用方面具有显著优势,但实践中常观察到生成的二进制文件体积远超预期,且运行时堆外内存(尤其是元数据区、字…...
OpenHand:自适应抓取技术的开源硬件革新
OpenHand:自适应抓取技术的开源硬件革新 【免费下载链接】openhand-hardware CAD files for the OpenHand hand designs 项目地址: https://gitcode.com/gh_mirrors/op/openhand-hardware 在工业自动化与协作机器人领域,传统抓取系统面临着适应性…...
3个步骤掌握Ryujinx模拟器高级配置:从入门到精通指南
3个步骤掌握Ryujinx模拟器高级配置:从入门到精通指南 【免费下载链接】Ryujinx 用 C# 编写的实验性 Nintendo Switch 模拟器 项目地址: https://gitcode.com/GitHub_Trending/ry/Ryujinx Ryujinx作为一款用C#编写的实验性Nintendo Switch模拟器,为…...
Qwen2.5-0.5B如何快速上手?新手入门必看部署实操指南
Qwen2.5-0.5B如何快速上手?新手入门必看部署实操指南 你是不是也对最近火热的开源大模型Qwen2.5系列感到好奇?特别是那个号称“小身材大能量”的Qwen2.5-0.5B-Instruct模型。它只有5亿参数,却继承了阿里通义千问系列强大的指令跟随和多语言能…...
