当前位置: 首页 > article >正文

TensorFlow 2基本功能和示例代码

TensorFlow 2.x 是 Google 开源的一个深度学习框架,广泛用于构建和训练机器学习模型。

一、核心特点

1. Keras API 集成

TensorFlow 2.x 将 Keras 作为其核心 API,简化了模型的构建和训练流程。Keras 提供了高层次的 API,易于使用和理解。

import tensorflow as tf
from tensorflow.keras import layers# 使用 Keras Sequential API 构建模型
model = tf.keras.Sequential([layers.Dense(64, activation='relu', input_shape=(784,)),layers.Dense(10, activation='softmax')
])model.summary()
2. 函数式 API 和子类化 API

除了 Keras 的序列化模型 API,TensorFlow 2.x 还支持函数式 API 和子类化 API,允许用户构建复杂的模型结构。

函数式 API 示例:

inputs = tf.keras.Input(shape=(784,))
x = layers.Dense(64, activation='relu')(inputs)
outputs = layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)model.summary()

子类化 API 示例:

class MyModel(tf.keras.Model):def __init__(self):super(MyModel, self).__init__()self.dense1 = layers.Dense(64, activation='relu')self.dense2 = layers.Dense(10, activation='softmax')def call(self, inputs):x = self.dense1(inputs)return self.dense2(x)model = MyModel()
model(tf.zeros((1, 784)))
3. 即时执行模式

TensorFlow 2.x 默认启用 Eager Execution,允许用户逐行运行代码和立即查看结果,使得调试和模型开发更加直观和灵活。

# 启用 Eager Execution
tf.config.run_functions_eagerly(True)# 示例
x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
y = tf.constant([[5.0, 6.0], [7.0, 8.0]])
z = tf.matmul(x, y)
print(z)
4. 兼容性工具

TensorFlow 2.x 提供了兼容性工具,如 tf.compat.v1,帮助用户迁移现有的 TensorFlow 1.x 代码到 TensorFlow 2.x。

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()# TensorFlow 1.x 代码
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))
5. 分布式训练

TensorFlow 2.x 提供了简化的分布式训练 API,如 tf.distribute.Strategy,支持在多 GPU、多 TPU 和分布式环境下训练模型。

strategy = tf.distribute.MirroredStrategy()with strategy.scope():model = tf.keras.Sequential([layers.Dense(64, activation='relu', input_shape=(784,)),layers.Dense(10, activation='softmax')])model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
6. TensorFlow Hub 和 TensorFlow Datasets

提供了预训练模型和数据集库,帮助用户更快速地构建和训练模型。

import tensorflow_hub as hub
import tensorflow_datasets as tfds# 使用 TensorFlow Hub 加载预训练模型
model = tf.keras.Sequential([hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4", input_shape=(224, 224, 3)),layers.Dense(10, activation='softmax')
])# 使用 TensorFlow Datasets 加载数据集
dataset, info = tfds.load('mnist', with_info=True, as_supervised=True)
train_dataset, test_dataset = dataset['train'], dataset['test']
7. XLA 编译器

TensorFlow 2.x 支持 XLA(Accelerated Linear Algebra)编译器,优化计算图,提高性能。

# 启用 XLA 编译器
tf.config.optimizer.set_jit(True)
8. 硬件加速

支持 GPU 和 TPU 加速,提升训练和推理效率。

# 检查 GPU 是否可用
if tf.config.list_physical_devices('GPU'):print("GPU is available")
else:print("GPU is not available")

二、模型构建

1. Keras Sequential API

用于构建顺序模型,适合堆叠层的模型结构。

model = tf.keras.Sequential([layers.Dense(64, activation='relu', input_shape=(784,)),layers.Dense(10, activation='softmax')
])
2. Keras Functional API

用于构建复杂的模型结构,如多输入、多输出模型。

inputs = tf.keras.Input(shape=(784,))
x = layers.Dense(64, activation='relu')(inputs)
outputs = layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
3. 子类化 API

允许用户定义自定义层和模型。

class MyModel(tf.keras.Model):def __init__(self):super(MyModel, self).__init__()self.dense1 = layers.Dense(64, activation='relu')self.dense2 = layers.Dense(10, activation='softmax')def call(self, inputs):x = self.dense1(inputs)return self.dense2(x)model = MyModel()

三、训练与评估

1. 训练模型

使用 model.compile 配置训练参数,使用 model.fit 训练模型。

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(train_dataset, epochs=5)
2. 评估模型

使用 model.evaluate 评估模型性能。

loss, accuracy = model.evaluate(test_dataset)
print(f"Loss: {loss}, Accuracy: {accuracy}")

四、其他功能

1. TensorFlow Lite

TensorFlow 的轻量级版本,适用于移动和嵌入式设备。

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()with open('model.tflite', 'wb') as f:f.write(tflite_model)
2. TensorFlow Hub

一个库,旨在促进机器学习模型的可重用模块的发布、发现和使用。

model = tf.keras.Sequential([hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4", input_shape=(224, 224, 3)),layers.Dense(10, activation='softmax')
])
3. TensorFlow Extended(TFX)

一个基于 TensorFlow 的通用机器学习平台,包括 TensorFlow Transform、TensorFlow Model Analysis 和 TensorFlow Serving 等开源库。

# 示例代码需要结合 TFX 库使用
4. TensorBoard

一套可视化工具,支持对 TensorFlow 程序的理解、调试和优化。

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
model.fit(train_dataset, epochs=5, callbacks=[tensorboard_callback])

五、综合应用示例

1. 模型构建

问题: 如何使用TensorFlow 2.x构建一个简单的全连接神经网络(MLP)?

代码示例:

import tensorflow as tf
from tensorflow.keras import layers, models# 构建模型
model = models.Sequential([layers.Dense(64, activation='relu', input_shape=(784,)),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 打印模型结构
model.summary()
2. 数据预处理

问题: 如何使用TensorFlow 2.x对MNIST数据集进行预处理?

代码示例:

import tensorflow as tf# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()# 归一化数据
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0# 将标签转换为整数
y_train = y_train.astype('int32')
y_test = y_test.astype('int32')
3. 模型训练

问题: 如何使用TensorFlow 2.x训练一个模型?

代码示例:

# 训练模型
history = model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2)# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_acc}")
4. 模型保存与加载

问题: 如何保存和加载TensorFlow 2.x模型?

代码示例:

# 保存模型
model.save('my_model.h5')# 加载模型
loaded_model = tf.keras.models.load_model('my_model.h5')# 使用加载的模型进行预测
predictions = loaded_model.predict(x_test)
5. 自定义损失函数

问题: 如何在TensorFlow 2.x中自定义损失函数?

代码示例:

import tensorflow as tf# 自定义损失函数
def custom_loss(y_true, y_pred):return tf.reduce_mean(tf.square(y_true - y_pred))# 编译模型时使用自定义损失函数
model.compile(optimizer='adam', loss=custom_loss)
6. 使用回调函数

问题: 如何在TensorFlow 2.x中使用回调函数?

代码示例:

# 定义回调函数
callbacks = [tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),tf.keras.callbacks.ModelCheckpoint(filepath='best_model.h5', save_best_only=True)
]# 训练模型时使用回调函数
model.fit(x_train, y_train, epochs=10, validation_split=0.2, callbacks=callbacks)
7. 使用TensorBoard

问题: 如何在TensorFlow 2.x中使用TensorBoard进行可视化?

代码示例:

# 定义TensorBoard回调
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs')# 训练模型时使用TensorBoard回调
model.fit(x_train, y_train, epochs=5, validation_split=0.2, callbacks=[tensorboard_callback])
8. 使用GPU加速

问题: 如何在TensorFlow 2.x中使用GPU加速训练?

代码示例:

# 检查是否有GPU可用
if tf.config.list_physical_devices('GPU'):print("GPU is available")
else:print("GPU is not available")# 使用GPU进行训练
with tf.device('/GPU:0'):model.fit(x_train, y_train, epochs=5, batch_size=32)
9. 模型微调

问题: 如何在TensorFlow 2.x中对预训练模型进行微调?

代码示例:

# 加载预训练模型
base_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')# 冻结预训练模型的层
base_model.trainable = False# 添加自定义层
model = tf.keras.Sequential([base_model,tf.keras.layers.GlobalAveragePooling2D(),tf.keras.layers.Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])# 训练模型
model.fit(x_train, y_train, epochs=5, batch_size=32)
10. 分布式训练

问题: 如何在TensorFlow 2.x中进行分布式训练?

代码示例:

# 设置分布式策略
strategy = tf.distribute.MirroredStrategy()# 在策略范围内构建和编译模型
with strategy.scope():model = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')])model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])# 训练模型
model.fit(x_train, y_train, epochs=5, batch_size=32)

相关文章:

TensorFlow 2基本功能和示例代码

TensorFlow 2.x 是 Google 开源的一个深度学习框架,广泛用于构建和训练机器学习模型。 一、核心特点 1. Keras API 集成 TensorFlow 2.x 将 Keras 作为其核心 API,简化了模型的构建和训练流程。Keras 提供了高层次的 API,易于使用和理解。…...

ZZNUOJ(C/C++)基础练习1011——1020(详解版)

1011 : 圆柱体表面积 题目描述 输入圆柱体的底面半径r和高h,计算圆柱体的表面积并输出到屏幕上。要求定义圆周率为如下宏常量 #define PI 3.14159 输入 输入两个实数,表示圆柱体的底面半径r和高h。 输出 输出一个实数,即圆柱体的表面积&…...

Python 字典:快速掌握高效的数据存储方式

文章目录 一、什么是字典?字典的定义二、字典的基本操作1. 访问字典的值2. 修改字典中的值3. 添加新的键值对4. 删除键值对5. 获取字典长度三、字典的遍历1. 遍历键2. 遍历值3. 遍历键值对四、字典的常用方法1. `keys()`:获取所有键2. `values()`:获取所有值3. `items()`:获…...

Baklib探索内容中台的核心价值与实施策略

内容概要 在数字化转型的背景下,内容中台逐渐成为企业数字化策略中的关键组成部分。内容中台是一个集成的内容管理体系,旨在打破信息孤岛,使内容能够在各个业务部门和平台之间高效流通。这种管理体系不仅能够提升内容的生产效率,…...

网络安全攻防实战:从基础防护到高级对抗

📝个人主页🌹:一ge科研小菜鸡-CSDN博客 🌹🌹期待您的关注 🌹🌹 引言 在信息化时代,网络安全已经成为企业、政府和个人必须重视的问题。从数据泄露到勒索软件攻击,每一次…...

论文阅读(十三):复杂表型关联的贝叶斯、基于系统的多层次分析:从解释到决策

1.论文链接:Bayesian, Systems-based, Multilevel Analysis of Associations for Complex Phenotypes: from Interpretation to Decision 摘要: 遗传关联研究(GAS)报告的结果相对稀缺,促使许多研究方向。尽管关联概念…...

13.zookeeper开机自启动配置

要在Linux(RHEL7.7)系统中设置zookeeper开机自启动,可以创建一个系统服务单元文件。以下是为详细配置部署,假设你已经安装了zookeeper并且可以通过zkServer.sh命令启动它。 1.进入/lib/systemd/system目录 命令: cd /lib/systemd/system [root@rhel77 system]# cd /lib/…...

“““【运用 R 语言里的“predict”函数针对 Cox 模型展开新数据的预测以及推理。】“““

主题与背景 本文主要介绍了如何在R语言中使用predict函数对已拟合的Cox比例风险模型进行新数据的预测和推理。Cox模型是一种常用的生存分析方法,用于评估多个因素对事件发生时间的影响。文章通过具体的代码示例展示了如何使用predict函数的不同参数来获取生存概率和…...

Oracle Primavera P6 最新版 v24.12 更新 1/2

目录 引言 P6 PPM 更新内容 1. 在提交更新基线前预览调整 2. 快速轻松地取消链接活动 3. 选择是否从 XER 文件导入责任经理 4. 提高全局变更报告的清晰度 5. 将整个分层代码值路径导出到 CPP 6. 里程碑活动支持所有关系类型 6. 时间表批准 7. 性能改进 8. 安装改进 …...

AI大模型开发原理篇-2:语言模型雏形之词袋模型

基本概念 词袋模型(Bag of Words,简称 BOW)是自然语言处理和信息检索等领域中一种简单而常用的文本表示方法,它将文本看作是一组单词的集合,并忽略文本中的语法、词序等信息,仅关注每个词的出现频率。 文本…...

JavaWeb学习-SpringBotWeb开发入门(HTTP协议)

(一)SpringBotWeb开发步骤 (1)创建springboot工程,并勾选开发相关依赖 (2)定义HelloController类,添加方法hello,并添加注解 (3)运行测试 (二)HTTP入门概述 创建请求页面 package com.itheima.demo3; /*请求处理类,加上注解标识为请求处理类*/import org.spr…...

网站结构优化:加速搜索引擎收录的关键

本文来自:百万收录网 原文链接:https://www.baiwanshoulu.com/9.html 网站结构优化对于加速搜索引擎收录至关重要。以下是一些关键策略,旨在通过优化网站结构来提高搜索引擎的抓取效率和收录速度: 一、合理规划网站架构 采用扁…...

本地部署deepseek模型步骤

文章目录 0.deepseek简介1.安装ollama软件2.配置合适的deepseek模型3.安装chatbox可视化 0.deepseek简介 DeepSeek 是一家专注于人工智能技术研发的公司,致力于打造高性能、低成本的 AI 模型,其目标是让 AI 技术更加普惠,让更多人能够用上强…...

【deepseek】deepseek-r1本地部署-第二步:huggingface.co替换为hf-mirror.com国内镜像

一、背景 由于国际镜像国内无法直接访问,会导致搜索模型时加载失败,如下: 因此需将国际地址替换为国内镜像地址。 二、操作 1、使用vscode打开下载路径 2、全局地址替换 关键字 huggingface.co 替换为 hf-mirror.com 注意:务…...

sunrays-framework配置重构

文章目录 1.common-log4j2-starter1.目录结构2.Log4j2Properties.java 新增两个属性3.Log4j2AutoConfiguration.java 条件注入LogAspect4.ApplicationEnvironmentPreparedListener.java 从Log4j2Properties.java中定义的配置读取信息 2.common-minio-starter1.MinioProperties.…...

Spark Streaming的背压机制的原理与实现代码及分析

Spark Streaming的背压机制是一种根据JobScheduler反馈的作业执行信息来动态调整Receiver数据接收率的机制。 在Spark 1.5.0及以上版本中,可以通过设置spark.streaming.backpressure.enabled为true来启用背压机制。当启用背压机制时,Spark Streaming会自…...

刷题记录 贪心算法-2:455. 分发饼干

题目:455. 分发饼干 难度:简单 假设你是一位很棒的家长,想要给你的孩子们一些小饼干。但是,每个孩子最多只能给一块饼干。 对每个孩子 i,都有一个胃口值 g[i],这是能让孩子们满足胃口的饼干的最小尺寸&a…...

360大数据面试题及参考答案

数据清理有哪些方法? 数据清理是指发现并纠正数据文件中可识别的错误,包括检查数据一致性,处理无效值和缺失值等。常见的数据清理方法有以下几种: 去重处理:数据中可能存在重复的记录,这不仅会占用存储空间,还可能影响分析结果。通过对比每条记录的关键属性,若所有关键…...

【大模型】Ollama+AnythingLLM搭建RAG大模型私有知识库

文章目录 一、AnythingLLM简介二、搭建本地智能知识库2.1 安装Ollama2.2 安装AnythingLLM 参考资料 一、AnythingLLM简介 AnythingLLM是由Mintplex Labs Inc.开发的一个全栈应用程序,是一款高效、可定制、开源的企业级文档聊天机器人解决方案。AnythingLLM能够将任…...

深入MapReduce——从MRv1到Yarn

引入 我们前面篇章有提到,和MapReduce的论文不太一样。在Hadoop1.0实现里,每一个MapReduce的任务并没有一个独立的master进程,而是直接让调度系统承担了所有的worker 的master 的角色,这就是Hadoop1.0里的 JobTracker。在Hadoop1…...

arkui-x 前端布局编码模板

build() {Column() {Row() {// 上侧页面布局实现}// 下侧页面布局实现}.width(Const.THOUSANDTH_1000).height(Const.THOUSANDTH_1000).justifyContent(FlexAlign.SpaceBetween).backgroundImage($r(app.media.background_xxx)).backgroundImageSize(ImageSize.Cover).backgrou…...

代理模式 -- 学习笔记

代理模式学习笔记 什么是代理? 代理是一种设计模式,用户可以通过代理操作,而真正去进行处理的是我们的目标对象,代理可以在方法增强(如:记录日志,添加事务,监控等) 拿一…...

sem_init的概念和使用案例

sem_init 是 POSIX 线程库中用于初始化未命名信号量&#xff08;unnamed semaphore&#xff09;的函数&#xff0c;常用于多线程或多进程间的同步。以下是其概念和使用案例的详细说明&#xff1a; 概念 函数原型&#xff1a; #include <semaphore.h>int sem_init(sem_t …...

JVM_类的加载、链接、初始化、卸载、主动使用、被动使用

①. 说说类加载分几步&#xff1f; ①. 按照Java虚拟机规范,从class文件到加载到内存中的类,到类卸载出内存为止,它的整个生命周期包括如下7个阶段: 第一过程的加载(loading)也称为装载验证、准备、解析3个部分统称为链接(Linking)在Java中数据类型分为基本数据类型和引用数据…...

ProfibusDP主机与从机交互

ProfibusDP 主机SD2索要数据下发&#xff1a;68 08 F7 68 01 02 03 21 05 06 07 08 1C 1668&#xff1a;SD2 08&#xff1a;LE F7&#xff1a;LEr 68&#xff1a;SD2 01:目的地址 02&#xff1a;源地址 03:FC_CYCLIC_DATA_EXCHANGE功能码 21&#xff1a;数据地址 05,06,07,08&a…...

Java设计模式:结构型模式→组合模式

Java 组合模式详解 1. 定义 组合模式&#xff08;Composite Pattern&#xff09;是一种结构型设计模式&#xff0c;它允许将对象组合成树形结构以表示“部分-整体”的层次。组合模式使得客户端能够以统一的方式对待单个对象和对象集合的一致性&#xff0c;有助于处理树形结构…...

【福州市AOI小区面】shp数据学校大厦商场等占地范围面数据内容测评

AOI城区小区面样图和数据范围查看&#xff1a; — 字段里面有name字段。分类比较多tpye&#xff1a;每个值代表一个类型。比如字段type中1549代表小区住宅&#xff0c;1563代表学校。小区、学校等占地面积范围数据 —— 小区范围占地面积面数据shp格式 无偏移坐标&#xff0c;只…...

【Python实现机器遗忘算法】复现2023年TNNLS期刊算法UNSIR

【Python实现机器遗忘算法】复现2023年TNNLS期刊算法UNSIR 1 算法原理 Tarun A K, Chundawat V S, Mandal M, et al. Fast yet effective machine unlearning[J]. IEEE Transactions on Neural Networks and Learning Systems, 2023. 本文提出了一种名为 UNSIR&#xff08;Un…...

基于SpringBoot的阳光幼儿园管理系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏&#xff1a;…...

【逻辑学导论第15版】A. 推理

识别下列语段中的前提与结论。有些前提确实支持结论&#xff0c;有些并不支持。请注意&#xff0c;前提可能直接或间接地支持结论&#xff0c;而简单的语段也可能包含不止一个论证。 例题&#xff1a; 1.管理得当的民兵组织对于一个自由国家的安全是必需的&#xff0c;因而人民…...