神经网络解决非线性二分类
这份 Python 代码实现了一个简单的神经网络,用于解决复杂的非线性二分类问题。具体步骤包含生成数据集、定义神经网络模型、训练模型、测试模型以及可视化决策边界。
依赖库说明
python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
numpy:用于数值计算,如数组操作、矩阵运算等。matplotlib.pyplot:用于数据可视化,如绘制决策边界和数据点。make_moons:从sklearn.datasets导入,用于生成半月形的非线性分类数据集。train_test_split:从sklearn.model_selection导入,用于将数据集划分为训练集和测试集。
详细步骤说明
1. 生成复杂的非线性分类数据集
python
np.random.seed(42)
X, y = make_moons(n_samples=1000, noise=0.2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
np.random.seed(42):设置随机数种子,保证结果可复现。make_moons(n_samples=1000, noise=0.2, random_state=42):生成包含 1000 个样本的半月形数据集,噪声水平为 0.2。train_test_split(X, y, test_size=0.2, random_state=42):将数据集按 80:20 的比例划分为训练集和测试集。
2. 定义神经网络(带隐藏层的非线性分类模型)
python
class SimpleNN:def __init__(self, input_size=2, hidden_size=10, output_size=1):self.w1 = np.random.randn(input_size, hidden_size)self.b1 = np.random.randn(hidden_size)self.w2 = np.random.randn(hidden_size, output_size)self.b2 = np.random.randn(output_size)
__init__方法:初始化神经网络的权重和偏置。input_size为输入层神经元数量,hidden_size为隐藏层神经元数量,output_size为输出层神经元数量。
python
def sigmoid(self, x):return 1 / (1 + np.exp(-x))def relu(self, x):return np.maximum(0, x)
sigmoid方法:实现 Sigmoid 激活函数,将输入值映射到 (0, 1) 区间,常用于二分类问题的输出层。relu方法:实现 ReLU 激活函数,将小于 0 的值置为 0,大于等于 0 的值保持不变,可缓解梯度消失问题。
python
def forward(self, x):self.z1 = np.dot(x, self.w1) + self.b1self.a1 = self.relu(self.z1)self.z2 = np.dot(self.a1, self.w2) + self.b2self.a2 = self.sigmoid(self.z2)return self.a2
forward方法:实现神经网络的前向传播过程。输入数据经过输入层到隐藏层的线性变换,再通过 ReLU 激活函数;然后经过隐藏层到输出层的线性变换,最后通过 Sigmoid 激活函数得到输出。
python
def binary_cross_entropy(self, y_true, y_pred):return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))
binary_cross_entropy方法:计算二分类交叉熵损失,衡量模型预测值与真实值之间的差异。
python
def gradient(self, x, y_true, y_pred):m = x.shape[0]d_z2 = y_pred - y_trued_w2 = np.dot(self.a1.T, d_z2) / md_b2 = np.sum(d_z2, axis=0) / md_a1 = np.dot(d_z2, self.w2.T)d_z1 = d_a1 * (self.z1 > 0)d_w1 = np.dot(x.T, d_z1) / md_b1 = np.sum(d_z1, axis=0) / mreturn d_w1, d_b1, d_w2, d_b2
gradient方法:实现反向传播算法,计算权重和偏置的梯度。
python
def train(self, x, y, lr=0.01, epochs=1000):for epoch in range(epochs):y_pred = self.forward(x)dw1, db1, dw2, db2 = self.gradient(x, y, y_pred)self.w1 -= lr * dw1self.b1 -= lr * db1self.w2 -= lr * dw2self.b2 -= lr * db2if (epoch + 1) % 100 == 0:loss = self.binary_cross_entropy(y, y_pred)acc = self.accuracy(y, y_pred)print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss:.4f}, Accuracy: {acc:.4f}')
train方法:训练神经网络。在每个 epoch 中,先进行前向传播得到预测值,再通过反向传播计算梯度,最后更新权重和偏置。每 100 个 epoch 打印一次损失和准确率。
python
def accuracy(self, y_true, y_pred):y_pred_class = (y_pred > 0.5).astype(int)return np.mean(y_pred_class == y_true)
accuracy方法:计算分类准确率,将预测概率大于 0.5 的样本判定为正类,小于等于 0.5 的判定为负类,然后计算预测正确的样本比例。
3. 训练模型
python
model = SimpleNN(input_size=2, hidden_size=10, output_size=1)
model.train(X_train, y_train.reshape(-1, 1), lr=0.01, epochs=2000)
- 创建
SimpleNN类的实例model,并调用train方法对模型进行训练,学习率为 0.01,训练 2000 个 epoch。
4. 测试模型
python
y_test_pred = model.forward(X_test)
test_acc = model.accuracy(y_test.reshape(-1, 1), y_test_pred)
print(f'Test Accuracy: {test_acc:.4f}')
- 使用训练好的模型对测试集进行预测,计算测试集的准确率并打印。
5. 可视化决策边界
python
def plot_decision_boundary(model, X, y):x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),np.arange(y_min, y_max, 0.01))grid = np.c_[xx.ravel(), yy.ravel()]probs = model.forward(grid).reshape(xx.shape)plt.contourf(xx, yy, probs, levels=[0, 0.5, 1], alpha=0.8, cmap=plt.cm.RdBu)plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', cmap=plt.cm.RdBu)plt.title("Decision Boundary")plt.show()plot_decision_boundary(model, X_test, y_test)
plot_decision_boundary函数:绘制模型的决策边界。首先创建一个网格,然后对网格中的每个点进行预测,最后使用contourf函数绘制决策区域,并使用scatter函数绘制数据点。- 调用
plot_decision_boundary函数,传入训练好的模型、测试集数据和标签,可视化决策边界。
总结
该代码实现了一个简单的两层神经网络,用于解决复杂的非线性二分类问题。通过生成数据集、定义模型、训练模型、测试模型和可视化决策边界等步骤,展示了神经网络在非线性分类任务中的应用。
完整代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split# 1. 生成复杂的非线性分类数据集
np.random.seed(42)
X, y = make_moons(n_samples=1000, noise=0.2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 2. 定义神经网络(带隐藏层的非线性分类模型)
class SimpleNN:def __init__(self, input_size=2, hidden_size=10, output_size=1):# 初始化权重和偏置self.w1 = np.random.randn(input_size, hidden_size) # 输入层到隐藏层的权重self.b1 = np.random.randn(hidden_size) # 隐藏层的偏置self.w2 = np.random.randn(hidden_size, output_size) # 隐藏层到输出层的权重self.b2 = np.random.randn(output_size) # 输出层的偏置def sigmoid(self, x):return 1 / (1 + np.exp(-x)) # Sigmoid激活函数def relu(self, x):return np.maximum(0, x) # ReLU激活函数def forward(self, x):# 前向传播self.z1 = np.dot(x, self.w1) + self.b1 # 隐藏层输入self.a1 = self.relu(self.z1) # 隐藏层输出(应用ReLU)self.z2 = np.dot(self.a1, self.w2) + self.b2 # 输出层输入self.a2 = self.sigmoid(self.z2) # 输出层输出(应用Sigmoid)return self.a2def binary_cross_entropy(self, y_true, y_pred):# 二分类交叉熵损失return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))def gradient(self, x, y_true, y_pred):# 反向传播计算梯度m = x.shape[0]# 输出层的梯度d_z2 = y_pred - y_trued_w2 = np.dot(self.a1.T, d_z2) / md_b2 = np.sum(d_z2, axis=0) / m# 隐藏层的梯度d_a1 = np.dot(d_z2, self.w2.T)d_z1 = d_a1 * (self.z1 > 0) # ReLU的导数d_w1 = np.dot(x.T, d_z1) / md_b1 = np.sum(d_z1, axis=0) / mreturn d_w1, d_b1, d_w2, d_b2def train(self, x, y, lr=0.01, epochs=2000):for epoch in range(epochs):y_pred = self.forward(x)dw1, db1, dw2, db2 = self.gradient(x, y, y_pred)# 更新权重和偏置self.w1 -= lr * dw1self.b1 -= lr * db1self.w2 -= lr * dw2self.b2 -= lr * db2if (epoch + 1) % 100 == 0:loss = self.binary_cross_entropy(y, y_pred)acc = self.accuracy(y, y_pred)print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss:.4f}, Accuracy: {acc:.4f}')def accuracy(self, y_true, y_pred):# 计算分类准确率y_pred_class = (y_pred > 0.5).astype(int)return np.mean(y_pred_class == y_true)# 3. 训练模型
model = SimpleNN(input_size=2, hidden_size=10, output_size=1)
model.train(X_train, y_train.reshape(-1, 1), lr=0.01, epochs=2000)# 4. 测试模型
y_test_pred = model.forward(X_test)
test_acc = model.accuracy(y_test.reshape(-1, 1), y_test_pred)
print(f'Test Accuracy: {test_acc:.4f}')# 5. 可视化决策边界
def plot_decision_boundary(model, X, y):x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),np.arange(y_min, y_max, 0.01))grid = np.c_[xx.ravel(), yy.ravel()]probs = model.forward(grid).reshape(xx.shape)plt.contourf(xx, yy, probs, levels=[0, 0.5, 1], alpha=0.8, cmap=plt.cm.RdBu)plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', cmap=plt.cm.RdBu)plt.title("Decision Boundary")plt.show()plot_decision_boundary(model, X_test, y_test)
相关文章:
神经网络解决非线性二分类
这份 Python 代码实现了一个简单的神经网络,用于解决复杂的非线性二分类问题。具体步骤包含生成数据集、定义神经网络模型、训练模型、测试模型以及可视化决策边界。 依赖库说明 python import numpy as np import matplotlib.pyplot as plt from sklearn.datase…...
CentOS 8.2 上安装 JDK 17 和 Nginx
AI越来越火了,我们想要不被淘汰就得主动拥抱。推荐一个人工智能学习网站,通俗易懂,风趣幽默,最重要的屌图甚多,忍不住分享一下给大家。点击跳转到网站 一、安装 JDK 17 1. 使用 dnf 安装(推荐)…...
Python 爬虫(4)HTTP协议
文章目录 一、HTTP协议1、HTTP特点2、HTTP工作原理3、HTTP与HTTPS的区别 前言: HTTP(HyperText Transfer Protocol,超文本传输协议)是互联网上应用最为广泛的一种网络协议,用于在客户端和服务器之间传输超文本…...
Simple-BEV的bilinear_sample 作为view_transformer的解析,核心是3D-2D关联点生成
文件路径models/view_transformers 父类 是class BiLinearSample(nn.Module)基于https://github.com/aharley/simple_bev。 函数解析 函数bev_coord_to_feature_coord的功能 将鸟瞰图3D坐标通过多相机(针孔/鱼眼)内外参投影到图像特征平面࿰…...
Midscene.js自然语言驱动的网页自动化全指南
一、概述 网页自动化在数据抓取、UI 测试和业务流程优化中发挥着重要作用。然而,传统工具如 Selenium 和 Puppeteer 要求用户具备编程技能,编写复杂的选择器和脚本维护成本高昂。Midscene.js 通过自然语言接口革新了这一领域,用户只需描述任…...
同一个局域网的话 如何访问另一台电脑的ip
在局域网内访问另一台电脑,可以通过以下几种常见的方法来实现: 直接通过IP地址访问: 首先,确保两台电脑都连接在同一个局域网内。获取目标电脑的IP地址,这可以通过在目标电脑上打开命令提示符(Windows系…...
基于SpringBoot的名著阅读网站
作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…...
Excel(实战):INDEX函数和MATCH函数、INDEX函数实战题
目录 经典用法两者嵌套查值题目解题分析 INDEX巧妙用法让数组公式,自动填充所有、有数据的行/列INDEX函数和SEQUENCE函数 经典用法两者嵌套查值 题目 根据左表查询这三个人的所有数据 解题分析 INDEX函数的参数:第1个参数是选定查找范围,…...
希尔排序中的Hibbard序列
一 定义 Hibbard序列的每个元素由以下公式生成: h_k = 2^k - 1 其中k从1开始递增,序列为:1, 3, 7, 15, 31, 63, … 二 生成方式 起始条件:k=1,对应h_1=2^1-1=1 递推公式:每次k增加1,计算 h_{k+1}=2^{k+1}-1 示例:前5项…...
uniapp超简单ios截屏和上传app store构建版本方法
假如使用windows开发ios的应用,上架的时候,你会发现,上架需要ios应用多种尺寸的ios设备的截图,和需要xcode等工具将打包好的ipa文件上传到app store的构建版本。 大部分情况下,我们的公司都没有这么多款ios设备来…...
Netty源码—5.Pipeline和Handler一
大纲 1.Pipeline和Handler的作用和构成 2.ChannelHandler的分类 3.几个特殊的ChannelHandler 4.ChannelHandler的生命周期 5.ChannelPipeline的事件处理 6.关于ChannelPipeline的问题整理 7.ChannelPipeline主要包括三部分内容 8.ChannelPipeline的初始化 9.ChannelPi…...
Netlify 的深度解析及使用指南
以下是关于 Netlify 的深度解析及使用指南,结合其核心功能与用户需求,提供一站式解决方案: 一、Netlify 核心优势 全托管静态网站服务Netlify 提供从代码托管、自动化构建到全球 CDN 加速的全流程服务,支持 HTML/CSS/JS 静态资源及…...
MySQL小练习
目录 一、单表查询 二、多表查询 一、单表查询 素材: 表名:worker-- 表中字段均为中文,比如 部门号 工资 职工号 参加工作 等 CREATE TABLE worker ( 部门号 int(11) NOT NULL, 职工号 int(11) NOT NULL, 工作时间 date NOT NULL, 工资 float…...
Apache Hive:基于Hadoop的分布式数据仓库
Apache Hive 是一个基于 Apache Hadoop 构建的开源分布式数据仓库系统,支持使用 SQL 执行 PB 级大规模数据分析与查询。 主要功能 Apache Hive 提供的主要功能如下。 HiveServer2 HiveServer2 服务用于支持接收客户端连接和查询请求。 HiveServer2 支持多客户端…...
推荐算法分析
一、性能分析指标 1. 准确性指标(Accuracy Metrics) 衡量推荐系统预测评分的准确性,包括: ✅ RMSE(均方根误差, Root Mean Squared Error) 解释:衡量预测评分 (\hat{r}_i) 和真实评分 (r_i)…...
vllm 离线推理Qwen2.5-VL-Instruct,API部署,支持max_pixels
使用这里的最新镜像: https://www.dong-blog.fun/post/1799 启动环境 docker run -it --rm --gpus "device=1,2" \ --net host \ -v ./zizhi_merge_2025-1/:/Qwen2.5-VL-Instruct \ -v ./test:/test \...
检波、限幅、钳位电路
检波电路: 类似调制收音机信号:输入的基波和载波叠加成调制信号(信号需要长距离里传输,频率要高,M级别的频率,所以要把低频信号叠在高频信号,才能把低频信号长距离传输,最后到达接收…...
学习threejs,使用TextGeometry文本几何体
👨⚕️ 主页: gis分享者 👨⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨⚕️ 收录于专栏:threejs gis工程师 文章目录 一、🍀前言1.1 ☘️THREE.TextGeometry1.1.1 ☘…...
Go红队开发—CLI框架(一)
CLI开发框架 命令行工具开发,主要是介绍开发用到的包,集成了一个框架,只要学会了基本每个人都能开发安全工具了。 该文章先学flags包,是比较经典的一个包,相比后面要学习的集成框架这个比较自由比较细化点࿰…...
解决点击按钮页面自动刷新
在React中,当你点击按钮时,如果按钮的type属性没有明确指定,它的默认值是submit。这意味着如果这个按钮被放置在一个<form>表单中,点击它会触发表单的提交行为,导致页面刷新。 在你的代码中,展开/折叠…...
高效团队开发的工具与方法 引言
引言 在现代软件开发领域,团队协作的效率和质量直接决定了项目的成败。随着项目规模的扩大和技术复杂度的增加,如何实现高效团队开发成为每个开发团队必须面对的挑战。高效团队开发不仅仅是个人技术能力的简单叠加,更需要借助合适的工具和方…...
【Java全栈进阶架构师实战:从设计模式到SpringCloudAlibaba,打造高可用系统】
🌟 分享一个教程,助刚踏入IT行业、工作几年的老油条、或热爱学习的工作党们更上一层楼的! 🌟 适合人群:初中级Java开发者、求职面试备战者、技术提升党! 📚 内容亮点: 1️⃣ …...
[蓝桥杯 2023 省 A] 异或和之和
题目来自洛谷网站: 暴力思路: 先进性预处理,找到每个点位置的前缀异或和,在枚举区间。 暴力代码: #include<bits/stdc.h> #define int long long using namespace std; const int N 1e520;int n; int arr[N…...
TDengine 3.3.2.0 集群报错 Post “http://buildkitsandbox:6041/rest/sql“
修复: vi /etc/hosts将buildkitsandbox映射为本机节点...
vue数据重置
前言 大家在开发后台管理系统的过程中,一定会遇到一个表格的条件查询重置功能吧,如果说查询条件少,重置起来还算是比较简单,如果元素特别多呢,那玩意写起来可遭老罪喽,那今天就给大家整一个如何快速重置数…...
22、web前端开发之html5(三)
六. 离线存储与缓存 在网络环境不稳定或需要优化资源加载速度的场景下,离线存储与缓存技术显得尤为重要。HTML5引入了多种离线存储和缓存机制,帮助开发者提升用户体验。本节将详细介绍Application Cache、localStorage、sessionStorage以及IndexedDB等技…...
git revert 用法实战:撤销一个 commit 或 merge
git revert 1 区别 • 常规的 commit (使用 git commit 提交的 commit) • merge commit 2 首先构建场景 master上的代码 dev开发分支上,添加一个a标签,并commit这次提交 切到master上,再次进行改动和提交 将de…...
修形还是需要再研究一下
最近有不少小伙伴问到修形和蜗杆砂轮的问题,之前虽然研究过一段时间,但是由于时间问题放下了,最近想再捡起来。 之前计算的砂轮齿形是一整段的,但是似乎这种对于有些小伙伴来说不太容易接受,希望按照修形的区域进行分…...
AI本地部署之dify
快捷目录 Windows 系统一、环境准备:首先windows 需要准备docker 环1. 安装Docker desktop2. 安装Docker3. 配置Docker 镜像路径4. 配置Docker 下载镜像源5. 重启Docker服务二、Dify 下载和安装1. Dify下载2. Dify 配置3. Dify 安装附件知识:4. Dify创建账号三、下载Ollama d…...
安恒春招一面
《网安面试指南》https://mp.weixin.qq.com/s/RIVYDmxI9g_TgGrpbdDKtA?token1860256701&langzh_CN 5000篇网安资料库https://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247486065&idx2&snb30ade8200e842743339d428f414475e&chksmc0e4732df793fa3bf39…...
