基于简单神经网络的线性回归
一、概述
本代码实现了一个简单的神经网络进行线性回归任务。通过生成包含噪声的线性数据集,定义一个简单的神经网络类,使用梯度下降算法训练网络以拟合数据,并最终通过可视化展示原始数据、真实线性关系以及模型的预测结果。
二、依赖库
numpy:用于数值计算,包括生成数组、进行随机数操作、执行数学运算等。matplotlib.pyplot:用于数据可视化,绘制散点图和折线图以展示数据和模型的预测结果。
三、代码详解
1. 生成数据集
python
np.random.seed(42)
x = np.linspace(-10, 10, 100)
y = x + np.random.normal(0, 1, x.shape) # 添加噪声
np.random.seed(42):设置随机数种子,确保每次运行代码时生成的随机数序列相同,从而使结果可复现。np.linspace(-10, 10, 100):生成一个包含 100 个元素的一维数组x,元素均匀分布在 - 10 到 10 之间。x + np.random.normal(0, 1, x.shape):生成因变量y,它基于真实的线性关系y = x,并添加了均值为 0、标准差为 1 的高斯噪声。np.random.normal(0, 1, x.shape)生成与x形状相同的随机噪声数组。
2. 定义神经网络(线性回归)
python
class SimpleNN:def __init__(self):self.w = np.random.randn() # 权重self.b = np.random.randn() # 偏置def forward(self, x):return self.w * x + self.b # 前向传播def loss(self, y_true, y_pred):return np.mean((y_true - y_pred) **2) # 均方误差def gradient(self, x, y_true, y_pred):dw = -2 * np.mean(x * (y_true - y_pred)) # 权重的梯度db = -2 * np.mean(y_true - y_pred) # 偏置的梯度return dw, dbdef train(self, x, y, lr=0.01, epochs=1000):for epoch in range(epochs):y_pred = self.forward(x)dw, db = self.gradient(x, y, y_pred)self.w -= lr * dw # 更新权重self.b -= lr * db # 更新偏置if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {self.loss(y, y_pred):.4f}')
__init__方法:初始化神经网络的权重self.w和偏置self.b,使用np.random.randn()生成随机的初始值。forward方法:实现前向传播,根据输入x、权重self.w和偏置self.b计算输出y_pred,即y_pred = self.w * x + self.b。loss方法:计算预测值y_pred和真实值y_true之间的均方误差(MSE),公式为np.mean((y_true - y_pred) ** 2)。gradient方法:计算权重self.w和偏置self.b的梯度。dw是权重的梯度,计算公式为-2 * np.mean(x * (y_true - y_pred));db是偏置的梯度,计算公式为-2 * np.mean(y_true - y_pred)。train方法:使用梯度下降算法训练神经网络。在指定的epochs(训练轮数)内,每次迭代进行前向传播计算预测值y_pred,然后计算梯度dw和db,根据学习率lr更新权重self.w和偏置self.b。每 100 轮打印一次当前轮数和损失值。
3. 训练模型
python
model = SimpleNN()
model.train(x, y, lr=0.01, epochs=1000)
SimpleNN():创建一个SimpleNN类的实例model。model.train(x, y, lr=0.01, epochs=1000):调用model的train方法,使用生成的数据集x和y,学习率lr=0.01,训练轮数epochs=1000进行训练。
4. 可视化结果
python
y_pred = model.forward(x)
plt.scatter(x, y, label='Data points')
plt.plot(x, x, color='red', label='y = x')
plt.plot(x, y_pred, color='green', label='Predicted')
plt.legend()
plt.show()
model.forward(x):使用训练好的模型model对数据集x进行前向传播,得到预测值y_pred。plt.scatter(x, y, label='Data points'):绘制原始数据集的散点图,标签为Data points。plt.plot(x, x, color='red', label='y = x'):绘制真实的线性关系y = x的折线图,颜色为红色,标签为y = x。plt.plot(x, y_pred, color='green', label='Predicted'):绘制模型预测结果的折线图,颜色为绿色,标签为Predicted。plt.legend():显示图例,方便区分不同的图形。plt.show():显示绘制好的图形。
四、注意事项
- 本代码实现的是一个简单的线性回归神经网络,实际应用中可能需要更复杂的模型结构和优化方法。
- 学习率
lr和训练轮数epochs是超参数,可能需要根据具体数据和任务进行调整以获得更好的训练效果。 - 代码中使用的均方误差损失函数和梯度计算公式是针对线性回归问题的常见选择,但在其他问题中可能需要使用不同的损失函数和梯度计算方法。
完整代码
import numpy as np
import matplotlib.pyplot as plt# 1. 生成数据集
np.random.seed(42)
x = np.linspace(-10, 10, 100)
y = x + np.random.normal(0, 1, x.shape) # 添加噪声# 2. 定义神经网络(线性回归)
class SimpleNN:def __init__(self):self.w = np.random.randn() # 权重self.b = np.random.randn() # 偏置def forward(self, x):return self.w * x + self.b # 前向传播def loss(self, y_true, y_pred):return np.mean((y_true - y_pred) **2) # 均方误差def gradient(self, x, y_true, y_pred):dw = -2 * np.mean(x * (y_true - y_pred)) # 权重的梯度db = -2 * np.mean(y_true - y_pred) # 偏置的梯度return dw, dbdef train(self, x, y, lr=0.01, epochs=1000):for epoch in range(epochs):y_pred = self.forward(x)dw, db = self.gradient(x, y, y_pred)self.w -= lr * dw # 更新权重self.b -= lr * db # 更新偏置if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {self.loss(y, y_pred):.4f}')# 3. 训练模型
model = SimpleNN()
model.train(x, y, lr=0.01, epochs=1000)# 4. 可视化结果
y_pred = model.forward(x)
plt.scatter(x, y, label='Data points')
plt.plot(x, x, color='red', label='y = x')
plt.plot(x, y_pred, color='green', label='Predicted')
plt.legend()
plt.show()
相关文章:
基于简单神经网络的线性回归
一、概述 本代码实现了一个简单的神经网络进行线性回归任务。通过生成包含噪声的线性数据集,定义一个简单的神经网络类,使用梯度下降算法训练网络以拟合数据,并最终通过可视化展示原始数据、真实线性关系以及模型的预测结果。 二、依赖库 …...
【nvidia】Windows 双 A6000 显卡双显示器驱动更新问题修复
问题描述:windows自动更新nvidia驱动会导致只检测得到一个A6000显卡。 解决方法 下载 A6000 驱动 572.83-quadro-rtx-desktop-notebook-win10-win11-64bit-international-dch-whql.exehttps://download.csdn.net/download/qq_18846849/90554276 不要直接安装。如…...
《SRv6 网络编程:开启IP网络新时代》第2章、第3章:SRv6基本原理和基础协议
背景 根据工作要求、本人掌握的知识情况,仅针对《SRv6 网络编程:开启IP网络新时代》书籍中涉及的部分知识点进行总结梳理,并与工作小组进行分享,不涉及对原作的逐字搬运。 问题 组内同事提出的问题:本文缺扩展头描述…...
如何将AI模型返回的字符串转为html元素?
场景: 接入deepseek模型的api到我们平台,返回的字符串需要做下格式化处理。 返回的数据是这样的: {"role": "assistant","content": "<think>\n嗯,用户问的是“星体是什么”。首先&am…...
Citus源码(1)分布式表行为测试
最近对citus的实现非常好奇,本篇对citus的行为做一些测试。本篇只测行为,不分析源码。后面会继续写一系列文章分析citus源码。 环境:3节点 PG17 with citus。 SELECT citus_set_coordinator_host(127.0.0.1, 3001); SELECT citus_add_node(1…...
装饰器模式与模板方法模式实现MyBatis-Plus QueryWrapper 扩展
pom <dependency><groupId>com.github.yulichang</groupId><artifactId>mybatis-plus-join-boot-starter</artifactId> <!-- MyBatis 联表查询 --> </dependency>MPJLambdaWrapperX /*** 拓展 MyBatis Plus Join QueryWrapper 类&…...
【PCIE711-214】基于PCIe总线架构的4路HD-SDI/3G-SDI视频图像模拟源
产品概述 PCIE711-214是一款基于PCIE总线架构的4路SDI视频模拟源。该板卡为标准的PCIE插卡,全高尺寸,适合与PCIE总线的工控机或者服务器,板载协议处理器,可以通过PCIE总线将上位机的YUV 422格式视频数据下发通过SDI接口播放出去&…...
突破反爬困境:SDK开发,浏览器模块(七)
声明 本文所讨论的内容及技术均纯属学术交流与技术研究目的,旨在探讨和总结互联网数据流动、前后端技术架构及安全防御中的技术演进。文中提及的各类技术手段和策略均仅供技术人员在合法与合规的前提下进行研究、学习与防御测试之用。 作者不支持亦不鼓励任何未经授…...
rce操作
Linux命令长度突破限制 源码 <?php $param $_REQUEST[param];if ( strlen($param) < 8 ) {echo shell_exec($param); } echo执行函数,$_REQUEST可以接post、get、cookie传参 源码中对参数长度做了限制,小于8位,可以利用临时函数&…...
LabVIEW高效溢流阀测试系统
开发了一种基于LabVIEW软件和PLC硬件的溢流阀测试系统。通过集成神经网络优化的自适应PID控制器,该系统能自动进行压力稳定性、寿命以及动静态性能测试。该设计不仅提升了测试效率,还通过智能化控制提高了数据的精确性和操作的便捷性。 项目背景&…...
Spring Boot 中 JdbcTemplate 处理枚举类型转换 和 减少数据库连接的方法 的详细说明,包含代码示例和关键要点
以下是 Spring Boot 中 JdbcTemplate 处理枚举类型转换 和 减少数据库连接的方法 的详细说明,包含代码示例和关键要点: 一、JdbcTemplate 处理枚举类型转换 1. 场景说明 假设数据库存储的是枚举的 String 或 int 值,但 Java 实体类使用 enu…...
DataGear 5.3.0 制作支持导出表格数据的数据可视化看板
DataGear 内置表格图表底层采用的是DataTable表格组件,默认并未引入导出数据的JS支持库,如果有导出表格数据需求,则可以在看板中引入导出相关JS支持库,制作具有导出CSV、Excel、PDF功能的表格数据看板。 在新发布的5.3.0版本中&a…...
Web网页内嵌 Adobe Pdf Reader 谷歌Chrome在线预览编辑PDF文档
随着数字化办公的普及,PDF文档已成为信息处理的核心载体,虽然桌面端有很多软件可以实现预览编辑PDF文档,而在线在线预览编辑PDF也日益成为一个难题。 作为网页内嵌本地程序的佼佼者——猿大师中间件,之前发布的猿大师办公助手&am…...
歌词json
绽放(4:17) {"lyrics": [{time: 00:00, text: 作词:郑润泽},{time: 00:01, text: 作曲:郑润泽},{time: 00:02, text: 编曲:赵建飞},{time: 00:03, text: 制作人:李淘/赵建飞},{time: 00:09, tex…...
CNG汽车加气站操作工备考真题及答案解析【判断题】
1、燃气经营许可证按照燃气经营规模和类别实行分级审批。(√) 解析:不同规模和类别的燃气经营,其许可证审批级别不同,以确保经营活动的规范和安全。 2、依照《安全生产法》的规定,安全生产监督检查人员对检…...
Sentinel[超详细讲解]-1
定义一系列 规则 👺,对资源进行 保护 👺, 如果违反的了规则,则抛出异常,看是否有fallback兜底处理,如果没有则直接返回异常信息😎 1. 快速入门 1.1 引入 Sentinel 依赖 <depend…...
CUDA专题8—CUDA L2缓存完全指南:从持久化策略到性能优化实战
1. 设备内存L2缓存访问管理 当CUDA内核反复访问全局内存中的某个数据区域时,此类数据访问可视为持久化(persisting)访问。反之,若数据仅被访问一次,则可视为流式(streaming)访问。 从CUDA 11.0开始,计算能力8.0及以上的设备能够调控L2缓存中数据的持久性,从而可能实现更…...
如何让 SQL2API 进化为 Text2API:自然语言生成 API 的深度解析?
在过去的十年里,技术的进步日新月异,尤其是在自动化、人工智能与自然语言处理(NLP)方面。 随着“低代码”平台的崛起,开发者和非技术人员能够更轻松地构建强大而复杂的应用程序。然而,尽管技术门槛降低了&…...
OCCT(2)Windows平台编译OCCT
文章目录 一、Windows平台编译OCCT1、准备环境2、下载源码3、下载第三方库4、使用 CMake 配置5、编译OCCT源码6、运行示例 一、Windows平台编译OCCT 1、准备环境 安装工具: Visual Studio(推荐 VS2019/2022,选择 C 桌面开发 组件࿰…...
【蓝桥杯—单片机】通信总线专项 | 真题整理、解析与拓展 (更新ing...)
通信总线专项 前言SPI第十五届省赛题 UART/RS485/RS232UARTRS485RS232第十三届省赛题小结和拓展:传输方式的分类第十三届省赛 其他相关考点网络传输速率第十五届省赛题第十二届省赛题 前言 在本文中我会把 蓝桥杯单片机赛道 历年真题 中涉及到通信总线的题目整理出…...
【Golang】泛型与类型约束
文章目录 一、环境二、没有泛型的Go三、泛型的优点四、理解泛型(一)定义(二)调用(三)类型约束(Type Constraint)1)接口与约束2)结构体类型约束3)类…...
Uni-app页面信息与元素影响解析
获取窗口信息uni.getWindowInfo {pixelRatio: 3safeArea:{bottom: 778height: 731left: 0right: 375top: 47width: 375}safeAreaInsets: {top: 47, left: 0, right: 0, bottom: 34},screenHeight: 812,screenTop: 0,screenWidth: 375,statusBarHeight: 47,windowBottom: 0,win…...
CentOS(最小化)安装之后,快速搭建Docker环境
本文以VMware虚拟机中安装最小化centos完成后开始。 1. 检查网络 打开网卡/启用网卡 执行命令ip a查看当前的网络连接是否正常: 如果得到的结果和我一样,有ens网卡但是没有ip地址,说明网卡未打开 手动启用: nmcli device sta…...
【身份证证件OCR识别】批量OCR识别身份证照片复印件图片里的文字信息保存表格或改名字,基于QT和腾讯云api_ocr的实现方式
项目背景 在许多业务场景中,需要处理大量身份证照片复印件,手动输入其中的文字信息效率低下且容易出错。利用 OCR(光学字符识别)技术可以自动识别身份证图片中的文字信息,结合 QT 构建图形用户界面,方便用户操作,同时使用腾讯 OCR API 能够保证较高的识别准确率。 界面…...
IP属地和发作品的地址不一样吗
在当今这个数字化时代,互联网已经成为人们日常生活不可或缺的一部分。随着各大社交平台功能的不断完善,一个新功能——IP属地显示,逐渐走进大众视野。这一功能在微博、抖音、快手等各大平台上得到广泛应用,旨在帮助公众识别虚假信…...
Redis - 概述
目录 编辑 一、什么是redis 二、redis能做什么(有什么特点)? 三、redis有什么优势 四、Redis与其他key-value存储有什么不同 五、Redis命令 六、Redis数据结构 1、基础数据结构 2、高级数据结构 一、什么是redis 1、redis&#x…...
vue3 根据城市名称计算城市之间的距离
<template><div class"distance-calculator"><h1>城市距离计算器</h1><!-- 城市输入框 --><div class"input-group"><inputv-model"city1"placeholder"请输入第一个城市"keyup.enter"cal…...
html 列表循环滚动,动态初始化字段数据
html <div class"layui-row"><div class"layui-col-md4"><div class"boxall"><div class"alltitle">超时菜品排行</div><div class"marquee-container"><div class"scroll-…...
QT基础:安装与简介
QT初级 1、简介1.1 安装1.2 设置1.3 在VS中配置Qt1.3 帮助文档 2、Qt项目2.1 创建项目2.1 项目文件2.2 Qt中的窗口类窗口显示 2.3 坐标体系2.4 内存回收 1、简介 QT是一个跨平台的C应用程序开发框架。几乎支持所有的平台, 可用于桌面程序开发以及嵌入式开发。 Qt是标准 C 的扩…...
41、当你在 index.html 中引用了一个公共文件(比如 common.js),修改这个文件后,用户访问页面时仍然看到旧内容,因为浏览器缓存了旧版本
由于浏览器缓存导致公共文件无法更新。当用户修改了公共文件(如 JavaScript 或 CSS),但 index.html 中引用的文件名没有变化,浏览器会认为文件没有更新,继续使用缓存的旧版本。因此,需要通过某种方式让浏览…...
