当前位置: 首页 > news >正文

ESP32开发进阶: 训练神经网络

一、网络设定

        我们设定一个简单的前馈神经网络,其结构如下:

  1. 输入层:节点数:2,接收输入数据,每个输入样本包含2个特征,例如 {1.0, 0.0}{0.0, 1.0} 等。

  2. 隐藏层:节点数:2,处理和提取输入数据的特征,

    1. 激活函数:使用 Sigmoid 激活函数 sigmoid(x) = 1 / (1 + exp(-x))

    2. 权重矩阵(Weights from Input to Hidden):

      float weights_input_hidden[INPUT_NODES][HIDDEN_NODES] = {{0.15, 0.25},{0.20, 0.30}
      };

      这是一个 2x2 的权重矩阵,用于连接输入层和隐藏层。

    3. 偏置(Biases for Hidden Layer)

      float bias_hidden[HIDDEN_NODES] = {0.35, 0.35};
      

      这是一个包含2个偏置值的数组,分别对应每个隐藏层节点。

    4. 每个隐藏层节点计算如下

      h_j = sigmoid\left(\sum_{i=1}^{n} (x_i \cdot w_{ij} + b_j)\right)

      其中 x_i是输入节点值,w_{ij}是权重,b_j是偏置。

  3. 输出层:节点数:1,给出最终的预测结果。

    1. 激活函数:使用 Sigmoid 激活函数

    2. 权重矩阵:

      float weights_hidden_output[HIDDEN_NODES][OUTPUT_NODES] = {{0.40},{0.50}
      };
      

      这是一个 2x1 的权重矩阵,用于连接隐藏层和输出层。

    3. 偏置

      float bias_output[OUTPUT_NODES] = {0.60};
      这是一个包含1个偏置值的数组,分别对应每个隐藏层节点。
    4. 输出层节点计算如下y = sigmoid\left(\sum_{j=1}^{m} (h_j \cdot w_{jo} + b_o)\right)其中 h_j是隐藏层节点值,w_{jo}是权重,b_o 是偏置。

        整体网络设定如下图所示:

二、Arduino端代码

        首先,是初始化部分(权重和偏置的定义)

float weights_input_hidden[INPUT_NODES][HIDDEN_NODES] = {{0.15, 0.25},{0.20, 0.30}
};float weights_hidden_output[HIDDEN_NODES][OUTPUT_NODES] = {{0.40},{0.50}
};float bias_hidden[HIDDEN_NODES] = {0.35, 0.35};
float bias_output[OUTPUT_NODES] = {0.60};

        接着,前向传播通过计算每一层的加权输入和激活输出来推断输入样本的预测值。

void forward_propagation(float input[]) {for (int i = 0; i < INPUT_NODES; i++) {input_layer[i] = input[i];}for (int j = 0; j < HIDDEN_NODES; j++) {hidden_layer[j] = 0;for (int i = 0; i < INPUT_NODES; i++) {hidden_layer[j] += input_layer[i] * weights_input_hidden[i][j];}hidden_layer[j] += bias_hidden[j];hidden_layer[j] = sigmoid(hidden_layer[j]);}for (int k = 0; k < OUTPUT_NODES; k++) {output_layer[k] = 0;for (int j = 0; j < HIDDEN_NODES; j++) {output_layer[k] += hidden_layer[j] * weights_hidden_output[j][k];}output_layer[k] += bias_output[k];output_layer[k] = sigmoid(output_layer[k]);}
}

        最后,反向传播通过计算误差并根据误差调整权重和偏置,以最小化损失函数。

void backward_propagation(float input[], float target) {float output_error = target - output_layer[0];float output_delta = output_error * sigmoid_derivative(output_layer[0]);float hidden_error[HIDDEN_NODES];float hidden_delta[HIDDEN_NODES];for (int j = 0; j < HIDDEN_NODES; j++) {hidden_error[j] = output_delta * weights_hidden_output[j][0];hidden_delta[j] = hidden_error[j] * sigmoid_derivative(hidden_layer[j]);}for (int j = 0; j < HIDDEN_NODES; j++) {weights_hidden_output[j][0] += learning_rate * output_delta * hidden_layer[j];}bias_output[0] += learning_rate * output_delta;for (int i = 0; i < INPUT_NODES; i++) {for (int j = 0; j < HIDDEN_NODES; j++) {weights_input_hidden[i][j] += learning_rate * hidden_delta[j] * input_layer[i];}}for (int j = 0; j < HIDDEN_NODES; j++) {bias_hidden[j] += learning_rate * hidden_delta[j];}
}

        定义四组输入样本及其目标输出,用于训练神经网络, 通过多次迭代训练神经网络,并在每次训练后输出当前的权重和F1-score。完整代码如下:

#include <Arduino.h>
#include <cmath>// 定义神经网络结构
#define INPUT_NODES 2
#define HIDDEN_NODES 2
#define OUTPUT_NODES 1// 定义神经网络参数
float input_layer[INPUT_NODES];
float hidden_layer[HIDDEN_NODES];
float output_layer[OUTPUT_NODES];float weights_input_hidden[INPUT_NODES][HIDDEN_NODES] = {{0.15, 0.25},{0.20, 0.30}
};float weights_hidden_output[HIDDEN_NODES][OUTPUT_NODES] = {{0.40},{0.50}
};float bias_hidden[HIDDEN_NODES] = {0.35, 0.35};
float bias_output[OUTPUT_NODES] = {0.60};float learning_rate = 0.1;// 激活函数和其导数(sigmoid)
float sigmoid(float x) {return 1.0 / (1.0 + exp(-x));
}float sigmoid_derivative(float x) {return x * (1.0 - x);
}// 计算预测值
void forward_propagation(float input[]) {for (int i = 0; i < INPUT_NODES; i++) {input_layer[i] = input[i];}for (int j = 0; j < HIDDEN_NODES; j++) {hidden_layer[j] = 0;for (int i = 0; i < INPUT_NODES; i++) {hidden_layer[j] += input_layer[i] * weights_input_hidden[i][j];}hidden_layer[j] += bias_hidden[j];hidden_layer[j] = sigmoid(hidden_layer[j]);}for (int k = 0; k < OUTPUT_NODES; k++) {output_layer[k] = 0;for (int j = 0; j < HIDDEN_NODES; j++) {output_layer[k] += hidden_layer[j] * weights_hidden_output[j][k];}output_layer[k] += bias_output[k];output_layer[k] = sigmoid(output_layer[k]);}
}// 更新权重和偏置
void backward_propagation(float input[], float target) {float output_error = target - output_layer[0];float output_delta = output_error * sigmoid_derivative(output_layer[0]);float hidden_error[HIDDEN_NODES];float hidden_delta[HIDDEN_NODES];for (int j = 0; j < HIDDEN_NODES; j++) {hidden_error[j] = output_delta * weights_hidden_output[j][0];hidden_delta[j] = hidden_error[j] * sigmoid_derivative(hidden_layer[j]);}for (int j = 0; j < HIDDEN_NODES; j++) {weights_hidden_output[j][0] += learning_rate * output_delta * hidden_layer[j];}bias_output[0] += learning_rate * output_delta;for (int i = 0; i < INPUT_NODES; i++) {for (int j = 0; j < HIDDEN_NODES; j++) {weights_input_hidden[i][j] += learning_rate * hidden_delta[j] * input_layer[i];}}for (int j = 0; j < HIDDEN_NODES; j++) {bias_hidden[j] += learning_rate * hidden_delta[j];}
}// 计算F1-score
float compute_f1_score(float tp, float fp, float fn) {float precision = tp / (tp + fp);float recall = tp / (tp + fn);return 2 * (precision * recall) / (precision + recall);
}void print_weights() {Serial.println("Weights Input-Hidden:");for (int i = 0; i < INPUT_NODES; i++) {for (int j = 0; j < HIDDEN_NODES; j++) {Serial.printf("w[%d][%d] = %f ", i, j, weights_input_hidden[i][j]);}Serial.println();}Serial.println("Weights Hidden-Output:");for (int j = 0; j < HIDDEN_NODES; j++) {Serial.printf("w[%d][0] = %f ", j, weights_hidden_output[j][0]);}Serial.println();
}void setup() {// 初始化串口Serial.begin(115200);while (!Serial) {}// 打印欢迎信息Serial.println("Hello, ESP32 Neural Network with Training!");
}void loop() {// 输入样本和目标float input[][INPUT_NODES] = {{1.0, 0.0}, {0.0, 1.0}, {1.0, 1.0}, {0.0, 0.0}};float target[] = {1.0, 1.0, 0.0, 0.0};// 初始化统计量float tp = 0, fp = 0, fn = 0;// 训练for (int epoch = 0; epoch < 1000; epoch++) {for (int i = 0; i < 4; i++) {forward_propagation(input[i]);backward_propagation(input[i], target[i]);// 更新统计量float prediction = output_layer[0] > 0.5 ? 1.0 : 0.0;if (prediction == 1.0 && target[i] == 1.0) {tp++;} else if (prediction == 1.0 && target[i] == 0.0) {fp++;} else if (prediction == 0.0 && target[i] == 1.0) {fn++;}}// 打印权重和F1-scoreSerial.printf("Epoch %d\n", epoch);print_weights();float f1_score = compute_f1_score(tp, fp, fn);Serial.printf("F1-Score: %f\n", f1_score);// 重置统计量tp = 0;fp = 0;fn = 0;// 延迟一段时间delay(100);}// 停止程序while (true) {}
}

        部分打印结果如下:

相关文章:

ESP32开发进阶: 训练神经网络

一、网络设定 我们设定一个简单的前馈神经网络&#xff0c;其结构如下&#xff1a; 输入层&#xff1a;节点数&#xff1a;2&#xff0c;接收输入数据&#xff0c;每个输入样本包含2个特征&#xff0c;例如 {1.0, 0.0}, {0.0, 1.0} 等。 隐藏层&#xff1a;节点数&#xff1a;…...

全国区块链职业技能大赛国赛考题前端功能开发

任务3-1:区块链应用前端功能开发 1.请基于前端系统的开发模板,在登录组件login.js、组件管理文件components.js中添加对应的逻辑代码,实现对前端的角色选择功能,并测试功能完整性,示例页面如下: 具体要求如下: (1)有明确的提示,提示用户选择角色; (2)用户可看…...

直接插入排序算法详解

直接插入排序&#xff08;Straight Insertion Sort&#xff09;是一种简单直观的排序算法。它的工作原理是通过构建有序序列&#xff0c;对于未排序数据&#xff0c;在已排序序列中从后向前扫描&#xff0c;找到相应位置并插入。插入排序在实现上&#xff0c;通常采用in-place排…...

sql手动自增id

有时候在运维处理数据的时候&#xff0c;需要给某张表插入新的记录&#xff0c;那么需要知道最新插入数据的id,并在最新id的基础上加上id增长步长获取新的id,这个过程往往需要现将max出来加1,再手动补充到sql语句中&#xff0c;很麻烦&#xff0c;而且数据多的时候容易出错。有…...

10_TypeScript中的泛型

TypeScript中的泛型&#xff09; 一、泛型的定义二、泛型函数三、泛型类&#xff1a;比如有个最小堆算法&#xff0c;需要同时支持返回数字和字符串两种类型。通过类的泛型来实现四、泛型接口五、泛型类 --扩展 把类作为参数类型的泛型类1、实现&#xff1a;定义一个 User 的类…...

Unity3D之TextMeshPro使用

文章目录 1. TextMeshPro简介2. TextMeshPro创建3. TextMeshPro脚本中调用4. TextMeshPro字体设置及中文支持过程中出现的一些问题 1. TextMeshPro简介 【官网文档】https://docs.unity.cn/cn/2020.3/Manual/com.unity.textmeshpro.html TextMeshPro 是 Unity 的最终文本解决…...

K8S 上部署 Prometheus + Grafana

文章目录 一、使用 Helm 安装 Prometheus1. 配置源2. 下载 prometheus 包3. 安装 prometheus4. 卸载 二、使用 Helm 安装 Grafana1. 配置源2. 安装 grafana3. 访问4. 卸载 一、使用 Helm 安装 Prometheus 1. 配置源 地址&#xff1a;https://artifacthub.io/packages/helm/pro…...

雷军的逆天改命与顺势而为

雷军年度演讲前&#xff0c;朋友李翔提了一个问题&#xff1a;雷军造车是属于顺势而为还是逆势而为&#xff1f;评论互动区有一个总结&#xff0c;很有意思&#xff0c;叫“顺势逆袭”。 大致意思是产业趋势下小米从手机到IOT再切入汽车&#xff0c;是战略的必然&#xff0c;不…...

Leetcode 11. 盛最多水的容器

Leetcode 11. 盛最多水的容器 Leetcode 11. 盛最多水的容器 一、题目描述二、我的想法 一、题目描述 给定一个长度为 n 的整数数组 height 。有 n 条垂线&#xff0c;第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 找出其中的两条线&#xff0c;使得它们与 x 轴共同构成…...

Java笔试分享

1、设计模式&#xff08;写>3种常用的设计模式&#xff09; 设计模式是在软件工程中解决常见问题的经验性解决方案。以下是一些常用的设计模式&#xff1a; 单例模式&#xff08;Singleton&#xff09;&#xff1a; 意图&#xff1a;确保一个类只有一个实例&#xff0c;并…...

LeetCode:对称的二叉树(C语言)

1、问题概述&#xff1a;给一个二叉树&#xff0c;看是否按轴对称 2、示例 示例 1&#xff1a; 输入&#xff1a;root [1,2,2,3,4,4,3] 输出&#xff1a;true 示例 2&#xff1a; 输入&#xff1a;root [1,2,2,null,3,null,3] 输出&#xff1a;false 3、分析 &#xff08;1&a…...

Postman中的API Schema验证:确保响应精准无误

Postman中的API Schema验证&#xff1a;确保响应精准无误 在API开发和测试过程中&#xff0c;验证响应数据的准确性和一致性是至关重要的。Postman提供了一个强大的功能——API Schema验证&#xff0c;它允许开发者根据预定义的JSON Schema来检查API响应。本文将详细介绍如何在…...

深入浅出WebRTC—GCC

GoogCcNetworkController 是 GCC 的控制中心&#xff0c;它由 RtpTransportControllerSend 通过定时器和 TransportFeedback 来驱动。GoogCcNetworkController 不断更新内部各个组件的状态&#xff0c;并协调组件之间相互配合&#xff0c;向外输出目标码率等重要参数&#xff0…...

leetcode日记(49)旋转链表

其实不难&#xff0c;就是根据kk%len判断需要旋转的位置&#xff0c;再将后半段接在前半段前面就行。 /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNode() : val(0), next(nullptr) {}* ListNode(int x) : …...

InteliJ IDEA最新2024版下载安装与快速配置激活使用教程+jdk下载配置

第一步&#xff1a;下载ideaIC-2024.1.4 方法1&#xff1a;在线链接 IntelliJ IDEA – the Leading Java and Kotlin IDE (jetbrains.com) 选择社区版进行下载 方法2&#xff1a;百度网盘 链接&#xff1a;https://pan.baidu.com/s/1ydS6krUX6eE_AdW4uGV_6w?pwdsbfm 提取…...

【23】Android高级知识之Window(四) - ThreadedRenderer

一、概述 在上一篇文章中已经讲了setView整个流程中&#xff0c;最开始的addToDisplay和WMS跨进程通信的整个过程做了什么。继文章Android基础知识之Window(二)&#xff0c;这算是另外一个分支了&#xff0c;接着讲分析在performTraversals的三个操作中&#xff0c;最后触发pe…...

Java-根据前缀-日期-数字-生成流水号(不重复)

&#x1f388;边走、边悟&#x1f388;迟早会好 小伙伴们在日常开发时可能会遇到的业务-生成流水号&#xff0c;在企业中可以说是比较常见的需求&#xff0c; 可以采用"前缀日期数字"的方式&#xff08;ps:此方式是需要用到缓存的&#xff09;前缀&#xff1a;为了…...

跟李沐学AI:卷积层

从全连接层到卷积 多层感知机十分适合处理表格数据&#xff0c;其中行对应样本&#xff0c;列对应特征。但对于图片等数据&#xff0c;全连接层会导致参数过多。卷积神经网络&#xff08;convolutional neural networks&#xff0c;CNN&#xff09;是机器学习利用自然图像中一…...

使用RedisTemplate操作executePipelined

前言 RedisTemplate 是 Spring 提供的用于操作 Redis 的模板类&#xff0c;它封装了 Redis 的连接、连接池等管理&#xff0c;并提供了一系列的操作方法来简化 Redis 的使用。其中&#xff0c;executePipelined 方法是 RedisTemplate 中的一个高级特性&#xff0c;用于支持 Re…...

react-native从入门到实战系列教程一环境安装篇

充分阅读官网的环境配置指南&#xff0c;严格按照他的指导作业&#xff0c;不然你一直只能在web或沙箱环境下玩玩 极快的网络和科学上网&#xff0c;必备其中的一个较好的心理忍受能力&#xff0c;因为上面一点就可以让你放弃坚持不懈&#xff0c;努力尝试 成功效果 三大件 …...

树莓派超全系列教程文档--(62)使用rpicam-app通过网络流式传输视频

使用rpicam-app通过网络流式传输视频 使用 rpicam-app 通过网络流式传输视频UDPTCPRTSPlibavGStreamerRTPlibcamerasrc GStreamer 元素 文章来源&#xff1a; http://raspberry.dns8844.cn/documentation 原文网址 使用 rpicam-app 通过网络流式传输视频 本节介绍来自 rpica…...

【JavaEE】-- HTTP

1. HTTP是什么&#xff1f; HTTP&#xff08;全称为"超文本传输协议"&#xff09;是一种应用非常广泛的应用层协议&#xff0c;HTTP是基于TCP协议的一种应用层协议。 应用层协议&#xff1a;是计算机网络协议栈中最高层的协议&#xff0c;它定义了运行在不同主机上…...

聊聊 Pulsar:Producer 源码解析

一、前言 Apache Pulsar 是一个企业级的开源分布式消息传递平台&#xff0c;以其高性能、可扩展性和存储计算分离架构在消息队列和流处理领域独树一帜。在 Pulsar 的核心架构中&#xff0c;Producer&#xff08;生产者&#xff09; 是连接客户端应用与消息队列的第一步。生产者…...

DIY|Mac 搭建 ESP-IDF 开发环境及编译小智 AI

前一阵子在百度 AI 开发者大会上&#xff0c;看到基于小智 AI DIY 玩具的演示&#xff0c;感觉有点意思&#xff0c;想着自己也来试试。 如果只是想烧录现成的固件&#xff0c;乐鑫官方除了提供了 Windows 版本的 Flash 下载工具 之外&#xff0c;还提供了基于网页版的 ESP LA…...

反射获取方法和属性

Java反射获取方法 在Java中&#xff0c;反射&#xff08;Reflection&#xff09;是一种强大的机制&#xff0c;允许程序在运行时访问和操作类的内部属性和方法。通过反射&#xff0c;可以动态地创建对象、调用方法、改变属性值&#xff0c;这在很多Java框架中如Spring和Hiberna…...

【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具

第2章 虚拟机性能监控&#xff0c;故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令&#xff1a;jps [options] [hostid] 功能&#xff1a;本地虚拟机进程显示进程ID&#xff08;与ps相同&#xff09;&#xff0c;可同时显示主类&#x…...

Golang——9、反射和文件操作

反射和文件操作 1、反射1.1、reflect.TypeOf()获取任意值的类型对象1.2、reflect.ValueOf()1.3、结构体反射 2、文件操作2.1、os.Open()打开文件2.2、方式一&#xff1a;使用Read()读取文件2.3、方式二&#xff1a;bufio读取文件2.4、方式三&#xff1a;os.ReadFile读取2.5、写…...

PostgreSQL——环境搭建

一、Linux # 安装 PostgreSQL 15 仓库 sudo dnf install -y https://download.postgresql.org/pub/repos/yum/reporpms/EL-$(rpm -E %{rhel})-x86_64/pgdg-redhat-repo-latest.noarch.rpm# 安装之前先确认是否已经存在PostgreSQL rpm -qa | grep postgres# 如果存在&#xff0…...

数据库——redis

一、Redis 介绍 1. 概述 Redis&#xff08;Remote Dictionary Server&#xff09;是一个开源的、高性能的内存键值数据库系统&#xff0c;具有以下核心特点&#xff1a; 内存存储架构&#xff1a;数据主要存储在内存中&#xff0c;提供微秒级的读写响应 多数据结构支持&…...

LUA+Reids实现库存秒杀预扣减 记录流水 以及自己的思考

目录 lua脚本 记录流水 记录流水的作用 流水什么时候删除 我们在做库存扣减的时候&#xff0c;显示基于Lua脚本和Redis实现的预扣减 这样可以在秒杀扣减的时候保证操作的原子性和高效性 lua脚本 // ... 已有代码 ...Overridepublic InventoryResponse decrease(Inventor…...