当前位置: 首页 > news >正文

政安晨:【Keras机器学习实践要点】(三)—— 编写组件与训练数据

政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras实战演绎机器学习

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

介绍

通过 Keras,您可以编写自定义层、模型、度量指标、损失和优化器,并在同一代码库中跨 TensorFlow、JAX 和 PyTorch 运行

老规矩,咱们还是先准备环境(参考我本专栏目录中的文章,其中有搭建环境的部分):

政安晨:【TensorFlow与Keras实战演绎机器学习】专栏 —— 目录icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/136985399

准备好环境后,咱们开始。

编写组件

让我们先来看看自定义层

{keras.ops 命名空间包含}
1. NumPy API 的实现,例如 keras.ops.stack 或 keras.ops.matmul
2. 一组 NumPy 中没有的神经网络特定操作,如 keras.ops.conv 或 keras.ops.binary_crossentropy

让我们创建一个可与所有后端配合使用的自定义密集层

class MyDense(keras.layers.Layer):def __init__(self, units, activation=None, name=None):super().__init__(name=name)self.units = unitsself.activation = keras.activations.get(activation)def build(self, input_shape):input_dim = input_shape[-1]self.w = self.add_weight(shape=(input_dim, self.units),initializer=keras.initializers.GlorotNormal(),name="kernel",trainable=True,)self.b = self.add_weight(shape=(self.units,),initializer=keras.initializers.Zeros(),name="bias",trainable=True,)def call(self, inputs):# Use Keras ops to create backend-agnostic layers/metrics/etc.x = keras.ops.matmul(inputs, self.w) + self.breturn self.activation(x)

接下来,让我们制作一个依赖于keras.random命名空间的自定义Dropout层

class MyDropout(keras.layers.Layer):def __init__(self, rate, name=None):super().__init__(name=name)self.rate = rate# Use seed_generator for managing RNG state.# It is a state element and its seed variable is# tracked as part of `layer.variables`.self.seed_generator = keras.random.SeedGenerator(1337)def call(self, inputs):# Use `keras.random` for random ops.return keras.random.dropout(inputs, self.rate, seed=self.seed_generator)

接下来,让我们编写一个自定义子类模型,使用我们的两个自定义层:

class MyModel(keras.Model):def __init__(self, num_classes):super().__init__()self.conv_base = keras.Sequential([keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),keras.layers.MaxPooling2D(pool_size=(2, 2)),keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),keras.layers.GlobalAveragePooling2D(),])self.dp = MyDropout(0.5)self.dense = MyDense(num_classes, activation="softmax")def call(self, x):x = self.conv_base(x)x = self.dp(x)return self.dense(x)

让我们编译并适配它:

model = MyModel(num_classes=10)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),optimizer=keras.optimizers.Adam(learning_rate=1e-3),metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc"),],
)model.fit(x_train,y_train,batch_size=batch_size,epochs=1,  # For speedvalidation_split=0.15,
)

现在咱们演绎如下

在本地的TensorFlow虚拟环境中,首先导入keras:

from tensorflow import keras

(可以在Jupyter Notebook中运行)

如果在演绎执行中出错,可能是Keras版本问题,使用如下命令升级keras

sudo pip install --upgrade keras

执行结果:

训练模型

在任意数据源上训练模型

所有的Keras模型都可以在各种数据来源上进行训练和评估,与您使用的后端无关。这包括:

NumPy数组 Pandas数据框 TensorFlow tf.data.Dataset对象 PyTorch DataLoader对象 Keras PyDataset对象 无论您使用TensorFlow、JAX还是PyTorch作为Keras后端,它们都可以工作。

让我们尝试使用PyTorch DataLoader:

import torch# Create a TensorDataset
train_torch_dataset = torch.utils.data.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)
)
val_torch_dataset = torch.utils.data.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test)
)# Create a DataLoader
train_dataloader = torch.utils.data.DataLoader(train_torch_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(val_torch_dataset, batch_size=batch_size, shuffle=False
)model = MyModel(num_classes=10)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),optimizer=keras.optimizers.Adam(learning_rate=1e-3),metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc"),],
)
model.fit(train_dataloader, epochs=1, validation_data=val_dataloader)

现在让我们尝试使用tf.data来完成这个任务

import tensorflow as tftrain_dataset = (tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
)
test_dataset = (tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
)model = MyModel(num_classes=10)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),optimizer=keras.optimizers.Adam(learning_rate=1e-3),metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc"),],
)
model.fit(train_dataset, epochs=1, validation_data=test_dataset)


相关文章:

政安晨:【Keras机器学习实践要点】(三)—— 编写组件与训练数据

政安晨的个人主页:政安晨 欢迎 👍点赞✍评论⭐收藏 收录专栏: TensorFlow与Keras实战演绎机器学习 希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正! 介绍 通过 Keras,您可以编写自定…...

数据库系统概论(超详解!!!) 第四节 关系数据库标准语言SQL(Ⅲ)

1.连接查询 连接查询&#xff1a;同时涉及多个表的查询 连接条件或连接谓词&#xff1a;用来连接两个表的条件 一般格式&#xff1a; [<表名1>.]<列名1> <比较运算符> [<表名2>.]<列名2> [<表名1>.]<列名1> BETWEEN [&l…...

如何使用Python进行网络安全与密码学【第149篇—密码学】

&#x1f47d;发现宝藏 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。【点击进入巨牛的人工智能学习网站】。 用Python进行网络安全与密码学&#xff1a;技术实践指南 随着互联网的普及&#xff0c;网络…...

应急响应-Web2

应急响应-Web2 1.攻击者的IP地址&#xff08;两个&#xff09;&#xff1f; 192.168.126.135 192.168.126.129 通过phpstudy查看日志&#xff0c;发现192.168.126.135这个IP一直在404访问 &#xff0c; 并且在日志的最后几条一直在访问system.php &#xff0c;从这可以推断 …...

复试专业前沿问题问答合集8-1——CNN、Transformer、TensorFlow、GPT

复试专业前沿问题问答合集8-1——CNN、Transformer、TensorFlow、GPT 深度学习中的CNN、Transformer、TensorFlow、GPT大语言模型的原理关系问答: Transformer与ChatGPT的关系 Transformer 是一种基于自注意力机制的深度学习模型,最初在论文《Attention is All You Need》…...

用Python做一个植物大战僵尸

植物大战僵尸是一个相对复杂的游戏&#xff0c;涉及到图形界面、动画、游戏逻辑等多个方面。用Python实现一个完整的植物大战僵尸游戏是一个大工程&#xff0c;但我们可以简化一些内容&#xff0c;做一个基础版本。 以下是一个简化版的植物大战僵尸游戏的Python实现思路&#…...

Win11文件右键菜单栏完整显示教程

近日公司电脑升级了win11&#xff0c;发现了一个小麻烦事&#xff0c;如下图&#xff1a; 当我想使用svn或git的时候必须要多点一下&#xff0c;这忍不了&#xff0c;无形之中加大了工作量&#xff01; 于是&#xff0c;菜单全显示教程如下&#xff1a; 第一步&#xff1a;管…...

【Python实用标准库】argparser使用教程

argparser使用教程 1.介绍2.基本使用3.add_argument() 参数设置4.参考 1.介绍 &#xff08;一&#xff09;argparse 模块是 Python 内置的用于命令项选项与参数解析的模块&#xff0c;其用主要在两个方面&#xff1a; 一方面在python文件中可以将算法参数集中放到一起&#x…...

伦敦金与纸黄金有什么区别?怎么选?

伦敦金与纸黄金都是与黄金相关的投资品种&#xff0c;近期黄金市场的上涨吸引了投资者的关注&#xff0c;那投资者想开户入场成为黄金投资者应该选择纸黄金还是伦敦金呢&#xff1f;两者有何区别呢&#xff1f;下面我们就来讨论一下。 伦敦金是一种起源于伦敦的标准化黄金交易合…...

化工企业能源在线监测管理系统,智能节能助力生产

化工企业能源消耗量极大&#xff0c;其节能的空间也相对较大&#xff0c;所以需要控制能耗强度&#xff0c;保持更高的能源利用率。 化工企业能源消耗现状 1、能源管理方面 计量能源消耗时&#xff0c;计量器具存在问题&#xff0c;未能对能耗情况实施完全计量&#xff0c;有…...

C/C++ 一些使用网站收集...

C/C 标准 各种语言协议标准文档 open-std.org 网站 C语言标准文档 open-std.org C基金会网站 C/C 标准库网站 c/c 标准库 cplusplus.com 网站 c/c标准库 cppreference.com 网站 C Core Guidelines【isocpp.org】 C/C 发展 C现有状态及未来规划【isocpp.org】...

2024可以搜索夸克网盘的方法

截止2024可以搜索夸克网盘的方法 6miu盘搜 6miu盘搜是一个强大的网盘搜索工具,它汇集了多个网盘平台的资源,包括百度网盘、163网盘、金山快盘等,可以帮助用户快速找到所需的资料。6miu盘搜的一个显著特点是它的资源更新速度快,可以搜索到最新的资源。此外,6miu盘搜的界面清爽…...

2024年最新阿里云服务器价格表_CPU内存+磁盘+带宽价格

2024年阿里云服务器租用费用&#xff0c;云服务器ECS经济型e实例2核2G、3M固定带宽99元一年&#xff0c;轻量应用服务器2核2G3M带宽轻量服务器一年61元&#xff0c;ECS u1服务器2核4G5M固定带宽199元一年&#xff0c;2核4G4M带宽轻量服务器一年165元12个月&#xff0c;2核4G服务…...

300.【华为OD机试】跳房子I(时间字符串排序—JavaPythonC++JS实现)

本文收录于专栏:算法之翼 本专栏所有题目均包含优质解题思路,高质量解题代码(Java&Python&C++&JS分别实现),详细代码讲解,助你深入学习,深度掌握! 文章目录 一. 题目二.解题思路三.题解代码Python题解代码JAVA题解代码C/C++题解代码JS题解代码四.代码讲解(Ja…...

linux ln Linux 系统中用于创建链接(link)的命令

linux 命令基础汇总 命令&基础描述地址linux curl命令行直接发送 http 请求Linux curl 类似 postman 直接发送 get/post 请求linux ln创建链接&#xff08;link&#xff09;的命令创建链接&#xff08;link&#xff09;的命令linux linklinux 软链接介绍linux 软链接介绍l…...

mysql按照查询条件进行排序和统计一个字段中每个不同数值出现的次数

1.比如学生表 如何显示查询结果的顺序根据放置的顺序查询 <select id"selectNames" resultType"Student">select * from student_table where 11<if test"studentList! null">and name in<foreach item"item" ind…...

深度学习基础知识

本文内容来自https://zhuanlan.zhihu.com/p/106763782 此文章是用于学习上述链接&#xff0c;方便自己理解的笔记 1.深度学习的网络结构 深度学习是神经网络的一种特殊形式&#xff0c;典型的神经网络如下图所示。 神经元&#xff1a;表示输入、中间数值、输出数值点。例如&…...

UE4_旋转节点总结一

一、Roll、Pitch、Yaw Roll 围绕X轴旋转 飞机的翻滚角 Pitch 围绕Y轴旋转 飞机的俯仰角 Yaw 围绕Z轴旋转 飞机的航向角 二、Get Forward Vector理解 测试&#xff1a; 运行&#xff1a; 三、Get Actor Rotation理解 运行效果&#xff1a; 拆分旋转体测试一&a…...

Dockerfile将jar部署成docker容器

将jar包copy到linux&#xff0c;新建Dockerfile文件 -rw-r--r-- 1 root root 52209844 Mar 25 22:55 data-sharing-0.0.1-SNAPSHOT.jar -rwxrwxrwx 1 root root 227 Mar 25 22:57 Dockerfile [rootlocalhost mnt]# pwd /mntDockerfile内容 # 指定基础镜像 FROM java:8-a…...

Android14音频进阶:AudioFlinger向HAL输出数据过程(六十四)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒体系统工程师系列【原创干货持续更新中……】🚀 人生格言: 人生从来没有捷径,只…...

业务系统对接大模型的基础方案:架构设计与关键步骤

业务系统对接大模型&#xff1a;架构设计与关键步骤 在当今数字化转型的浪潮中&#xff0c;大语言模型&#xff08;LLM&#xff09;已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中&#xff0c;不仅可以优化用户体验&#xff0c;还能为业务决策提供…...

Vue记事本应用实现教程

文章目录 1. 项目介绍2. 开发环境准备3. 设计应用界面4. 创建Vue实例和数据模型5. 实现记事本功能5.1 添加新记事项5.2 删除记事项5.3 清空所有记事 6. 添加样式7. 功能扩展&#xff1a;显示创建时间8. 功能扩展&#xff1a;记事项搜索9. 完整代码10. Vue知识点解析10.1 数据绑…...

基于FPGA的PID算法学习———实现PID比例控制算法

基于FPGA的PID算法学习 前言一、PID算法分析二、PID仿真分析1. PID代码2.PI代码3.P代码4.顶层5.测试文件6.仿真波形 总结 前言 学习内容&#xff1a;参考网站&#xff1a; PID算法控制 PID即&#xff1a;Proportional&#xff08;比例&#xff09;、Integral&#xff08;积分&…...

【Oracle APEX开发小技巧12】

有如下需求&#xff1a; 有一个问题反馈页面&#xff0c;要实现在apex页面展示能直观看到反馈时间超过7天未处理的数据&#xff0c;方便管理员及时处理反馈。 我的方法&#xff1a;直接将逻辑写在SQL中&#xff0c;这样可以直接在页面展示 完整代码&#xff1a; SELECTSF.FE…...

SciencePlots——绘制论文中的图片

文章目录 安装一、风格二、1 资源 安装 # 安装最新版 pip install githttps://github.com/garrettj403/SciencePlots.git# 安装稳定版 pip install SciencePlots一、风格 简单好用的深度学习论文绘图专用工具包–Science Plot 二、 1 资源 论文绘图神器来了&#xff1a;一行…...

解决本地部署 SmolVLM2 大语言模型运行 flash-attn 报错

出现的问题 安装 flash-attn 会一直卡在 build 那一步或者运行报错 解决办法 是因为你安装的 flash-attn 版本没有对应上&#xff0c;所以报错&#xff0c;到 https://github.com/Dao-AILab/flash-attention/releases 下载对应版本&#xff0c;cu、torch、cp 的版本一定要对…...

10-Oracle 23 ai Vector Search 概述和参数

一、Oracle AI Vector Search 概述 企业和个人都在尝试各种AI&#xff0c;使用客户端或是内部自己搭建集成大模型的终端&#xff0c;加速与大型语言模型&#xff08;LLM&#xff09;的结合&#xff0c;同时使用检索增强生成&#xff08;Retrieval Augmented Generation &#…...

LeetCode - 199. 二叉树的右视图

题目 199. 二叉树的右视图 - 力扣&#xff08;LeetCode&#xff09; 思路 右视图是指从树的右侧看&#xff0c;对于每一层&#xff0c;只能看到该层最右边的节点。实现思路是&#xff1a; 使用深度优先搜索(DFS)按照"根-右-左"的顺序遍历树记录每个节点的深度对于…...

RSS 2025|从说明书学习复杂机器人操作任务:NUS邵林团队提出全新机器人装配技能学习框架Manual2Skill

视觉语言模型&#xff08;Vision-Language Models, VLMs&#xff09;&#xff0c;为真实环境中的机器人操作任务提供了极具潜力的解决方案。 尽管 VLMs 取得了显著进展&#xff0c;机器人仍难以胜任复杂的长时程任务&#xff08;如家具装配&#xff09;&#xff0c;主要受限于人…...

Chrome 浏览器前端与客户端双向通信实战

Chrome 前端&#xff08;即页面 JS / Web UI&#xff09;与客户端&#xff08;C 后端&#xff09;的交互机制&#xff0c;是 Chromium 架构中非常核心的一环。下面我将按常见场景&#xff0c;从通道、流程、技术栈几个角度做一套完整的分析&#xff0c;特别适合你这种在分析和改…...