TensorFlow代码逻辑 vs PyTorch代码逻辑
文章目录
- 一、TensorFlow
 - (一)导入必要的库
 - (二)加载MNIST数据集
 - (三)数据预处理
 - (四)构建神经网络模型
 - (五)编译模型
 - (六)训练模型
 - (七)评估模型
 - (八)将模型的输出转化为概率
 - (九)预测测试集的前5个样本
 
- 二、PyTorch
 - (一)导入必要的库
 - (二)定义神经网络模型
 - (三)数据预处理和加载
 - (四)初始化模型、损失函数和优化器
 - (五)训练模型
 - (六)评估模型
 - (七)设置设备为GPU或CPU
 - (八)运行训练和评估
 - (九)预测测试集的前5个样本
 
- 三、TensorFlow和PyTorch代码逻辑上的对比
 - (一)模型定义
 - (二)数据处理
 - (三)训练过程
 - (四)自动求导
 
- 四、TensorFlow和PyTorch的应用
 - 五、动态图计算
 - (一)TensorFlow(静态图计算):
 - (二)PyTorch(动态图计算):
 
一、TensorFlow
使用TensorFlow构建一个简单的神经网络来对MNIST数据集进行分类
(一)导入必要的库
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
 
(二)加载MNIST数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
 
(三)数据预处理
将图像数据归一化到[0, 1]范围,以提高模型的训练效果
x_train, x_test = x_train / 255.0, x_test / 255.0
 
(四)构建神经网络模型
- Flatten层:将输入的28x28像素图像展平成784个特征的一维向量。
 - Dense层:全连接层,包含128个神经元,使用ReLU激活函数。
 - Dropout层:在训练过程中随机丢弃20%的神经元,防止过拟合。
 - 输出层:包含10个神经元,对应10个类别(数字0-9)。
 
model = models.Sequential([layers.Flatten(input_shape=(28, 28)),layers.Dense(128, activation='relu'),layers.Dropout(0.2),layers.Dense(10)
])
 
(五)编译模型
- optimizer=‘adam’:使用Adam优化器。
 - loss=‘SparseCategoricalCrossentropy’:使用交叉熵损失函数。
 - metrics=[‘accuracy’]:使用准确率作为评估指标。
 
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
 
(六)训练模型
model.fit(x_train, y_train, epochs=5)
 
(七)评估模型
model.evaluate(x_test, y_test, verbose=2)
 
(八)将模型的输出转化为概率
probability_model = tf.keras.Sequential([model,tf.keras.layers.Softmax()
])
 
(九)预测测试集的前5个样本
predictions = probability_model.predict(x_test[:5])
print(predictions)
 
二、PyTorch
使用PyTorch来构建、训练和评估一个用于MNIST数据集的神经网络模型
(一)导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
 
(二)定义神经网络模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(28 * 28, 128)self.dropout = nn.Dropout(0.2)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)x = F.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x
 
(三)数据预处理和加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
 
(四)初始化模型、损失函数和优化器
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
 
(五)训练模型
def train(model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
 
(六)评估模型
def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} 'f'({accuracy:.0f}%)\n')
 
(七)设置设备为GPU或CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
 
(八)运行训练和评估
for epoch in range(1, 6):train(model, device, train_loader, optimizer, epoch)test(model, device, test_loader)
 
(九)预测测试集的前5个样本
model.eval()
with torch.no_grad():samples = next(iter(test_loader))[0][:5].to(device)output = model(samples)predictions = F.softmax(output, dim=1)print(predictions)
 
三、TensorFlow和PyTorch代码逻辑上的对比
(一)模型定义
- 在TensorFlow中,通常使用tf.keras模块来定义模型。可以使用Sequential API或Functional API。
 
import tensorflow as tf# Sequential API
model = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])# Functional API
inputs = tf.keras.Input(shape=(784,))
x = tf.keras.layers.Dense(64, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs, outputs)
 
- PyTorch中,定义模型时需要继承nn.Module类并实现forward方法
 
import torch
import torch.nn as nnclass Model(nn.Module):def __init__(self):super(Model, self).__init__()self.dense1 = nn.Linear(784, 64)self.relu = nn.ReLU()self.dense2 = nn.Linear(64, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = self.relu(self.dense1(x))x = self.softmax(self.dense2(x))return xmodel = Model()
 
(二)数据处理
- TensorFlow有tf.data模块来处理数据管道
 
import tensorflow as tfdef preprocess(data):# 数据预处理逻辑return datadataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
dataset = dataset.map(preprocess).batch(32)
 
- PyTorch使用torch.utils.data.DataLoader和Dataset类来处理数据管道
 
import torch
from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):x = self.data[idx]y = self.labels[idx]return x, ydataset = CustomDataset(X_train, y_train)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
 
(三)训练过程
- TensorFlow的tf.keras提供了高阶API来进行模型编译和训练
 
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=5, batch_size=32)
 
- PyTorch中,训练过程需要手动编写,包括前向传播、损失计算、反向传播和优化步骤
 
import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(5):for data, labels in dataloader:optimizer.zero_grad()outputs = model(data)loss = criterion(outputs, labels)loss.backward()optimizer.step()
 
(四)自动求导
- TensorFlow在后端自动处理梯度计算和应用
 
# 使用model.fit自动处理
 
- PyTorch的自动求导功能非常灵活,可以使用autograd模块
 
# 使用loss.backward()和optimizer.step()手动处理
 
四、TensorFlow和PyTorch的应用
总体来说,PyTorch提供了更多的灵活性和控制,适合需要自定义复杂模型和训练过程的场景。而TensorFlow则更加高级和简洁,适合快速原型和标准模型的开发。
TensorFlow:
- 高阶API:使用tf.keras简化模型定义、训练和评估,适合快速原型开发和生产部署。
 - 性能优化:支持图计算,优化执行速度和资源使用,适合大规模分布式训练。
 - 广泛生态:拥有丰富的工具和库,如TensorBoard用于可视化,TensorFlow Lite用于移动端部署。
 - 企业支持:由Google支持,广泛应用于工业界,提供稳定的长期支持和更新。
 
PyTorch:
- 灵活性:采用动态图计算,代码易于调试和修改,适合研究和实验。
 - 简单直观:符合Python语言习惯,API设计简洁明了,降低学习曲线。
 - 社区活跃:由Facebook支持,拥有活跃的开源社区,快速响应用户需求和改进。
 - 科研应用:广泛应用于学术界,支持多种前沿研究,如自定义损失函数和复杂模型结构。
 
五、动态图计算
动态图计算是PyTorch的一个显著特点,它让模型的计算图在每次前向传播时动态生成,而不是像TensorFlow那样预先定义和编译。
动态图计算的定义与特性:
- 动态生成:每次执行前向传播时,计算图都会根据当前输入数据动态构建。
 - 即时调试:允许在代码执行时使用标准的Python调试工具(如pdb),进行逐步调试和检查。
 - 灵活性高:支持更复杂和动态的模型结构,如条件控制流和递归神经网络,更适合研究实验和快速原型开发。
 
(一)TensorFlow(静态图计算):
在TensorFlow中,计算图是预先定义并编译的。在模型定义和编译之后,图结构固定,随后输入数据进行计算。
import tensorflow as tf# 定义计算图
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
logits = tf.matmul(x, W) + b
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))# 创建会话并执行
with tf.Session() as sess:sess.run(tf.global_variables_initializer())for i in range(1000):batch_x, batch_y = ...  # 获取训练数据sess.run(loss, feed_dict={x: batch_x, y: batch_y})
 
(二)PyTorch(动态图计算):
在PyTorch中,计算图在每次前向传播时动态构建,代码更接近标准的Python编程风格。
import torch
import torch.nn as nn
import torch.optim as optim# 定义模型
class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.dense = nn.Linear(784, 10)def forward(self, x):return self.dense(x)model = Model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练过程
for epoch in range(1000):for data, target in dataloader:optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()
相关文章:
TensorFlow代码逻辑 vs PyTorch代码逻辑
文章目录 一、TensorFlow(一)导入必要的库(二)加载MNIST数据集(三)数据预处理(四)构建神经网络模型(五)编译模型(六)训练模型…...
boost asio异步服务器(4)处理粘包
粘包的产生 当客户端发送多个数据包给服务器时,服务器底层的tcp接收缓冲区收到的数据为粘连在一起的。这种情况的产生通常是服务器端处理数据的速率不如客户端的发送速率的情况。比如:客户端1s内连续发送了两个hello world!,服务器过了2s才接…...
【QT】常用控件|widget|QPushButton|RadioButton|核心属性
目录 编辑 概念 信号与槽机制 控件的多样性和定制性 核心属性 enabled geometry 编辑 windowTiltle windowIcon toolTip styleSheet PushButton RadioButton 概念 QT 控件是构成图形用户界面(GUI)的基础组件,它们是实现与…...
【C++ Primer Plus学习记录】函数参数和按值传递
函数可以有多个参数。在调用函数时,只需使用都逗号将这些参数分开即可: n_chars(R,25); 上述函数调用将两个参数传递给函数n_chars(),我们将稍后定义该函数。 同样,在定义函数时,也在函数头中使用由逗号分隔的参数声…...
MySQL:设计数据库与操作
设计数据库 1. 数据建模1.1 概念模型1.2 逻辑模型1.3 实体模型主键外键外键约束 2. 标准化2.1 第一范式2.2 链接表2.3 第二范式2.4 第三范式 3. 数据库模型修改3.1 模型的正向工程3.2 同步数据库模型3.3 模型的逆向工程3.4 实际应用建议 4. 数据库实体模型4.1 创建和删除数据库…...
OBS 免费的录屏软件
一、下载 obs 【OBS】OBS Studio 的安装、参数设置和录屏、摄像头使用教程-CSDN博客 二、使用 obs & 输出无黑屏 【OBS任意指定区域录屏的方法-哔哩哔哩】 https://b23.tv/aM0hj8A OBS任意指定区域录屏的方法_哔哩哔哩_bilibili 步骤: 1)获取区域…...
uniapp微信小程序使用xr加载模型
1.在根目录与pages同级创建如下目录结构和文件: // index.js Component({properties: {modelPath: { // vue页面传过来的模型type: String,value: }},data: {},methods: {} }) { // index.json"component": true,"renderer": "xr-frame&q…...
机器人运动范围检测 c++
地上有一个m行n列的方格,一个机器人从坐标(0,0)的格子开始移动,它每次可以向上下左右移动一个格子,但不能进入行坐标和列坐标的位数之和大于k的格子,请问机器人能够到达多少个格子 #include &l…...
kettle从入门到精通 第七十四课 ETL之kettle kettle调用https接口教程,忽略SSL校验
场景:kettle调用https接口,跳过校验SSL。(有些公司内部系统之间的https的接口是没有SSL校验这一说,无需使用用证书的) 解决方案:自定义插件或者自定义jar包通过javascript调用https接口。 1、http post 步…...
C++轻量级 线程间异步消息架构(向曾经工作的ROSA-RB以及共事的DOPRA的老兄弟们致敬)
1 啰嗦一番背景 这么多年,换着槽位做牛做马,没有什么钱途 手艺仍然很潮,唯有对于第一线的码农工作,孜孜不倦,其实没有啥进步,就是在不断地重复,刷熟练度,和同期的老兄弟们…...
Kotlin中的类
类初始化顺序 constructor 里的参数列表是首先被执行的,紧接着是 init 块和属性初始化器,最后是次构造函数的函数体。 主构造函数参数列表firstProperty 初始化第一个 init 块secondProperty 初始化第二个 init 块次构造函数函数体 class Example const…...
VSCode中常用的快捷键
通用操作快捷键 显示命令面板:Ctrl Shift P or F1,用于快速访问VSCode的各种命令。 快速打开:Ctrl P,可以快速打开文件、跳转到某个行号或搜索项目内容。 新建窗口/实例:Ctrl Shift N,用于打开一个新的…...
代码随想录-Day45
198. 打家劫舍 你是一个专业的小偷,计划偷窃沿街的房屋。每间房内都藏有一定的现金,影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统,如果两间相邻的房屋在同一晚上被小偷闯入,系统会自动报警。 给定一个代表每个…...
Rust Eq 和 PartialEq
Eq 和 PartialEq 在 Rust 中,想要重载操作符,你就需要实现对应的特征。 例如 <、<、> 和 > 需要实现 PartialOrd 特征: use std::fmt::Display;struct Pair<T> {x: T,y: T, }impl<T> Pair<T> {fn new(x: T, y: T) ->…...
思考如何学习一门编程语言?
一、什么是编程语言 编程语言是一种用于编写计算机程序的人工语言。通过编程语言,程序员可以向计算机发出指令,控制计算机执行各种任务和操作。编程语言由一组语法规则和语义规则组成,这些规则定义了如何编写代码以及代码的含义。 编程语言…...
顺序串算法库构建
学习贺利坚老师顺序串算法库 数据结构之自建算法库——顺序串_创建顺序串s1,创建顺序串s2-CSDN博客 本人详细解析博客 串的概念及操作_串的基本操作-CSDN博客 版本更新日志 V1.0: 在贺利坚老师算法库指导下, 结合本人详细解析博客思路基础上,进行测试, 加入异常弹出信息 v1.0补…...
[论文阅读笔记33] Matching Anything by Segmenting Anything (CVPR2024 highlight)
这篇文章借助SAM模型强大的泛化性,在任意域上进行任意的多目标跟踪,而无需任何额外的标注。 其核心思想就是在训练的过程中,利用strong augmentation对一张图片进行变换,然后用SAM分割出其中的对象,因此可以找到一组图…...
阿里Nacos下载、安装(保姆篇)
文章目录 Nacos下载版本选择Nacos安装Windows常见问题解决 更多相关内容可查看 Nacos下载 Nacos官方下载地址:https://github.com/alibaba/nacos/releases 码云拉取(如果国外较慢或者拉取超时可以试一下国内地址) //国外 git clone https:…...
四、golang基础之defer
文章目录 一、定义二、作用三、结果四、recover错误拦截 一、定义 defer语句被用于预定对一个函数的调用。可以把这类被defer语句调用的函数称为延迟函数。 二、作用 释放占用的资源捕捉处理异常输出日志 三、结果 如果一个函数中有多个defer语句,它们会以LIFO…...
机器人----四元素
四元素 四元素的大小 [-1,1] 欧拉角转四元素...
利用最小二乘法找圆心和半径
#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...
Lombok 的 @Data 注解失效,未生成 getter/setter 方法引发的HTTP 406 错误
HTTP 状态码 406 (Not Acceptable) 和 500 (Internal Server Error) 是两类完全不同的错误,它们的含义、原因和解决方法都有显著区别。以下是详细对比: 1. HTTP 406 (Not Acceptable) 含义: 客户端请求的内容类型与服务器支持的内容类型不匹…...
进程地址空间(比特课总结)
一、进程地址空间 1. 环境变量 1 )⽤户级环境变量与系统级环境变量 全局属性:环境变量具有全局属性,会被⼦进程继承。例如当bash启动⼦进程时,环 境变量会⾃动传递给⼦进程。 本地变量限制:本地变量只在当前进程(ba…...
脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)
一、数据处理与分析实战 (一)实时滤波与参数调整 基础滤波操作 60Hz 工频滤波:勾选界面右侧 “60Hz” 复选框,可有效抑制电网干扰(适用于北美地区,欧洲用户可调整为 50Hz)。 平滑处理&…...
工业安全零事故的智能守护者:一体化AI智能安防平台
前言: 通过AI视觉技术,为船厂提供全面的安全监控解决方案,涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面,能够实现对应负责人反馈机制,并最终实现数据的统计报表。提升船厂…...
Java 8 Stream API 入门到实践详解
一、告别 for 循环! 传统痛点: Java 8 之前,集合操作离不开冗长的 for 循环和匿名类。例如,过滤列表中的偶数: List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...
SpringTask-03.入门案例
一.入门案例 启动类: package com.sky;import lombok.extern.slf4j.Slf4j; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.cache.annotation.EnableCach…...
A2A JS SDK 完整教程:快速入门指南
目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库ÿ…...
LabVIEW双光子成像系统技术
双光子成像技术的核心特性 双光子成像通过双低能量光子协同激发机制,展现出显著的技术优势: 深层组织穿透能力:适用于活体组织深度成像 高分辨率观测性能:满足微观结构的精细研究需求 低光毒性特点:减少对样本的损伤…...
倒装芯片凸点成型工艺
UBM(Under Bump Metallization)与Bump(焊球)形成工艺流程。我们可以将整张流程图分为三大阶段来理解: 🔧 一、UBM(Under Bump Metallization)工艺流程(黄色区域ÿ…...
