【人工智能-初级】第15章 TensorFlow 和 PyTorch 的入门:深度学习的利器
文章目录
- 一、引言
- 二、TensorFlow 简介
- 2.1 什么是 TensorFlow?
- 2.2 TensorFlow 安装
- 2.3 TensorFlow 构建简单的神经网络
- 2.4 TensorBoard 可视化
- 三、PyTorch 简介
- 3.1 什么是 PyTorch?
- 3.2 PyTorch 安装
- 3.3 PyTorch 构建简单的神经网络
- 四、TensorFlow 与 PyTorch 的对比
- 4.1 灵活性
- 4.2 易用性
- 4.3 部署能力
- 五、总结
- 5.1 学习要点
- 5.2 练习题
一、引言
在深度学习领域,TensorFlow 和 PyTorch 是最流行的两个框架。它们为构建、训练和部署深度学习模型提供了强大的工具,使得研究人员和开发者能够快速开发复杂的神经网络应用。这两个框架各有优劣,TensorFlow 以其强大的生产部署能力而闻名,而 PyTorch 则以其易用性和灵活性深受研究人员的喜爱。
本篇文章将介绍 TensorFlow 和 PyTorch 的基础概念,通过实例代码展示如何使用这两个框架构建简单的深度学习模型,帮助读者快速上手这两款深度学习利器。
二、TensorFlow 简介
2.1 什么是 TensorFlow?
TensorFlow 是由 Google 开发的一个开源深度学习框架,最初用于大规模机器学习任务的分布式训练。它的主要特点包括:
- 灵活性:支持从机器学习到深度学习的多种任务。
- 易于部署:可以轻松地将模型部署到不同平台(如服务器、移动设备和浏览器)。
- 强大的可视化工具:TensorBoard 是 TensorFlow 内置的可视化工具,用于追踪和监控训练过程。
TensorFlow 使用计算图(Computation Graph)来构建和执行模型,用户通过定义图中的节点和边来描述神经网络的结构。
2.2 TensorFlow 安装
要安装 TensorFlow,只需使用 Python 的 pip 工具:
pip install tensorflow
安装完成后,我们可以在 Python 环境中导入 TensorFlow:
import tensorflow as tf
2.3 TensorFlow 构建简单的神经网络
接下来,我们使用 TensorFlow 来实现一个简单的两层神经网络,用于对 MNIST 数据集进行分类。MNIST 数据集包含手写数字的图片,每个图片是 28x28 像素,分为 10 个类别(0-9)。
import tensorflow as tf
from tensorflow.keras import layers, models# 加载 MNIST 数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()# 数据归一化
x_train, x_test = x_train / 255.0, x_test / 255.0# 构建神经网络模型
model = models.Sequential([layers.Flatten(input_shape=(28, 28)),layers.Dense(128, activation='relu'),layers.Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))
- layers.Flatten:将输入的 28x28 的图像展开为一维数组,作为全连接层的输入。
- layers.Dense:定义全连接层,其中 128 个神经元使用 ReLU 激活函数,10 个输出神经元使用 softmax 激活函数。
- model.compile:定义优化器、损失函数和评估指标。
- model.fit:使用训练数据训练模型,执行 5 个训练周期(epoch)。
2.4 TensorBoard 可视化
TensorBoard 是 TensorFlow 提供的可视化工具,可以帮助我们查看模型训练的过程和参数变化。
使用 TensorBoard 的步骤如下:
- 在编译模型时添加日志记录。
- 启动 TensorBoard 服务器,查看日志。
import datetime# 定义日志目录
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)# 训练模型并记录日志
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test), callbacks=[tensorboard_callback])
接着,可以在命令行中启动 TensorBoard:
tensorboard --logdir=logs/fit
打开浏览器并访问 http://localhost:6006 即可查看训练的可视化结果。
三、PyTorch 简介
3.1 什么是 PyTorch?
PyTorch 是由 Facebook 开发的一个开源深度学习框架,以其动态计算图的特性和高度的灵活性而受到研究人员的广泛喜爱。PyTorch 提供了类似 NumPy 的张量操作,并集成了自动求导功能,使得用户能够更方便地构建和调试神经网络。
PyTorch 的主要特点包括:
- 动态图:在运行时动态构建计算图,使得调试和开发更加灵活。
- 简洁的 API:与 Python 生态系统紧密集成,易于学习和使用。
- 支持 GPU 加速:简单的 API 使得用户能够轻松将模型部署到 GPU 上。
3.2 PyTorch 安装
要安装 PyTorch,也可以使用 pip 工具:
pip install torch torchvision
安装完成后,我们可以在 Python 环境中导入 PyTorch:
import torch
import torch.nn as nn
import torch.optim as optim
3.3 PyTorch 构建简单的神经网络
接下来,我们使用 PyTorch 来实现一个简单的两层神经网络,来完成 MNIST 数据集的分类任务。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 数据加载与预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)# 定义神经网络模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(28*28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = self.flatten(x)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return self.softmax(x)model = SimpleNN()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)num_epochs = 5
for epoch in range(num_epochs):model.train()running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播与优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')print('Finished Training')
- transforms.Normalize:对图像进行归一化处理。
- nn.Linear:定义全连接层,包含输入和输出特征数。
- model.train():将模型设置为训练模式。
- optimizer.zero_grad():清除梯度缓存,防止梯度累积。
- loss.backward():计算梯度,进行反向传播。
- optimizer.step():更新模型的参数。
四、TensorFlow 与 PyTorch 的对比
4.1 灵活性
- TensorFlow:TensorFlow 在构建和运行模型时,采用静态计算图的方式,这意味着计算图在执行前就已经构建完成。这种方式有利于优化计算过程,适合在生产环境中部署模型。
- PyTorch:PyTorch 则采用动态图的方式,在运行时动态构建计算图,这使得调试和模型修改更加方便,深受研究人员和开发者的喜爱。
4.2 易用性
- TensorFlow:TensorFlow 的早期版本使用静态图,较为复杂,但自 2.0 版本以来,使用类似 Keras 的接口,变得更加易于上手。
- PyTorch:PyTorch 的设计与 Python 编程语言紧密结合,代码简洁、易于理解,尤其适合新手学习深度学习和进行研究。
4.3 部署能力
- TensorFlow:TensorFlow 提供了强大的生产部署工具,如 TensorFlow Serving,用于在生产环境中部署深度学习模型。
- PyTorch:虽然 PyTorch 的部署能力相对较弱,但最近推出的 TorchServe 也为部署 PyTorch 模型提供了便利。
五、总结
TensorFlow 和 PyTorch 是当前深度学习领域最流行的两个框架,各有其优缺点。TensorFlow 更加适合生产部署,而 PyTorch 则因其灵活性和易用性在研究领域更受欢迎。在本文中,我们介绍了这两个框架的基础知识,并通过实例代码展示了如何使用 TensorFlow 和 PyTorch 构建简单的神经网络模型。希望这篇文章能够帮助您理解并快速上手这两款深度学习工具。
5.1 学习要点
- TensorFlow 的静态图和 PyTorch 的动态图:理解两者的主要区别以及对模型构建的影响。
- 深度学习模型的构建和训练:掌握如何使用 TensorFlow 和 PyTorch 构建和训练简单的神经网络模型。
- 模型部署:了解两者在生产环境中的应用及部署能力。
5.2 练习题
- 使用 TensorFlow 和 PyTorch 分别构建一个三层神经网络模型,比较其代码实现和训练效果。
- 在 PyTorch 中使用不同的优化器(如 SGD、Adam),观察对训练速度和精度的影响。
- 使用 TensorFlow 的 TensorBoard 可视化工具,监控模型的训练过程,并理解各项指标的含义。
希望本文能帮助您
相关文章:
【人工智能-初级】第15章 TensorFlow 和 PyTorch 的入门:深度学习的利器
文章目录 一、引言二、TensorFlow 简介2.1 什么是 TensorFlow?2.2 TensorFlow 安装2.3 TensorFlow 构建简单的神经网络2.4 TensorBoard 可视化 三、PyTorch 简介3.1 什么是 PyTorch?3.2 PyTorch 安装3.3 PyTorch 构建简单的神经网络 四、TensorFlow 与 P…...
git禁用 SSL 证书验证
命令 git config --global http.sslVerify false注意:禁用 SSL 证书验证是不安全的,可能会使你的 Git 操作面临中间人攻击的风险。因此,只有在你确信网络环境是安全的,且了解禁用 SSL 验证的后果时,才应该使用这个配置…...
C++之《剑指offer》学习记录(2):sizeof
笔者最近在找工作时,无意间读到了一本名为《剑指offer》的书,粗略翻阅了一下,感觉这将会是一本能让我不再苦恼于笔试和面试“手搓代码”的书。故笔者写下该系列博客记录自己的学习历程,希望能和这本书的读者朋友们一起交流学习心得…...
linux线程 | 同步与互斥 | 线程池以及知识点补充
前言:本节内容是linux的线程的相关知识。本篇首先会实现一个简易的线程池, 然后再将线程池利用单例的懒汉模式改编一下。 然后再谈一些小的知识点,比如自旋锁, 读者写者问题等等。 那么, 现在开始我们的学习吧。 ps:本…...
ArkTS 如何实现表单,地区选择效果
速览 ArkTS实现表单和地区选择效果,可通过Picker组件实现地区选择下拉列表,结合表单组件如Input等构建完整表单。使用ArkTS提供的UI组件库和状态管理机制,可以方便地构建复杂且交云互动的表单界面。 1. ArkTS 表单基础 在ArkTS中,构建表单通常涉及多个UI组件的组合,如I…...
Vite 项目的核心配置- vite.config.ts 和 tsconfig.json 全解析
一、vite.config.ts 详细说明 vite.config.ts 是 Vite 项目的核心配置文件。它允许你自定义 Vite 的行为,以适应你的项目需求。 让我们来看看其中一些重要的配置选项: import { fileURLToPath, URL } from node:url// 使用 defineConfig 帮手函数,这样不用 jsdoc …...
如何使用JMeter进行性能测试的保姆级教程
性能测试是确保网站在用户访问高峰时保持稳定和快速响应的关键环节。作为初学者,选择合适的工具尤为重要。JMeter 是一个强大的开源性能测试工具,可以帮助我们轻松模拟多用户场景,测试网站的稳定性与性能。本教程将引导你通过一个简单的登录场…...
Qt 实战(11)样式表 | 11.1、样式表简介
文章目录 一、样式表简介1、简介2、样式表语法2.1、样式规则2.2、选择器类型2.3、伪状态2.4、设置子控件状态 3、样式表继承与优先级3.1、样式表继承3.2、样式表优先级3.3、解决冲突3.4、样式表层叠 4、总结 前言: 在开发图形用户界面(GUI)应…...
WebGl 多缓冲区和数据偏移
1.多缓冲区 多缓冲区技术通常涉及到创建多个缓冲区对象,并将它们用于不同的数据集。这种做法可以提高数据处理效率,尤其是在处理大量数据或需要频繁更新数据时。通过预先分配和配置多个缓冲区,可以在不影响渲染性能的情况下,快速…...
基于SSM的甜品店销售管理系统
作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…...
Spacetime Gaussian Feature Splatting for Real-Time Dynamic View Synthesis
Spacetime Gaussian Feature Splatting for Real-Time Dynamic View Synthesis 摘要 动态场景的新视角合成一直是一个引人入胜但充满挑战的问题。尽管最近取得了很多进展,但如何同时实现高分辨率的真实感渲染、实时渲染和紧凑的存储,依然是一个巨大的…...
PCL 基于FPFH特征描述子获取点云对应关系
目录 一、概述 1.1原理 1.2实现步骤 1.3应用场景 二、代码实现 2.1关键函数 2.1.1 FPFH特征计算函数 2.1.2 获取点云之间的对应点对函数 2.1.3 可视化函数 2.2完整代码 三、实现效果 PCL点云算法汇总及实战案例汇总的目录地址链接: PCL点云算法与项目实战案例汇总…...
项目实战:Qt+OpenCV仿射变换工具v1.1.0(支持打开图片、输出棋盘角点、调整偏移点、导出变换后的图等等)
若该文为原创文章,转载请注明出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/143105881 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV、Op…...
OpenCV坐标系统与图像处理案例
在图像处理中,理解图像的坐标系统是至关重要的。OpenCV,作为一个强大的计算机视觉库,提供了丰富的功能来操作图像。本文将介绍OpenCV中的坐标系统,并提供一个简单的案例来展示如何使用这些坐标来修改图像的特定区域。 OpenCV坐标…...
Unity之如何使用Unity Cloud Build云构建
文章目录 前言什么是 UnityCloudBuild?如何使用Unity云构建Unity 团队中的人员不属于 Unity Team 的人员UnityCloudBuild2.0价格表如何使用Unity云构建配置CloudBuild前言 Unity Cloud Build作为Unity平台的一项强大工具,它允许开发团队通过云端自动构建项目,节省了繁琐的手…...
Halcon开启多线程
并行运算(提升检测时间) 支持主线程中的子线程并行执行程序和调用算子。 一旦启动,子线程由线程 ID 标识,该线程 ID 是一个取决于操作系统的整数进程号。 子线程的执行独立于它们启动的线程。 因此,无法预测子线程执行…...
Echarts 点击事件无法使用 this 或者 this绑定的数据无法获取
这里写自定义目录标题 现象解决方案 现象 给echarts绑定自定义点击事件时,无法使用this,并且无法获取到this绑定的数据。 解决方案 增加:const _this this; 代码块如下: const _this this; let myChart echarts.init(docum…...
PCL 基于距离阈值去除错误对应关系(永久免费版)
目录 一、概述1.1 原理1.2 实现步骤1.3应用场景 二、关键函数2.1 获取初始点对2.2 基于距离的对应关系筛选函数2.3 可视化 三、完整代码四、结果展示 即日起,付费专栏所有内容将以永久免费形式陆续进行发表!!! 一、概述 在3D点云的…...
DirectX 11 和 Direct3D 11 的关系
以下是对两者的详细比较: DirectX 11 DirectX 11是微软的一项技术,为高性能游戏和复杂图形程序制定了标准。它是DirectX系列的一个版本,引入了多项创新功能,如硬件加速的Tessellation(细分曲面技术)、多线…...
什么是SCRM?为什么企业要做SCRM?
很多人都知道CRM是客户关系管理系统,而SCRM又是什么呢? 今天我就给大家用一文讲清SCRM的那些事,本文包括:SCRM 的定义与内涵,与传统 CRM 的区别;通过案例阐述其重要性及作用,如适应消费模式转变…...
反向工程与模型迁移:打造未来商品详情API的可持续创新体系
在电商行业蓬勃发展的当下,商品详情API作为连接电商平台与开发者、商家及用户的关键纽带,其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息(如名称、价格、库存等)的获取与展示,已难以满足市场对个性化、智能…...
React Native 导航系统实战(React Navigation)
导航系统实战(React Navigation) React Navigation 是 React Native 应用中最常用的导航库之一,它提供了多种导航模式,如堆栈导航(Stack Navigator)、标签导航(Tab Navigator)和抽屉…...
高等数学(下)题型笔记(八)空间解析几何与向量代数
目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...
管理学院权限管理系统开发总结
文章目录 🎓 管理学院权限管理系统开发总结 - 现代化Web应用实践之路📝 项目概述🏗️ 技术架构设计后端技术栈前端技术栈 💡 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 🗄️ 数据库设…...
Selenium常用函数介绍
目录 一,元素定位 1.1 cssSeector 1.2 xpath 二,操作测试对象 三,窗口 3.1 案例 3.2 窗口切换 3.3 窗口大小 3.4 屏幕截图 3.5 关闭窗口 四,弹窗 五,等待 六,导航 七,文件上传 …...
掌握 HTTP 请求:理解 cURL GET 语法
cURL 是一个强大的命令行工具,用于发送 HTTP 请求和与 Web 服务器交互。在 Web 开发和测试中,cURL 经常用于发送 GET 请求来获取服务器资源。本文将详细介绍 cURL GET 请求的语法和使用方法。 一、cURL 基本概念 cURL 是 "Client URL" 的缩写…...
提升移动端网页调试效率:WebDebugX 与常见工具组合实践
在日常移动端开发中,网页调试始终是一个高频但又极具挑战的环节。尤其在面对 iOS 与 Android 的混合技术栈、各种设备差异化行为时,开发者迫切需要一套高效、可靠且跨平台的调试方案。过去,我们或多或少使用过 Chrome DevTools、Remote Debug…...
springboot 日志类切面,接口成功记录日志,失败不记录
springboot 日志类切面,接口成功记录日志,失败不记录 自定义一个注解方法 import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target;/***…...
0x-3-Oracle 23 ai-sqlcl 25.1 集成安装-配置和优化
是不是受够了安装了oracle database之后sqlplus的简陋,无法删除无法上下翻页的苦恼。 可以安装readline和rlwrap插件的话,配置.bahs_profile后也能解决上下翻页这些,但是很多生产环境无法安装rpm包。 oracle提供了sqlcl免费许可,…...
使用SSE解决获取状态不一致问题
使用SSE解决获取状态不一致问题 1. 问题描述2. SSE介绍2.1 SSE 的工作原理2.2 SSE 的事件格式规范2.3 SSE与其他技术对比2.4 SSE 的优缺点 3. 实战代码 1. 问题描述 目前做的一个功能是上传多个文件,这个上传文件是整体功能的一部分,文件在上传的过程中…...
