详解:Tensorflow、Pytorch、Keras(搭建自己的深度学习网络)
这是一个专门对Tensorflow、Pytorch、Keras三个主流DL框架的一个详解和对比分析
一、何为深度学习框架?
你可以理解为一个工具帮你构建一个深度学习网络,调用里面的各种方法就能自行构建任意层,diy你想要的DNN,而且任意指定学习器和优化器等,非常的方便!
二、Tensorflow
1.发展历史
TensorFlow由Google智能机器研究部门Google Brain团队研发的;TensorFlow编程接口支持Python和C++。随着1.0版本的公布,相继支持了Java、Go、R和Haskell API的alpha版本。
在2017年,Tensorflow独占鳌头,处于深度学习框架的领先地位;但截至目前已经和Pytorch不争上下。
注意,Tensorflow目前主要在工业级领域处于领先地位。参考至博客(38 封私信 / 16 条消息) 为什么说学术上用pytorch,工业上用tensorflow? - 知乎 (zhihu.com)
但说句实话,这个问题过于宏观,每个人都有自己的观点,最好还是自己实际两者都使用之后,再来说最适合自己的是哪一个吧。(并且tensoeflow和pytorch两者都一直在发展,后期有可能就不分伯仲啦!)
三、Pytorch
Pytorch目前是由Facebook人工智能学院提供支持服务的。
Pytorch目前主要在学术研究方向领域处于领先地位。
其优点在于:PyTorch可以使用强大的GPU加速的Tensor计算(比如:Numpy的使用)以及可以构建带有autograd的深度神经网络。
同时,PyTorch 的代码很简洁、易于使用、支持计算过程中的动态图而且内存使用很高效
四、Keras
本来是一个独立的高级API,现在已经成为Tensorflow的一部分
接口简单友好,使用tensorflow作为后端,适合快速实验和原型开发。
五、区别
主要区别:
-
计算图:
- TensorFlow使用静态计算图,需要先定义后运行
- PyTorch使用动态计算图,更灵活,可以边定义边运行
-
易用性:
- Keras通常被认为是最容易上手的
- PyTorch的API设计更加直观
- TensorFlow相对复杂一些,但提供更多底层控制
-
性能和部署:
- TensorFlow在大规模部署和性能优化方面较为成熟
- PyTorch在研究和实验阶段更受欢迎
- Keras作为高级API,性能可能略低,但开发速度快
-
社区和生态系统:
- TensorFlow拥有最大的社区和最广泛的工具支持
- PyTorch在学术界更受欢迎,增长迅速
- Keras作为TensorFlow的一部分,也有很好的社区支持
六、总结
-
TensorFlow:
- 由Google开发
- 使用静态计算图
- 广泛应用于生产环境
- 有较为完善的部署工具
-
PyTorch:
- 由Facebook开发
- 使用动态计算图
- 更加灵活,适合研究和快速原型开发
- 相对更加直观和易于调试
-
Keras:
- 最初是一个独立的高级API,现已成为TensorFlow的一部分
- 提供更简单、更用户友好的接口
- 可以使用TensorFlow或Theano作为后端
- 适合快速实验和原型开发
六、代码实现以及对比(key😍)
选择哪个框架通常取决于项目需求、个人偏好和团队经验。对于初学者,Keras可能是最好的起点;对于需要更多控制和自定义的高级用户,PyTorch或TensorFlow的低级API可能更合适。
使用TensorFlow、PyTorch和Keras分别搭建一个简单的深度神经网络的例子。这些例子都将创建一个简单的前馈神经网络用于MNIST手写数字分类任务。
1.TensorFlow 2.x 示例:(这里用的是低级API,没用tf.keras的高级API)
- 丰富的低级操作:允许对计算过程进行精细控制。
- 强大的性能优化:特别是在分布式和大规模部署方面。
- 全面的工具生态系统:包括TensorBoard、TFLite等工具。
- 灵活的模型部署:支持多种平台和设备。
- 静态图支持:虽然2.x版本默认使用动态图,但仍支持静态图,有利于某些优化。
同时下面这个例子展示了TensorFlow低级API的几个关键特性:
-
手动定义模型参数(W1, b1, W2, b2)作为tf.Variable。
-
使用函数定义模型结构,而不是使用Keras的Sequential或Functional API。
-
自定义损失函数。
-
使用tf.GradientTape来计算梯度。
-
手动应用梯度到优化器。
-
使用@tf.function装饰器来将Python函数转换为TensorFlow图,以提高性能。
-
手动实现训练循环和评估过程。
-
使用tf.data.Dataset API来处理数据。
import tensorflow as tf
import numpy as np# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# 将数据转换为适当的形状和类型
x_train = x_train.reshape(-1, 784).astype(np.float32)
x_test = x_test.reshape(-1, 784).astype(np.float32)
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)# 创建数据集
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)# 定义模型参数
W1 = tf.Variable(tf.random.normal([784, 128], stddev=0.1))
b1 = tf.Variable(tf.zeros([128]))
W2 = tf.Variable(tf.random.normal([128, 10], stddev=0.1))
b2 = tf.Variable(tf.zeros([10]))# 定义模型函数
def model(x):layer1 = tf.nn.relu(tf.matmul(x, W1) + b1)return tf.matmul(layer1, W2) + b2# 定义损失函数
def loss_fn(predictions, labels):return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=predictions))# 定义优化器
optimizer = tf.optimizers.Adam(learning_rate=0.001)# 定义训练步骤
@tf.function
def train_step(x, y):with tf.GradientTape() as tape:predictions = model(x)loss = loss_fn(predictions, y)gradients = tape.gradient(loss, [W1, b1, W2, b2])optimizer.apply_gradients(zip(gradients, [W1, b1, W2, b2]))return loss# 定义测试步骤
@tf.function
def test_step(x, y):predictions = model(x)loss = loss_fn(predictions, y)accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(predictions, axis=1), y), tf.float32))return loss, accuracy# 训练循环
epochs = 5
for epoch in range(epochs):total_loss = 0.0num_batches = 0for x_batch, y_batch in train_dataset:loss = train_step(x_batch, y_batch)total_loss += lossnum_batches += 1avg_loss = total_loss / num_batches# 在测试集上评估test_loss = 0.0test_accuracy = 0.0num_test_batches = 0for x_test_batch, y_test_batch in test_dataset:batch_loss, batch_accuracy = test_step(x_test_batch, y_test_batch)test_loss += batch_losstest_accuracy += batch_accuracynum_test_batches += 1avg_test_loss = test_loss / num_test_batchesavg_test_accuracy = test_accuracy / num_test_batchesprint(f"Epoch {epoch+1}/{epochs}")print(f"Train Loss: {avg_loss:.4f}, Test Loss: {avg_test_loss:.4f}, Test Accuracy: {avg_test_accuracy:.4f}")# 最终模型评估
final_test_loss = 0.0
final_test_accuracy = 0.0
num_final_batches = 0
for x_test_batch, y_test_batch in test_dataset:batch_loss, batch_accuracy = test_step(x_test_batch, y_test_batch)final_test_loss += batch_lossfinal_test_accuracy += batch_accuracynum_final_batches += 1
final_avg_test_loss = final_test_loss / num_final_batches
final_avg_test_accuracy = final_test_accuracy / num_final_batchesprint("\nFinal Test Results:")
print(f"Test Loss: {final_avg_test_loss:.4f}, Test Accuracy: {final_avg_test_accuracy:.4f}")
2. PyTorch示例(复杂但可操作性强)
PyTorch的API设计更加直观,其实就跟python的编程一样,类似嵌套在python里面一样,类似从底层构建的一个深度学习网络,这虽然有一点点复杂,但是这使得整个搭建过程透明化可操作性极强,适合做研究的人,进行细节优化(魔改),进行底层控制。(复杂但可操作性强)
- 类似Python的编程风格:使用动态计算图,编码感觉更像普通Python编程。
- 面向对象的设计:模型定义为类,更符合Python用户的习惯。
- 即时执行:可以立即看到每一步的结果,便于调试。
- 灵活性:易于处理动态网络结构和复杂的研究型模型。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 加载数据
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(28 * 28, 128)self.fc2 = nn.Linear(128, 10)self.dropout = nn.Dropout(0.2)def forward(self, x):x = self.flatten(x)x = torch.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return xmodel = Net()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())# 训练模型
def train(model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 运行训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(1, 6):train(model, device, train_loader, optimizer, epoch)# 评估模型
model.eval()
correct = 0
with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()print(f'Accuracy: {correct / len(test_loader.dataset)}')
3. Keras示例(简单但没有什么可操作性)
- 高级API:Keras提供了非常简洁和直观的API,隐藏了许多底层复杂性。
- 模块化设计:可以轻松堆叠层来构建模型,如model.add(layer)。
- 内置常用模型:提供了许多预定义的架构,如Sequential模型。
- 一致的接口:无论后端如何(TensorFlow、Theano等),接口保持一致。
- 详细文档:有优秀的文档和大量教程。
比较简单,加载数据---构建模型(tf.keras自己叠就行)---编译模型(定义模型在训练过程中如何学习,使用什么优化器,使用什么损失函数评估模型性能,以及监控指标等)---训练模型---评估模型。
是不是非常的简单,结构清洗明了,确实在工程上是非常的适合的,搭建快速,便于部署
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist# 加载数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# 构建模型
model = keras.Sequential([layers.Flatten(input_shape=(28, 28)),layers.Dense(128, activation='relu'),layers.Dropout(0.2),layers.Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(x_train, y_train, epochs=5)# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f'Test accuracy: {test_acc}')
很明显:
- Keras适合快速原型设计和简单项目,学习曲线最平缓。
- PyTorch在研究和复杂模型开发中很受欢迎,因为它的直观性和灵活性。
- TensorFlow提供了从高级(Keras API)到低级的全方位控制,适合各种规模的项目,尤其是大规模部署。
这些示例展示了使用不同框架构建简单神经网络的基本步骤。每个框架都有其独特的语法和风格,但基本概念是相似的:
- 加载和预处理数据
- 定义模型结构
- 指定损失函数和优化器
- 训练模型
- 评估模型性能
需要注意的是,Keras现在是TensorFlow的一部分,所以TensorFlow和Keras的例子看起来非常相似。结合了keras的TensorFlow搭建DNN非常的简单(哥们当年用的LSTM就是用的TensorFlow搭建的),PyTorch的例子稍微复杂一些,因为它提供了更多的底层控制。(比如那个transformer的搭建就是依赖的PyTorch,比较复杂)
Final:个人eassy
对于我个人来讲,使用基于tensorflow的keras搭建我的网络确实是非常的简单便捷,这也比较适合我所做的研究方向,偏重于深度学习网络的应用研究,而不是执着于优化网络或者优化一些算法!哥们就是单纯一个套用别人开发的网络,然后自定义几层,设置一下优化器,损失函数啥的,调调参数,训练一下,出了效果,就行了。对优化这件事,哥们我是一点也不敢碰瓷的呀兄弟,这都是研究内容外的拓展学习了!!!也是我的小个人发展路径和能力提升的路子吧!
相关文章:

详解:Tensorflow、Pytorch、Keras(搭建自己的深度学习网络)
这是一个专门对Tensorflow、Pytorch、Keras三个主流DL框架的一个详解和对比分析 一、何为深度学习框架? 你可以理解为一个工具帮你构建一个深度学习网络,调用里面的各种方法就能自行构建任意层,diy你想要的DNN,而且任意指定学习…...

【CSS in Depth 2 精译_035】5.5 Grid 网格布局中的子网格布局(全新内容)
当前内容所在位置(可进入专栏查看其他译好的章节内容) 第一章 层叠、优先级与继承(已完结) 1.1 层叠1.2 继承1.3 特殊值1.4 简写属性1.5 CSS 渐进式增强技术1.6 本章小结 第二章 相对单位(已完结) 2.1 相对…...

Java是怎么处理死锁的
文章目录 避免死锁避免嵌套锁资源进行排序超时锁 检测死锁通过Java提供的API检查死锁情况jStack监控工具 Java 本身没有内置的机制自动处理死锁问题,但可以采取一些策略和技术来检测和避免死锁。 避免死锁 避免嵌套锁 尽可能减少嵌套锁操作,避免在一个…...

Effective Java 学习笔记 方法签名设计
目录 谨慎选择方法名称 不要过于追求提供便利的快捷方法 避免过长的参数列表 对于参数类型优先使用接口而不是类 对于boolean参数,要优先使用两个元素的枚举类型 本文接续前一篇文章聚焦Java方法签名的设计,方法签名包括了方法的输入和输出参数以及…...

毛利超70%、超70+智驾客户,这家AI数据训练服务商刚刚止亏
AI训练数据服务第一股海天瑞声终于迎来了“曙光”。 日前,海天瑞声发布2024年半年报显示,上半年其实现营收9242.63万,同比增长24.13%;实现净利润41.64 万元,不过同比去年同期的亏损1724.14万元,扭亏为盈。…...

本地部署高颜值某抑云音乐播放器Splayer并实现无公网IP远程听歌
💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...

图像压缩编码(4)--H.26x系列视频压缩编码_2
目录 H.261 视频编码标准 H.261的编码与解码 1) 帧内/帧间编码 2)运动补偿 3)量化 4)环路滤波器 5)缓存器 压缩数据的分层 数据复用结构 H.264的编码与解码 H.261 视频编码标准 实际应用时,要求有…...

JS渲染锻炼输入表单
前言 上篇文章为大家展现了好看的信息窗口,接下来我们跟着流程图看下一步 之前我们的带点击事件已经添加完毕,下一步就是当用户点击的时候,渲染锻炼形式,当然这是一个标签,可以提供给用户输入锻炼形式 实例 ● 我…...

proteus仿真学习(1)
一,创建工程 一般选择默认模式,不配置pcb文件 可以选用芯片型号也可以不选 不选则从零开始布局,没有初始最小系统。选用则有初始最小系统以及基础的main函数 本次学习使用从零开始,不配置固件 二,上手软件 1.在元件…...

决策树+随机森林模型实现足球大小球让球预测软件
文章目录 前言一、决策树是什么?二、数据收集与整理1.数据收集2.数据清洗3.特征选择 三、决策树构建3.1绘制训练数据图像3.2 训练决策树模型3.3 依据模型绘制决策树的决策边界3.4 树模型可视化 四、模型预测五、随机森林模型总结 前言 之前搞足球数据分析的时候&…...

31省市农业地图大数据
1.北京市 谷类作物种植结构(万亩) 农作物种植结构(万亩) 2.天津市 谷类作物种植结构(万亩) 农作物种植结构(万亩) 3.黑龙江省 谷类作物种植结构(万亩) 农作物…...

http请求包含什么
HTTP请求通常包含以下几个主要部分: 请求行(Request Line): 包含请求方法(如 GET、POST、PUT、DELETE 等)、请求的目标 URI 和 HTTP 版本。例如:GET /index.html HTTP/1.1 请求头部(…...

【基础算法总结】模拟篇
目录 一,算法介绍二,算法原理和代码实现1576.替换所有的问号495.提莫攻击6.Z字形变换38.外观数列1419.数青蛙 三,算法总结 一,算法介绍 模拟算法本质就是"依葫芦画瓢",就是在题目中已经告诉了我们该如何操作…...

《深度学习》PyTorch 手写数字识别 案例解析及实现 <下>
目录 一、回顾神经网络框架 1、单层神经网络 2、多层神经网络 二、手写数字识别 1、续接上节课代码,如下所示 2、建立神经网络模型 输出结果: 3、设置训练集 4、设置测试集 5、创建损失函数、优化器 参数解析: 1)para…...

【笔记】材料分析测试:晶体学
晶体与晶体结构Crystal and Crystal Structure 1.晶体主要特征 固态物质可以分为晶态和非晶态两大类,分别称为晶体和非晶体。 晶体和非晶体在微观结构上的区别在于是否具有长程有序。 晶体(长程有序)非晶(短程有序)…...

飞塔Fortigate7.4.4的DNS劫持功能
基础网络配置、上网策略、与Server的VIP配置(略)。 在FortiGate上配置DNS Translation,将DNS请求结果为202.103.12.2的DNS响应报文中的IP地址修改为Server的内网IP 10.10.2.100。 config firewall dnstranslationedit 1set src 2.13.12.2set…...

Unity 设计模式 之 行为型模式 -【状态模式】【观察者模式】【备忘录模式】
Unity 设计模式 之 行为型模式 -【状态模式】【观察者模式】【备忘录模式】 目录 Unity 设计模式 之 行为型模式 -【状态模式】【观察者模式】【备忘录模式】 一、简单介绍 二、状态模式(State Pattern) 1、什么时候使用状态模式 2、使用状态模式的…...

【RabbitMQ】RabbitMQ 的概念以及使用RabbitMQ编写生产者消费者代码
目录 1. RabbitMQ 核心概念 1.1生产者和消费者 1.2 Connection和Channel 1.3 Virtual host 1.4 Queue 1.5 Exchange 1.6 RabbitMO工作流程 2. AMQP 3.RabbitMO快速入门 3.1.引入依赖 3.2.编写生产者代码 3.3.编写消费者代码 4.源码 1. RabbitMQ 核心概念 在安装…...

openmv与stm32通信
控制小车视觉循迹使用 OpenMV 往往是不够的。一般使用 OpenMV 对图像进行处理,将处理过后的数据使用串口发送给STM32,使用STM32控制小车行驶。本文主要讲解 OpenMV 模块与 STM32 间的串口通信以及两种循迹方案,分别是划分检测区域和线性回归。…...

C++ STL全面解析:六大核心组件之一----序列式容器(vector和List)(STL进阶学习)
目录 序列式容器 Vector vector概述 vector的迭代器 vector的数据结构 vector的构造和内存管理 vector的元素操作 List List概述 List的设计结构 List的迭代器 List的数据结构 List的内存构造 List的元素操作 C标准模板库(STL)是一组高效的…...

【c数据结构】OJ练习篇 帮你更深层次理解链表!(相交链表、相交链表、环形链表、环形链表之寻找环形入口点、判断链表是否是回文结构、 随机链表的复制)
目录 一. 相交链表 二. 环形链表 三. 环形链表之寻找环形入口点 四. 判断链表是否是回文结构 五. 随机链表的复制 一. 相交链表 最简单粗暴的思路,遍历两个链表,分别寻找是否有相同的对应的结点。 我们对两个链表的每个对应的节点进行判断比较&…...

微软开源GraphRAG的使用教程(最全,非常详细)
GraphRAG的介绍 目前微软已经开源了GraphRAG的完整项目代码。对于某一些LLM的下游任务则可以使用GraphRAG去增强自己业务的RAG的表现。项目给出了两种使用方式: 在打包好的项目状态下运行,可进行尝试使用。在源码基础上运行,适合为了下游任…...

使用Refine构建项目(1)初始化项目
要初始化一个空的Refine项目,你可以使用Refine提供的CLI工具create-refine-app。以下是初始化步骤: 使用npx命令: 在命令行中运行以下命令来创建一个新的Refine项目: npx create-refine-applatest my-refine-project这将引导你通过…...

【Docker】安装及使用
1. 安装Docker Desktop Docker Desktop是官方提供的桌面版Docker客户端,在Mac上使用Docker需要安装这个工具。 访问 Docker官方页面 并下载Docker Desktop for Mac。打开下载的.dmg文件,并拖动Docker图标到应用程序文件夹。安装完成后,打开…...

[大语言模型-论文精读] 以《黑神话:悟空》为研究案例探讨VLMs能否玩动作角色扮演游戏?
1. 论文简介 论文《Can VLMs Play Action Role-Playing Games? Take Black Myth Wukong as a Study Case》是阿里巴巴集团的Peng Chen、Pi Bu、Jun Song和Yuan Gao,在2024.09.19提交到arXiv上的研究论文。 论文: https://arxiv.org/abs/2409.12889代码和数据: h…...

提升动态数据查询效率:应对数据库成为性能瓶颈的优化方案
引言 在现代软件系统中,数据库性能是决定整个系统响应速度和处理能力的关键因素之一。然而,当系统负载增加,特别是在高并发、大数据量场景下,数据库性能往往会成为瓶颈,导致查询响应时间延长,影响用户体验…...

Prometheus+grafana+kafka_exporter监控kafka运行情况
使用Prometheus、Grafana和kafka_exporter来监控Kafka的运行情况是一种常见且有效的方案。以下是详细的步骤和说明: 1. 部署kafka_exporter 步骤: 从GitHub下载kafka_exporter的最新版本:kafka_exporter项目地址(注意ÿ…...

在vue中:style 的几种使用方式
在日常开发中:style的使用也是比较常见的: 亲测有效 1.最通用的写法 <p :style"{fontFamily:arr.conFontFamily,color:arr.conFontColor,backgroundColor:arr.conBgColor}">{{con.title}}</p> 2.三元表达式 <a :style"{height:…...

商城小程序后端开发实践中出现的问题及其解决方法
前言 商城小程序后端开发中,开发者可能会面临多种问题。以下是一些常见的问题及其解决方法: 一、性能优化 问题:随着用户量的增加和功能的扩展,商城小程序可能会出现响应速度慢、处理效率低的问题。 解决方法: 对数…...

阿里Arthas-Java诊断工具,基本操作和命令使用
Arthas 是阿里巴巴开源的一款Java诊断工具,深受开发者喜爱。它可以帮助开发者在不需要修改代码的情况下,对运行中的Java程序进行问题诊断和性能分析。 软件具体使用方法 1 启动 Arthas,此时可能会出现好几个jvm的进程号,输入序号…...