当前位置: 首页 > news >正文

cuda从零开始手搓PB神经网络

cuda实现PB神经网络


基于上一篇的矩阵点乘,实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下:

#ifndef __BP_NETWORK_HPP__
#define __BP_NETWORK_HPP__
#include "matrix.hpp"
#include "mat.hpp"
#include "update_methods.hpp"template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_, int ... remain_layer>
struct bp_network
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using next_node_type = typename bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>;using term_output_type = typename next_node_type::term_output_type;weight_type weight;update_type_tpl<weight_type> weight_update_method;output_type bias;update_type_tpl<output_type> bias_update_method;input_type pre_input;output_type pre_func_input;next_node_type next_node;bp_network():weight_update_method(), bias_update_method(){weight.template reset<init_type>();bias.template reset<init_type>();next_node = bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>();}auto forward(input_type& input){output_type curr_output;pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;curr_output = pre_func_input.template activate<forward_func>();return next_node.forward(curr_output);}auto backward(term_output_type& delta, val_type lr){output_type curr_delta = next_node.backward(delta, lr);curr_delta = pre_func_input.template activate<backward_func>() * curr_delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update_method.update(weight, delta_weight);bias = bias_update_method.update(bias, curr_delta);return ret;}   // 更新惯性量void update_inert(){weight_update_method.update_inert();bias_update_method.update_inert();next_node.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("=================\n");next_node.print();}
};template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_>
struct bp_network<activate_type, val_type_, update_type_tpl, init_type, input_num_, output_num_>
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using term_output_type = typename output_type;using weight_update_type = typename update_type_tpl<weight_type>;using bias_update_type = typename update_type_tpl<output_type>;weight_type weight;weight_update_type weight_update;output_type bias;bias_update_type bias_update;output_type pre_func_input;input_type pre_input;bp_network():weight_update(), bias_update(){weight.template reset<init_type>();bias.template reset<init_type>();}auto forward(input_type& input){pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;return pre_func_input.template activate<forward_func>();}auto backward(output_type& delta, val_type lr){output_type curr_delta = pre_func_input.template activate<backward_func>() * delta;auto ret = weight.t_dot(curr_delta);// 更新参数weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update.update(weight, delta_weight);bias = bias_update.update(bias, curr_delta);return ret;}void update_inert(){weight_update.update_inert();bias_update.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("*****************\n");}
};#endif

下面实验一下我们的bp神经网络。

#include <chrono>
#include <thread>
#include "matrix.hpp"
#include "bp_network.hpp"
int main()
{constexpr int row_num = 32;constexpr int adj_num = 32;constexpr int col_num = 32;/*matrix_device_proxy<row_num, adj_num, double> A;eyes(A(), 2.0);matrix_device_proxy<adj_num, col_num, double> B;eyes(B(), 1.0);matrix_device_proxy<row_num, col_num, double> C;mat_dot<sigmoid>(A(), B(), C());print(type_cast(C()));auto A = mat<row_num, adj_num, double>::eyes(2.0);auto B = mat<adj_num, col_num, double>::eyes(1.0);auto C = A.dot(B);C = C + 1.0;C = sqrtl(C);C = C - 2.0;C = C * 3.0;C = C / 4.0;C.print();std::cout << "---------- D ----------" << std::endl;auto D = mat<row_num, col_num, double>::xavier_gaussian();D.print();std::cout << "---------- E ----------" << std::endl;auto E = mat<row_num, col_num, double>::xavier_mean();E.print();std::cout << "---------- F ----------" << std::endl;auto F = mat<row_num, col_num, double>::he_gaussian();F.print();std::cout << "---------- G ----------" << std::endl;auto G = mat<row_num, col_num, double>::he_mean();G.print();*/bp_network<sigmoid, double, nadam, xavier_gaussian_type, row_num, adj_num, col_num> node;auto input = mat<row_num, 1, double>::ones(0.2);auto expect = mat<col_num, 1, double>::ones(0.4);int times = 8000;int update_inert_times = 100;int step = times / update_inert_times;// 计时开始auto start = std::chrono::high_resolution_clock::now();for (int i = 0; i < times; ++i){auto output = node.forward(input);auto delta = (output - expect);node.backward(delta, 0.001);if (i == times - 1){output.t().print();}if (i % step == 0 && i != 0){node.update_inert();}}// 计时结束// 获取结束时间点auto end = std::chrono::high_resolution_clock::now();// 计算持续时间std::chrono::duration<double> duration = end - start;// 输出执行时间std::cout << "Execution time: " << duration.count() << " seconds" << std::endl;//node.print();cudaDeviceReset();return 0;
}

以上代码有个学习率lr没有地方设置哈,将来优化,见谅。执行结果如下:
在这里插入图片描述
可以看出,经过8000次的训练,这个使用sigmoid激活函数、NAdam优化、Xavier-Gaussian初始化的323232的PB能够将误差缩减到0.0001这个量级,而训练时间仅为8.54秒。还是相当给力的。
虽然这对于我的工作没有任何关系,但是我还是想搞一下。毕竟“越是没用的知识就越有用,越是有用的东西就越没用”。

相关文章:

cuda从零开始手搓PB神经网络

cuda实现PB神经网络 基于上一篇的矩阵点乘&#xff0c;实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下&#xff1a; #ifndef __BP_NETWORK_HPP__ #define __BP_NETWORK_HPP__ #include "matrix.hpp&quo…...

mac 安装mongodb

本文分享2种mac本地安装mongodb的方法&#xff0c;一种是通过homebrew安装&#xff0c;一种是通过tar包安装 homebrew安装 brew tap mongodb/brew brew upate brew install mongodb-community8.0tar包安装 安装mongodb 1.下载mongodb社区版的tar包 mongdb tar包下载地址 2…...

K8S-Pod资源清单的编写,资源的增删改查,镜像的下载策略

1. Pod资源清单的编写 1.1 Pod运行单个容器的资源清单 ##创建工作目录 mkdir -p /root/manifests/pods && cd /root/manifests/pods vim 01-nginx.yaml ##指定api版本 apiVersion: v1 ##指定资源类型 kind: Pod ##指定元数据 metadata:##指定名称name: myweb ##用户…...

【Maui】视图界面与数据模型绑定

文章目录 前言一、问题描述二、解决方案三、软件开发&#xff08;源码&#xff09;3.1 创建模型3.2 视图界面3.3 控制器逻辑层 四、项目展示 前言 .NET 多平台应用 UI (.NET MAUI) 是一个跨平台框架&#xff0c;用于使用 C# 和 XAML 创建本机移动和桌面应用。 使用 .NET MAUI&…...

JavaScript笔记基础篇02——运算符、语句、数组

黑马程序员视频地址&#xff1a;黑马程序员前端JavaScript入门到精通全套视频教程https://www.bilibili.com/video/BV1Y84y1L7Nn?vd_source0a2d366696f87e241adc64419bf12cab&spm_id_from333.788.videopod.episodes 目录 运算符 赋值运算符 ​编辑​编辑 一元运算符…...

心法利器[127] | 24年算法思考-特征工程和经典深度学习

心法利器 本栏目主要和大家一起讨论近期自己学习的心得和体会。具体介绍&#xff1a;仓颉专项&#xff1a;飞机大炮我都会&#xff0c;利器心法我还有。 2023年新的文章合集已经发布&#xff0c;获取方式看这里&#xff1a;又添十万字-CS的陋室2023年文章合集来袭&#xff0c;更…...

ASP.NET Core 中的 JWT 鉴权实现

在当今的软件开发中&#xff0c;安全性和用户认证是至关重要的方面。JSON Web Token&#xff08;JWT&#xff09;作为一种流行的身份验证机制&#xff0c;因其简洁性和无状态特性而被广泛应用于各种应用中&#xff0c;尤其是在 ASP.NET Core 项目里。本文将详细介绍如何在 ASP.…...

PyTorch基本功能与实现代码

PyTorch是一个开源的深度学习框架&#xff0c;提供了丰富的函数和工具&#xff0c;以下为其主要功能的归纳&#xff1a; 核心数据结构&#xff1a; • 张量&#xff08;Tensor&#xff09;&#xff1a;类似于Numpy的ndarray&#xff0c;是PyTorch中基本的数据结构&#xff0c…...

SparkSQL数据模型综合实践

文章目录 1. 实战概述2. 实战步骤2.1 创建数据集2.2 创建数据模型对象2.2.1 创建常量2.2.2 创建加载数据方法2.2.3 创建过滤年龄方法2.2.4 创建平均薪水方法2.2.5 创建主方法2.2.6 查看完整代码 2.3 运行程序&#xff0c;查看结果 3. 实战小结 1. 实战概述 在本次实战中&#…...

3 查找重复的电子邮箱(having与where区别,distinct去重使用)

3 查找重复的电子邮箱&#xff08;having与where区别&#xff0c;distinct去重使用&#xff09; 表: Person ---------------------- | Column Name | Type | ---------------------- | id | int | | email | varchar | ---------------------- id 是该…...

uniapp——App 监听下载文件状态,打开文件(三)

5 实现下载文件并打开 这里演示&#xff0c;导出Excel 表格 文章目录 5 实现下载文件并打开DEMO监听下载进度效果图为什么 totalSize 一直为0&#xff1f; 相关Api&#xff1a; downloader DEMO 提示&#xff1a; 请求方式支持&#xff1a;GET、POST&#xff1b;POST 方式需要…...

循环队列(C语言)

从今天开始我会开启一个专栏leetcode每日一题&#xff0c;大家互相交流代码经验&#xff0c;也当作我每天练习的自我回顾。第一天的内容是leetcode622.设计循环队列。 一、题目详细 设计你的循环队列实现。 循环队列是一种线性数据结构&#xff0c;其操作表现基于 FIFO&#…...

数据可视化:让数据讲故事的艺术

目录 1 前言2 数据可视化的基本概念2.1 可视化的核心目标2.2 传统可视化手段 3 数据可视化在知识图谱中的应用3.1 知识图谱的可视化需求3.2 知识图谱的可视化方法 4 数据可视化叙事&#xff1a;让数据讲故事4.1 叙事可视化的关键要素4.2 数据可视化叙事的实现方法 5 数据可视化…...

雷电9最新版安装Magisk+LSPosd(新手速通)

大家好啊&#xff01;我是NiJiMingCheng 我的博客&#xff1a;NiJiMingCheng 在安卓系统的定制与拓展过程中&#xff0c;获取 ROOT 权限以及安装各类框架是进阶玩家常用的操作&#xff0c;这可以帮助我们实现更多系统层面的个性化功能。今天&#xff0c;我将为大家详细介绍如何…...

Ubuntu 24.04 LTS 开启 SMB 服务,并通过 windows 访问

Ubuntu 24.04 LTS 背景资料 Ubuntu服务器折腾集Ubuntu linux 文件权限Ubuntu 空闲硬盘挂载到 文件管理器的 other locations Ubuntu开启samba和window共享文件 Ubuntu 配置 SMB 服务 安装 Samba 确保 Samba 已安装。如果未安装&#xff0c;运行以下命令进行安装&#xff…...

使用Websocket进行前后端实时通信

1、引入jar&#xff0c;spring-websocket-starter <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId> </dependency> 2、配置websocket config import org.springframe…...

vue2使用flv.js在浏览器打开flv格式视频

组件地址&#xff1a;GitHub - bilibili/flv.js: HTML5 FLV Player flv.js 仅支持 H.264 和 AAC/MP3 编码的 FLV 文件。如果视频文件使用了其他编码格式就打不开。 flv.vue <template><div><el-dialog :visible.sync"innerVisibleFlv" :close-on-pre…...

OpenCV相机标定与3D重建(61)处理未校准的立体图像对函数stereoRectifyUncalibrated()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 为未校准的立体相机计算一个校正变换。 cv::stereoRectifyUncalibrated 是 OpenCV 库中的一个函数&#xff0c;用于处理未校准的立体图像对。该函…...

[cg] glProgramBinary

参考&#xff1a; glProgramBinary - OpenGL 4 Reference Pages opengl 通过gpu编译好的 shader 可以存储到二进制文件中&#xff0c;第二次使用的时候直接加载二进制文件即可&#xff0c; glProgramBinary用于加载shader的二进制数据 实列代码如下&#xff1a; // 假设已经…...

LeetCode hot 力扣热题100 二叉树的最大深度

class Solution { public:int maxDepth(TreeNode* root) {if (root nullptr) {return 0;}int l_depth maxDepth(root->left);int r_depth maxDepth(root->right);return max(l_depth, r_depth) 1;} }; 代码作用 该函数通过递归计算二叉树的最大深度&#xff08;从根节…...

idea大量爆红问题解决

问题描述 在学习和工作中&#xff0c;idea是程序员不可缺少的一个工具&#xff0c;但是突然在有些时候就会出现大量爆红的问题&#xff0c;发现无法跳转&#xff0c;无论是关机重启或者是替换root都无法解决 就是如上所展示的问题&#xff0c;但是程序依然可以启动。 问题解决…...

云原生核心技术 (7/12): K8s 核心概念白话解读(上):Pod 和 Deployment 究竟是什么?

大家好&#xff0c;欢迎来到《云原生核心技术》系列的第七篇&#xff01; 在上一篇&#xff0c;我们成功地使用 Minikube 或 kind 在自己的电脑上搭建起了一个迷你但功能完备的 Kubernetes 集群。现在&#xff0c;我们就像一个拥有了一块崭新数字土地的农场主&#xff0c;是时…...

进程地址空间(比特课总结)

一、进程地址空间 1. 环境变量 1 &#xff09;⽤户级环境变量与系统级环境变量 全局属性&#xff1a;环境变量具有全局属性&#xff0c;会被⼦进程继承。例如当bash启动⼦进程时&#xff0c;环 境变量会⾃动传递给⼦进程。 本地变量限制&#xff1a;本地变量只在当前进程(ba…...

突破不可导策略的训练难题:零阶优化与强化学习的深度嵌合

强化学习&#xff08;Reinforcement Learning, RL&#xff09;是工业领域智能控制的重要方法。它的基本原理是将最优控制问题建模为马尔可夫决策过程&#xff0c;然后使用强化学习的Actor-Critic机制&#xff08;中文译作“知行互动”机制&#xff09;&#xff0c;逐步迭代求解…...

黑马Mybatis

Mybatis 表现层&#xff1a;页面展示 业务层&#xff1a;逻辑处理 持久层&#xff1a;持久数据化保存 在这里插入图片描述 Mybatis快速入门 ![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/6501c2109c4442118ceb6014725e48e4.png //logback.xml <?xml ver…...

汽车生产虚拟实训中的技能提升与生产优化​

在制造业蓬勃发展的大背景下&#xff0c;虚拟教学实训宛如一颗璀璨的新星&#xff0c;正发挥着不可或缺且日益凸显的关键作用&#xff0c;源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例&#xff0c;汽车生产线上各类…...

postgresql|数据库|只读用户的创建和删除(备忘)

CREATE USER read_only WITH PASSWORD 密码 -- 连接到xxx数据库 \c xxx -- 授予对xxx数据库的只读权限 GRANT CONNECT ON DATABASE xxx TO read_only; GRANT USAGE ON SCHEMA public TO read_only; GRANT SELECT ON ALL TABLES IN SCHEMA public TO read_only; GRANT EXECUTE O…...

oracle与MySQL数据库之间数据同步的技术要点

Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异&#xff0c;它们的数据同步要求既要保持数据的准确性和一致性&#xff0c;又要处理好性能问题。以下是一些主要的技术要点&#xff1a; 数据结构差异 数据类型差异&#xff…...

基于数字孪生的水厂可视化平台建设:架构与实践

分享大纲&#xff1a; 1、数字孪生水厂可视化平台建设背景 2、数字孪生水厂可视化平台建设架构 3、数字孪生水厂可视化平台建设成效 近几年&#xff0c;数字孪生水厂的建设开展的如火如荼。作为提升水厂管理效率、优化资源的调度手段&#xff0c;基于数字孪生的水厂可视化平台的…...

【单片机期末】单片机系统设计

主要内容&#xff1a;系统状态机&#xff0c;系统时基&#xff0c;系统需求分析&#xff0c;系统构建&#xff0c;系统状态流图 一、题目要求 二、绘制系统状态流图 题目&#xff1a;根据上述描述绘制系统状态流图&#xff0c;注明状态转移条件及方向。 三、利用定时器产生时…...