当前位置: 首页 > article >正文

神经网络解决非线性二分类

这份 Python 代码实现了一个简单的神经网络,用于解决复杂的非线性二分类问题。具体步骤包含生成数据集、定义神经网络模型、训练模型、测试模型以及可视化决策边界。

依赖库说明

python

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split

  • numpy:用于数值计算,如数组操作、矩阵运算等。
  • matplotlib.pyplot:用于数据可视化,如绘制决策边界和数据点。
  • make_moons:从sklearn.datasets导入,用于生成半月形的非线性分类数据集。
  • train_test_split:从sklearn.model_selection导入,用于将数据集划分为训练集和测试集。

详细步骤说明

1. 生成复杂的非线性分类数据集

python

np.random.seed(42)
X, y = make_moons(n_samples=1000, noise=0.2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

  • np.random.seed(42):设置随机数种子,保证结果可复现。
  • make_moons(n_samples=1000, noise=0.2, random_state=42):生成包含 1000 个样本的半月形数据集,噪声水平为 0.2。
  • train_test_split(X, y, test_size=0.2, random_state=42):将数据集按 80:20 的比例划分为训练集和测试集。
2. 定义神经网络(带隐藏层的非线性分类模型)

python

class SimpleNN:def __init__(self, input_size=2, hidden_size=10, output_size=1):self.w1 = np.random.randn(input_size, hidden_size)self.b1 = np.random.randn(hidden_size)self.w2 = np.random.randn(hidden_size, output_size)self.b2 = np.random.randn(output_size)

  • __init__方法:初始化神经网络的权重和偏置。input_size为输入层神经元数量,hidden_size为隐藏层神经元数量,output_size为输出层神经元数量。

python

    def sigmoid(self, x):return 1 / (1 + np.exp(-x))def relu(self, x):return np.maximum(0, x)

  • sigmoid方法:实现 Sigmoid 激活函数,将输入值映射到 (0, 1) 区间,常用于二分类问题的输出层。
  • relu方法:实现 ReLU 激活函数,将小于 0 的值置为 0,大于等于 0 的值保持不变,可缓解梯度消失问题。

python

    def forward(self, x):self.z1 = np.dot(x, self.w1) + self.b1self.a1 = self.relu(self.z1)self.z2 = np.dot(self.a1, self.w2) + self.b2self.a2 = self.sigmoid(self.z2)return self.a2

  • forward方法:实现神经网络的前向传播过程。输入数据经过输入层到隐藏层的线性变换,再通过 ReLU 激活函数;然后经过隐藏层到输出层的线性变换,最后通过 Sigmoid 激活函数得到输出。

python

    def binary_cross_entropy(self, y_true, y_pred):return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))

  • binary_cross_entropy方法:计算二分类交叉熵损失,衡量模型预测值与真实值之间的差异。

python

    def gradient(self, x, y_true, y_pred):m = x.shape[0]d_z2 = y_pred - y_trued_w2 = np.dot(self.a1.T, d_z2) / md_b2 = np.sum(d_z2, axis=0) / md_a1 = np.dot(d_z2, self.w2.T)d_z1 = d_a1 * (self.z1 > 0)d_w1 = np.dot(x.T, d_z1) / md_b1 = np.sum(d_z1, axis=0) / mreturn d_w1, d_b1, d_w2, d_b2

  • gradient方法:实现反向传播算法,计算权重和偏置的梯度。

python

    def train(self, x, y, lr=0.01, epochs=1000):for epoch in range(epochs):y_pred = self.forward(x)dw1, db1, dw2, db2 = self.gradient(x, y, y_pred)self.w1 -= lr * dw1self.b1 -= lr * db1self.w2 -= lr * dw2self.b2 -= lr * db2if (epoch + 1) % 100 == 0:loss = self.binary_cross_entropy(y, y_pred)acc = self.accuracy(y, y_pred)print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss:.4f}, Accuracy: {acc:.4f}')

  • train方法:训练神经网络。在每个 epoch 中,先进行前向传播得到预测值,再通过反向传播计算梯度,最后更新权重和偏置。每 100 个 epoch 打印一次损失和准确率。

python

    def accuracy(self, y_true, y_pred):y_pred_class = (y_pred > 0.5).astype(int)return np.mean(y_pred_class == y_true)

  • accuracy方法:计算分类准确率,将预测概率大于 0.5 的样本判定为正类,小于等于 0.5 的判定为负类,然后计算预测正确的样本比例。
3. 训练模型

python

model = SimpleNN(input_size=2, hidden_size=10, output_size=1)
model.train(X_train, y_train.reshape(-1, 1), lr=0.01, epochs=2000)

  • 创建SimpleNN类的实例model,并调用train方法对模型进行训练,学习率为 0.01,训练 2000 个 epoch。
4. 测试模型

python

y_test_pred = model.forward(X_test)
test_acc = model.accuracy(y_test.reshape(-1, 1), y_test_pred)
print(f'Test Accuracy: {test_acc:.4f}')

  • 使用训练好的模型对测试集进行预测,计算测试集的准确率并打印。
5. 可视化决策边界

python

def plot_decision_boundary(model, X, y):x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),np.arange(y_min, y_max, 0.01))grid = np.c_[xx.ravel(), yy.ravel()]probs = model.forward(grid).reshape(xx.shape)plt.contourf(xx, yy, probs, levels=[0, 0.5, 1], alpha=0.8, cmap=plt.cm.RdBu)plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', cmap=plt.cm.RdBu)plt.title("Decision Boundary")plt.show()plot_decision_boundary(model, X_test, y_test)
  • plot_decision_boundary函数:绘制模型的决策边界。首先创建一个网格,然后对网格中的每个点进行预测,最后使用contourf函数绘制决策区域,并使用scatter函数绘制数据点。
  • 调用plot_decision_boundary函数,传入训练好的模型、测试集数据和标签,可视化决策边界。

总结

该代码实现了一个简单的两层神经网络,用于解决复杂的非线性二分类问题。通过生成数据集、定义模型、训练模型、测试模型和可视化决策边界等步骤,展示了神经网络在非线性分类任务中的应用。

完整代码

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split# 1. 生成复杂的非线性分类数据集
np.random.seed(42)
X, y = make_moons(n_samples=1000, noise=0.2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 2. 定义神经网络(带隐藏层的非线性分类模型)
class SimpleNN:def __init__(self, input_size=2, hidden_size=10, output_size=1):# 初始化权重和偏置self.w1 = np.random.randn(input_size, hidden_size)  # 输入层到隐藏层的权重self.b1 = np.random.randn(hidden_size)              # 隐藏层的偏置self.w2 = np.random.randn(hidden_size, output_size)  # 隐藏层到输出层的权重self.b2 = np.random.randn(output_size)              # 输出层的偏置def sigmoid(self, x):return 1 / (1 + np.exp(-x))  # Sigmoid激活函数def relu(self, x):return np.maximum(0, x)  # ReLU激活函数def forward(self, x):# 前向传播self.z1 = np.dot(x, self.w1) + self.b1  # 隐藏层输入self.a1 = self.relu(self.z1)            # 隐藏层输出(应用ReLU)self.z2 = np.dot(self.a1, self.w2) + self.b2  # 输出层输入self.a2 = self.sigmoid(self.z2)         # 输出层输出(应用Sigmoid)return self.a2def binary_cross_entropy(self, y_true, y_pred):# 二分类交叉熵损失return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))def gradient(self, x, y_true, y_pred):# 反向传播计算梯度m = x.shape[0]# 输出层的梯度d_z2 = y_pred - y_trued_w2 = np.dot(self.a1.T, d_z2) / md_b2 = np.sum(d_z2, axis=0) / m# 隐藏层的梯度d_a1 = np.dot(d_z2, self.w2.T)d_z1 = d_a1 * (self.z1 > 0)  # ReLU的导数d_w1 = np.dot(x.T, d_z1) / md_b1 = np.sum(d_z1, axis=0) / mreturn d_w1, d_b1, d_w2, d_b2def train(self, x, y, lr=0.01, epochs=2000):for epoch in range(epochs):y_pred = self.forward(x)dw1, db1, dw2, db2 = self.gradient(x, y, y_pred)# 更新权重和偏置self.w1 -= lr * dw1self.b1 -= lr * db1self.w2 -= lr * dw2self.b2 -= lr * db2if (epoch + 1) % 100 == 0:loss = self.binary_cross_entropy(y, y_pred)acc = self.accuracy(y, y_pred)print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss:.4f}, Accuracy: {acc:.4f}')def accuracy(self, y_true, y_pred):# 计算分类准确率y_pred_class = (y_pred > 0.5).astype(int)return np.mean(y_pred_class == y_true)# 3. 训练模型
model = SimpleNN(input_size=2, hidden_size=10, output_size=1)
model.train(X_train, y_train.reshape(-1, 1), lr=0.01, epochs=2000)# 4. 测试模型
y_test_pred = model.forward(X_test)
test_acc = model.accuracy(y_test.reshape(-1, 1), y_test_pred)
print(f'Test Accuracy: {test_acc:.4f}')# 5. 可视化决策边界
def plot_decision_boundary(model, X, y):x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),np.arange(y_min, y_max, 0.01))grid = np.c_[xx.ravel(), yy.ravel()]probs = model.forward(grid).reshape(xx.shape)plt.contourf(xx, yy, probs, levels=[0, 0.5, 1], alpha=0.8, cmap=plt.cm.RdBu)plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', cmap=plt.cm.RdBu)plt.title("Decision Boundary")plt.show()plot_decision_boundary(model, X_test, y_test)

相关文章:

神经网络解决非线性二分类

这份 Python 代码实现了一个简单的神经网络,用于解决复杂的非线性二分类问题。具体步骤包含生成数据集、定义神经网络模型、训练模型、测试模型以及可视化决策边界。 依赖库说明 python import numpy as np import matplotlib.pyplot as plt from sklearn.datase…...

CentOS 8.2 上安装 JDK 17 和 Nginx

AI越来越火了,我们想要不被淘汰就得主动拥抱。推荐一个人工智能学习网站,通俗易懂,风趣幽默,最重要的屌图甚多,忍不住分享一下给大家。点击跳转到网站 一、安装 JDK 17 1. 使用 dnf 安装(推荐&#xff09…...

Python 爬虫(4)HTTP协议

文章目录 一、HTTP协议1、HTTP特点2、HTTP工作原理3、HTTP与HTTPS的区别 前言: HTTP(HyperText Transfer Protocol,超文本传输协议)是互联网上应用最为广泛的一种网络协议,用于在客户端和服务器之间传输超文本&#xf…...

Simple-BEV的bilinear_sample 作为view_transformer的解析,核心是3D-2D关联点生成

文件路径models/view_transformers 父类 是class BiLinearSample(nn.Module)基于https://github.com/aharley/simple_bev。 函数解析 函数bev_coord_to_feature_coord的功能 将鸟瞰图3D坐标通过多相机(针孔/鱼眼)内外参投影到图像特征平面&#xff0…...

Midscene.js自然语言驱动的网页自动化全指南

一、概述 网页自动化在数据抓取、UI 测试和业务流程优化中发挥着重要作用。然而,传统工具如 Selenium 和 Puppeteer 要求用户具备编程技能,编写复杂的选择器和脚本维护成本高昂。Midscene.js 通过自然语言接口革新了这一领域,用户只需描述任…...

同一个局域网的话 如何访问另一台电脑的ip

在局域网内访问另一台电脑,可以通过以下几种常见的方法来实现: ‌直接通过IP地址访问‌: 首先,确保两台电脑都连接在同一个局域网内。获取目标电脑的IP地址,这可以通过在目标电脑上打开命令提示符(Windows系…...

基于SpringBoot的名著阅读网站

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…...

Excel(实战):INDEX函数和MATCH函数、INDEX函数实战题

目录 经典用法两者嵌套查值题目解题分析 INDEX巧妙用法让数组公式,自动填充所有、有数据的行/列INDEX函数和SEQUENCE函数 经典用法两者嵌套查值 题目 根据左表查询这三个人的所有数据 解题分析 INDEX函数的参数:第1个参数是选定查找范围&#xff0c…...

希尔排序中的Hibbard序列

一 定义 Hibbard序列的每个元素由以下公式生成: h_k = 2^k - 1 其中k从1开始递增,序列为:1, 3, 7, 15, 31, 63, … 二 生成方式 起始条件:k=1,对应h_1=2^1-1=1 递推公式:每次k增加1,计算 h_{k+1}=2^{k+1}-1 示例:前5项…...

uniapp超简单ios截屏和上传app store构建版本方法

​ 假如使用windows开发ios的应用,上架的时候,你会发现,上架需要ios应用多种尺寸的ios设备的截图,和需要xcode等工具将打包好的ipa文件上传到app store的构建版本。 大部分情况下,我们的公司都没有这么多款ios设备来…...

Netty源码—5.Pipeline和Handler一

大纲 1.Pipeline和Handler的作用和构成 2.ChannelHandler的分类 3.几个特殊的ChannelHandler 4.ChannelHandler的生命周期 5.ChannelPipeline的事件处理 6.关于ChannelPipeline的问题整理 7.ChannelPipeline主要包括三部分内容 8.ChannelPipeline的初始化 9.ChannelPi…...

Netlify 的深度解析及使用指南

以下是关于 Netlify 的深度解析及使用指南,结合其核心功能与用户需求,提供一站式解决方案: 一、Netlify 核心优势 全托管静态网站服务Netlify 提供从代码托管、自动化构建到全球 CDN 加速的全流程服务,支持 HTML/CSS/JS 静态资源及…...

MySQL小练习

目录 一、单表查询 二、多表查询 一、单表查询 素材: 表名:worker-- 表中字段均为中文,比如 部门号 工资 职工号 参加工作 等 CREATE TABLE worker ( 部门号 int(11) NOT NULL, 职工号 int(11) NOT NULL, 工作时间 date NOT NULL, 工资 float…...

Apache Hive:基于Hadoop的分布式数据仓库

Apache Hive 是一个基于 Apache Hadoop 构建的开源分布式数据仓库系统,支持使用 SQL 执行 PB 级大规模数据分析与查询。 主要功能 Apache Hive 提供的主要功能如下。 HiveServer2 HiveServer2 服务用于支持接收客户端连接和查询请求。 HiveServer2 支持多客户端…...

推荐算法分析

一、性能分析指标 1. 准确性指标(Accuracy Metrics) 衡量推荐系统预测评分的准确性,包括: ✅ RMSE(均方根误差, Root Mean Squared Error) 解释:衡量预测评分 (\hat{r}_i) 和真实评分 (r_i)…...

vllm 离线推理Qwen2.5-VL-Instruct,API部署,支持max_pixels

使用这里的最新镜像: https://www.dong-blog.fun/post/1799 启动环境 docker run -it --rm --gpus "device=1,2" \ --net host \ -v ./zizhi_merge_2025-1/:/Qwen2.5-VL-Instruct \ -v ./test:/test \...

检波、限幅、钳位电路

检波电路: 类似调制收音机信号:输入的基波和载波叠加成调制信号(信号需要长距离里传输,频率要高,M级别的频率,所以要把低频信号叠在高频信号,才能把低频信号长距离传输,最后到达接收…...

学习threejs,使用TextGeometry文本几何体

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:threejs gis工程师 文章目录 一、🍀前言1.1 ☘️THREE.TextGeometry1.1.1 ☘…...

Go红队开发—CLI框架(一)

CLI开发框架 命令行工具开发,主要是介绍开发用到的包,集成了一个框架,只要学会了基本每个人都能开发安全工具了。 该文章先学flags包,是比较经典的一个包,相比后面要学习的集成框架这个比较自由比较细化点&#xff0…...

解决点击按钮页面自动刷新

在React中&#xff0c;当你点击按钮时&#xff0c;如果按钮的type属性没有明确指定&#xff0c;它的默认值是submit。这意味着如果这个按钮被放置在一个<form>表单中&#xff0c;点击它会触发表单的提交行为&#xff0c;导致页面刷新。 在你的代码中&#xff0c;展开/折叠…...

高效团队开发的工具与方法 引言

引言 在现代软件开发领域&#xff0c;团队协作的效率和质量直接决定了项目的成败。随着项目规模的扩大和技术复杂度的增加&#xff0c;如何实现高效团队开发成为每个开发团队必须面对的挑战。高效团队开发不仅仅是个人技术能力的简单叠加&#xff0c;更需要借助合适的工具和方…...

【Java全栈进阶架构师实战:从设计模式到SpringCloudAlibaba,打造高可用系统】

&#x1f31f; 分享一个教程&#xff0c;助刚踏入IT行业、工作几年的老油条、或热爱学习的工作党们更上一层楼的&#xff01; &#x1f31f; ​适合人群&#xff1a;初中级Java开发者、求职面试备战者、技术提升党&#xff01; &#x1f4da; ​内容亮点&#xff1a; 1️⃣ ​…...

[蓝桥杯 2023 省 A] 异或和之和

题目来自洛谷网站&#xff1a; 暴力思路&#xff1a; 先进性预处理&#xff0c;找到每个点位置的前缀异或和&#xff0c;在枚举区间。 暴力代码&#xff1a; #include<bits/stdc.h> #define int long long using namespace std; const int N 1e520;int n; int arr[N…...

TDengine 3.3.2.0 集群报错 Post “http://buildkitsandbox:6041/rest/sql“

修复&#xff1a; vi /etc/hosts将buildkitsandbox映射为本机节点...

vue数据重置

前言 大家在开发后台管理系统的过程中&#xff0c;一定会遇到一个表格的条件查询重置功能吧&#xff0c;如果说查询条件少&#xff0c;重置起来还算是比较简单&#xff0c;如果元素特别多呢&#xff0c;那玩意写起来可遭老罪喽&#xff0c;那今天就给大家整一个如何快速重置数…...

22、web前端开发之html5(三)

六. 离线存储与缓存 在网络环境不稳定或需要优化资源加载速度的场景下&#xff0c;离线存储与缓存技术显得尤为重要。HTML5引入了多种离线存储和缓存机制&#xff0c;帮助开发者提升用户体验。本节将详细介绍Application Cache、localStorage、sessionStorage以及IndexedDB等技…...

git revert 用法实战:撤销一个 commit 或 merge

git revert 1 区别 • 常规的 commit &#xff08;使用 git commit 提交的 commit&#xff09; • merge commit 2 首先构建场景 master上的代码 dev开发分支上&#xff0c;添加一个a标签&#xff0c;并commit这次提交 切到master上&#xff0c;再次进行改动和提交 将de…...

修形还是需要再研究一下

最近有不少小伙伴问到修形和蜗杆砂轮的问题&#xff0c;之前虽然研究过一段时间&#xff0c;但是由于时间问题放下了&#xff0c;最近想再捡起来。 之前计算的砂轮齿形是一整段的&#xff0c;但是似乎这种对于有些小伙伴来说不太容易接受&#xff0c;希望按照修形的区域进行分…...

AI本地部署之dify

快捷目录 Windows 系统一、环境准备:首先windows 需要准备docker 环1. 安装Docker desktop2. 安装Docker3. 配置Docker 镜像路径4. 配置Docker 下载镜像源5. 重启Docker服务二、Dify 下载和安装1. Dify下载2. Dify 配置3. Dify 安装附件知识:4. Dify创建账号三、下载Ollama d…...

安恒春招一面

《网安面试指南》https://mp.weixin.qq.com/s/RIVYDmxI9g_TgGrpbdDKtA?token1860256701&langzh_CN 5000篇网安资料库https://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247486065&idx2&snb30ade8200e842743339d428f414475e&chksmc0e4732df793fa3bf39…...