Python深度学习实战-基于tensorflow原生代码搭建BP神经网络实现分类任务(附源码和实现效果)
实现功能
前面两篇文章分别介绍了两种搭建神经网络模型的方法,一种是基于tensorflow的keras框架,另一种是继承父类自定义class类,本篇文章将编写原生代码搭建BP神经网络。
实现代码
import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target# 数据预处理
scaler = StandardScaler()
X = scaler.fit_transform(X)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 设置超参数
learning_rate = 0.001
num_epochs = 100
batch_size = 32# 定义输入和输出的维度
input_dim = X.shape[1]
output_dim = len(set(y))# 定义权重和偏置项
W1 = tf.Variable(tf.random.normal(shape=(input_dim, 64), dtype=tf.float64))
b1 = tf.Variable(tf.zeros(shape=(64,), dtype=tf.float64))
W2 = tf.Variable(tf.random.normal(shape=(64, 64), dtype=tf.float64))
b2 = tf.Variable(tf.zeros(shape=(64,), dtype=tf.float64))
W3 = tf.Variable(tf.random.normal(shape=(64, output_dim), dtype=tf.float64))
b3 = tf.Variable(tf.zeros(shape=(output_dim,), dtype=tf.float64))# 定义前向传播函数
def forward_pass(X):X = tf.cast(X, tf.float64)h1 = tf.nn.relu(tf.matmul(X, W1) + b1)h2 = tf.nn.relu(tf.matmul(h1, W2) + b2)logits = tf.matmul(h2, W3) + b3return logits# 定义损失函数
def loss_fn(logits, labels):return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits))# 定义优化器
optimizer = tf.optimizers.Adam(learning_rate)# 定义准确率指标
accuracy_metric = tf.metrics.SparseCategoricalAccuracy()# 定义训练步骤
def train_step(inputs, labels):with tf.GradientTape() as tape:logits = forward_pass(inputs)loss_value = loss_fn(logits, labels)gradients = tape.gradient(loss_value, [W1, b1, W2, b2, W3, b3])optimizer.apply_gradients(zip(gradients, [W1, b1, W2, b2, W3, b3]))accuracy_metric(labels, logits)return loss_value# 进行训练
for epoch in range(num_epochs):epoch_loss = 0.0accuracy_metric.reset_states()for batch_start in range(0, len(X_train), batch_size):batch_end = batch_start + batch_sizebatch_X = X_train[batch_start:batch_end]batch_y = y_train[batch_start:batch_end]loss = train_step(batch_X, batch_y)epoch_loss += losstrain_loss = epoch_loss / (len(X_train) // batch_size)train_accuracy = accuracy_metric.result()print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.4f}")# 进行评估
logits = forward_pass(X_test)
test_loss = loss_fn(logits, y_test)
test_accuracy = accuracy_metric(y_test, logits)print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
实现效果
本人读研期间发表5篇SCI数据挖掘相关论文,现在某研究院从事数据挖掘相关科研工作,对数据挖掘有一定认知和理解,会结合自身科研实践经历不定期分享关于python、机器学习、深度学习基础知识与案例。
致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。
邀请三个朋友关注V订阅号:数据杂坛,即可在后台联系我获取相关数据集和源码,送有关数据分析、数据挖掘、机器学习、深度学习相关的电子书籍。
相关文章:

Python深度学习实战-基于tensorflow原生代码搭建BP神经网络实现分类任务(附源码和实现效果)
实现功能 前面两篇文章分别介绍了两种搭建神经网络模型的方法,一种是基于tensorflow的keras框架,另一种是继承父类自定义class类,本篇文章将编写原生代码搭建BP神经网络。 实现代码 import tensorflow as tf from sklearn.datasets import…...

PDF 文档处理:使用 Java 对比 PDF 找出内容差异
不论是在团队写作还是在个人工作中,PDF 文档往往会经过多次修订和更新。掌握 PDF 文档内容的变化对于管理文档有极大的帮助。通过对比 PDF 文档,用户可以快速找出文档增加、删除和修改的内容,更好地了解文档的演变过程,轻松地管理…...

压敏电阻有哪些原理?|深圳比创达电子EMC
压敏电阻是一种金属氧化物陶瓷半导体电阻器。它以氧化锌(ZnO)为基料,加入多种(一般5~10种)其它添加剂,经压制成坯体,高温烧结,成为具有晶界特性的多晶半导体陶瓷组件。氧化锌压敏电阻器的微观结构如下图1所示。 氧化锌…...

【计算机网络笔记】Web应用之HTTP协议(涉及HTTP连接类型和HTTP消息格式)
系列文章目录 什么是计算机网络? 什么是网络协议? 计算机网络的结构 数据交换之电路交换 数据交换之报文交换和分组交换 分组交换 vs 电路交换 计算机网络性能(1)——速率、带宽、延迟 计算机网络性能(2)…...

IDEA 2023.2.2 使用 Scala 编译报错 No scalac found to compile scala sources
一、问题 scala: No scalac found to compile scala sources 官网 Bug 链接 二、临时解决方案 Incrementality Type 先变成 IDEA 类型 Please go to Settings > Build, Execution, Deployment > Compiler > Scala Compiler and change the Incrementality type to …...
C51--PWN-舵机控制
PWM开发sg90舵机 1、简介 PWM(pulse width modulation)是脉冲宽度调制缩写。 通过对一系列脉冲的宽度进行调制,等效出所需要的波形(包含形状以及幅值)。对模拟信号电平进行数字编码,通过调节占空比的变化来…...

electron27+react18集成搭建跨平台应用|electron窗口多开
基于Electron27集成React18创建一个桌面端exe程序。 electron27-vite4-react18基于electron27结合vite4构建工具快速创建react18跨端应用实践。 版本列表 "vite": "^4.4.5" "react": "^18.2.0" "electron": "^27.0.1&…...

【k8s】kubeadm安装k8s集群
一、环境部署 master192.168.88.10docker、kubeadm、kubelet、kubectl、flannelnode01192.168.88.20docker、kubeadm、kubelet、kubectl、flannelnode02192.168.88.30docker、kubeadm、kubelet、kubectl、flannelhub.lp.com192.168.88.40 docker、docker-compose harbor-offli…...

三、虚拟机的迁移和删除
虚拟机的本质就是文件(放在文件夹的)。因此虚拟机的迁移很方便,可以把安装好的虚拟系统这个文件夹整体拷贝或者剪切到另外的位置使用。删除也很简单,使用vmware进行移除,再点菜单->从磁盘删除即可,或者手动删除虚拟系统对应的文…...

RabbitMQ的交换机(原理及代码实现)
1.交换机类型 Fanout Exchange(扇形)Direct Exchange(直连)opic Exchange(主题)Headers Exchange(头部) 2.Fanout Exchange 2.1 简介 Fanout 扇形的,散开的࿱…...

【C++进阶】pair容器
👦个人主页:Weraphael ✍🏻作者简介:目前学习C和算法 ✈️专栏:C航路 🐋 希望大家多多支持,咱一起进步!😁 如果文章对你有帮助的话 欢迎 评论💬 点赞…...

Linux--进程等待
1.什么是进程等待 1.通过系统调用wait/waitid,来对子进程进行进行检测和回收的功能。 2.为什么有进程等待 1.对于每个进程来说,如果子进程终止,父进程没有停止,就会形成僵尸进程,导致内存泄露,为了防止僵尸进程的形成…...
VMware CentOS 虚拟机扩容
参考文章: VMware中centos磁盘扩容 - 简书 看这篇文章进行操作!扩展根分区报错,xfs_growfs 提示 / is not a mounted XFS filesystem-CSDN博客 [rootnode001 ~]# df 文件系统 1K-块 已用 可用 已用% 挂载点 /dev/…...

CentOS 编译安装 nginx
CentOS 编译安装 nginx 修改 yum 源地址为 阿里云 curl -o /etc/yum.repos.d/CentOS-Base.repo https://mirrors.aliyun.com/repo/Centos-7.repoyum makecache升级内核和软件 yum -y update安装常用软件和依赖 yum -y install gcc gcc-c make cmake zlib zlib-devel openss…...

学习笔记-MongoDB(命令增删改查,聚合,权限管理,索引,java使用)
基础概念 1 什么是mogodb? MongoDB 是一个基于分布式文件/文档存储的数据库,由 C 编写,可以为 Web 应用提供可扩展、高性能、易部署的数据存储解决方案。MongoDB 是一个介于关系数据库和非关系数据库之间的产品,是非关系数据库中功…...

第13期 | GPTSecurity周报
GPTSecurity是一个涵盖了前沿学术研究和实践经验分享的社区,集成了生成预训练 Transformer(GPT)、人工智能生成内容(AIGC)以及大型语言模型(LLM)等安全领域应用的知识。在这里,您可以…...

OpenCV学习(一)——图像读取
1. 图像入门 读取图像显示图像写入图像 import cv2# 读取图像 img cv2.imread(lena.jpg) print(img.shape)# 显示图像 cv2.imshow(image, img) cv2.waitKey(0) cv2.destroyAllWindows()# 写入图像 cv2.imwrite(image.jpg, img)1.1 读取图像 读取图像cv.imread(filename, fl…...

并发编程- 线程池ForkJoinPool工作原理分析(实践)
数据结构加油站: Comparison Sorting Visualization 并发设计模式 单线程归并排序 public class MergeSort {private final int[] arrayToSort; //要排序的数组private final int threshold; //拆分的阈值,低于此阈值就不再进行拆分public MergeSort…...
小程序原生开发中的onLoad和onShow
在小程序的原生开发中,onLoad和onShow是两个常用的生命周期函数,用于管理页面的加载和显示。 onLoad:该函数会在页面加载时触发。当页面第一次加载时,它会被调用一次,之后切换到其他页面再返回时不会再触发。可以在on…...
springcloud技术栈以及相关组件
常用中间件 注册中心—nacos分布式服务之间的交互工具—Feign服务安全入口中间件—Gateway各个服务的异步通信组件—rabbitmqRabbitMq分布式场景的应用配置微服务的容器部署–docker分布式检索引擎—elasticSearches在分布式场景的应用分布式事务协调中间间— seata分布式服务…...
macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用
文章目录 问题现象问题原因解决办法 问题现象 macOS启动台(Launchpad)多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显,都是Google家的办公全家桶。这些应用并不是通过独立安装的…...

React19源码系列之 事件插件系统
事件类别 事件类型 定义 文档 Event Event 接口表示在 EventTarget 上出现的事件。 Event - Web API | MDN UIEvent UIEvent 接口表示简单的用户界面事件。 UIEvent - Web API | MDN KeyboardEvent KeyboardEvent 对象描述了用户与键盘的交互。 KeyboardEvent - Web…...

华为OD机试-食堂供餐-二分法
import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...
Python爬虫(二):爬虫完整流程
爬虫完整流程详解(7大核心步骤实战技巧) 一、爬虫完整工作流程 以下是爬虫开发的完整流程,我将结合具体技术点和实战经验展开说明: 1. 目标分析与前期准备 网站技术分析: 使用浏览器开发者工具(F12&…...
解决本地部署 SmolVLM2 大语言模型运行 flash-attn 报错
出现的问题 安装 flash-attn 会一直卡在 build 那一步或者运行报错 解决办法 是因为你安装的 flash-attn 版本没有对应上,所以报错,到 https://github.com/Dao-AILab/flash-attention/releases 下载对应版本,cu、torch、cp 的版本一定要对…...
关于 WASM:1. WASM 基础原理
一、WASM 简介 1.1 WebAssembly 是什么? WebAssembly(WASM) 是一种能在现代浏览器中高效运行的二进制指令格式,它不是传统的编程语言,而是一种 低级字节码格式,可由高级语言(如 C、C、Rust&am…...

多模态大语言模型arxiv论文略读(108)
CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文标题:CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文作者:Sayna Ebrahimi, Sercan O. Arik, Tejas Nama, Tomas Pfister ➡️ 研究机构: Google Cloud AI Re…...

什么是Ansible Jinja2
理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具,可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板,允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板,并通…...
Device Mapper 机制
Device Mapper 机制详解 Device Mapper(简称 DM)是 Linux 内核中的一套通用块设备映射框架,为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程,并配以详细的…...

抽象类和接口(全)
一、抽象类 1.概念:如果⼀个类中没有包含⾜够的信息来描绘⼀个具体的对象,这样的类就是抽象类。 像是没有实际⼯作的⽅法,我们可以把它设计成⼀个抽象⽅法,包含抽象⽅法的类我们称为抽象类。 2.语法 在Java中,⼀个类如果被 abs…...