详解:Tensorflow、Pytorch、Keras(搭建自己的深度学习网络)
这是一个专门对Tensorflow、Pytorch、Keras三个主流DL框架的一个详解和对比分析
一、何为深度学习框架?
你可以理解为一个工具帮你构建一个深度学习网络,调用里面的各种方法就能自行构建任意层,diy你想要的DNN,而且任意指定学习器和优化器等,非常的方便!
二、Tensorflow
1.发展历史
TensorFlow由Google智能机器研究部门Google Brain团队研发的;TensorFlow编程接口支持Python和C++。随着1.0版本的公布,相继支持了Java、Go、R和Haskell API的alpha版本。
在2017年,Tensorflow独占鳌头,处于深度学习框架的领先地位;但截至目前已经和Pytorch不争上下。
注意,Tensorflow目前主要在工业级领域处于领先地位。参考至博客(38 封私信 / 16 条消息) 为什么说学术上用pytorch,工业上用tensorflow? - 知乎 (zhihu.com)
但说句实话,这个问题过于宏观,每个人都有自己的观点,最好还是自己实际两者都使用之后,再来说最适合自己的是哪一个吧。(并且tensoeflow和pytorch两者都一直在发展,后期有可能就不分伯仲啦!)
三、Pytorch
Pytorch目前是由Facebook人工智能学院提供支持服务的。
Pytorch目前主要在学术研究方向领域处于领先地位。
其优点在于:PyTorch可以使用强大的GPU加速的Tensor计算(比如:Numpy的使用)以及可以构建带有autograd的深度神经网络。
同时,PyTorch 的代码很简洁、易于使用、支持计算过程中的动态图而且内存使用很高效
四、Keras
本来是一个独立的高级API,现在已经成为Tensorflow的一部分
接口简单友好,使用tensorflow作为后端,适合快速实验和原型开发。
五、区别
主要区别:
-
计算图:
- TensorFlow使用静态计算图,需要先定义后运行
- PyTorch使用动态计算图,更灵活,可以边定义边运行
-
易用性:
- Keras通常被认为是最容易上手的
- PyTorch的API设计更加直观
- TensorFlow相对复杂一些,但提供更多底层控制
-
性能和部署:
- TensorFlow在大规模部署和性能优化方面较为成熟
- PyTorch在研究和实验阶段更受欢迎
- Keras作为高级API,性能可能略低,但开发速度快
-
社区和生态系统:
- TensorFlow拥有最大的社区和最广泛的工具支持
- PyTorch在学术界更受欢迎,增长迅速
- Keras作为TensorFlow的一部分,也有很好的社区支持
六、总结
-
TensorFlow:
- 由Google开发
- 使用静态计算图
- 广泛应用于生产环境
- 有较为完善的部署工具
-
PyTorch:
- 由Facebook开发
- 使用动态计算图
- 更加灵活,适合研究和快速原型开发
- 相对更加直观和易于调试
-
Keras:
- 最初是一个独立的高级API,现已成为TensorFlow的一部分
- 提供更简单、更用户友好的接口
- 可以使用TensorFlow或Theano作为后端
- 适合快速实验和原型开发
六、代码实现以及对比(key😍)
选择哪个框架通常取决于项目需求、个人偏好和团队经验。对于初学者,Keras可能是最好的起点;对于需要更多控制和自定义的高级用户,PyTorch或TensorFlow的低级API可能更合适。
使用TensorFlow、PyTorch和Keras分别搭建一个简单的深度神经网络的例子。这些例子都将创建一个简单的前馈神经网络用于MNIST手写数字分类任务。
1.TensorFlow 2.x 示例:(这里用的是低级API,没用tf.keras的高级API)
- 丰富的低级操作:允许对计算过程进行精细控制。
- 强大的性能优化:特别是在分布式和大规模部署方面。
- 全面的工具生态系统:包括TensorBoard、TFLite等工具。
- 灵活的模型部署:支持多种平台和设备。
- 静态图支持:虽然2.x版本默认使用动态图,但仍支持静态图,有利于某些优化。
同时下面这个例子展示了TensorFlow低级API的几个关键特性:
-
手动定义模型参数(W1, b1, W2, b2)作为tf.Variable。
-
使用函数定义模型结构,而不是使用Keras的Sequential或Functional API。
-
自定义损失函数。
-
使用tf.GradientTape来计算梯度。
-
手动应用梯度到优化器。
-
使用@tf.function装饰器来将Python函数转换为TensorFlow图,以提高性能。
-
手动实现训练循环和评估过程。
-
使用tf.data.Dataset API来处理数据。
import tensorflow as tf
import numpy as np# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# 将数据转换为适当的形状和类型
x_train = x_train.reshape(-1, 784).astype(np.float32)
x_test = x_test.reshape(-1, 784).astype(np.float32)
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)# 创建数据集
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)# 定义模型参数
W1 = tf.Variable(tf.random.normal([784, 128], stddev=0.1))
b1 = tf.Variable(tf.zeros([128]))
W2 = tf.Variable(tf.random.normal([128, 10], stddev=0.1))
b2 = tf.Variable(tf.zeros([10]))# 定义模型函数
def model(x):layer1 = tf.nn.relu(tf.matmul(x, W1) + b1)return tf.matmul(layer1, W2) + b2# 定义损失函数
def loss_fn(predictions, labels):return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=predictions))# 定义优化器
optimizer = tf.optimizers.Adam(learning_rate=0.001)# 定义训练步骤
@tf.function
def train_step(x, y):with tf.GradientTape() as tape:predictions = model(x)loss = loss_fn(predictions, y)gradients = tape.gradient(loss, [W1, b1, W2, b2])optimizer.apply_gradients(zip(gradients, [W1, b1, W2, b2]))return loss# 定义测试步骤
@tf.function
def test_step(x, y):predictions = model(x)loss = loss_fn(predictions, y)accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(predictions, axis=1), y), tf.float32))return loss, accuracy# 训练循环
epochs = 5
for epoch in range(epochs):total_loss = 0.0num_batches = 0for x_batch, y_batch in train_dataset:loss = train_step(x_batch, y_batch)total_loss += lossnum_batches += 1avg_loss = total_loss / num_batches# 在测试集上评估test_loss = 0.0test_accuracy = 0.0num_test_batches = 0for x_test_batch, y_test_batch in test_dataset:batch_loss, batch_accuracy = test_step(x_test_batch, y_test_batch)test_loss += batch_losstest_accuracy += batch_accuracynum_test_batches += 1avg_test_loss = test_loss / num_test_batchesavg_test_accuracy = test_accuracy / num_test_batchesprint(f"Epoch {epoch+1}/{epochs}")print(f"Train Loss: {avg_loss:.4f}, Test Loss: {avg_test_loss:.4f}, Test Accuracy: {avg_test_accuracy:.4f}")# 最终模型评估
final_test_loss = 0.0
final_test_accuracy = 0.0
num_final_batches = 0
for x_test_batch, y_test_batch in test_dataset:batch_loss, batch_accuracy = test_step(x_test_batch, y_test_batch)final_test_loss += batch_lossfinal_test_accuracy += batch_accuracynum_final_batches += 1
final_avg_test_loss = final_test_loss / num_final_batches
final_avg_test_accuracy = final_test_accuracy / num_final_batchesprint("\nFinal Test Results:")
print(f"Test Loss: {final_avg_test_loss:.4f}, Test Accuracy: {final_avg_test_accuracy:.4f}")
2. PyTorch示例(复杂但可操作性强)
PyTorch的API设计更加直观,其实就跟python的编程一样,类似嵌套在python里面一样,类似从底层构建的一个深度学习网络,这虽然有一点点复杂,但是这使得整个搭建过程透明化可操作性极强,适合做研究的人,进行细节优化(魔改),进行底层控制。(复杂但可操作性强)
- 类似Python的编程风格:使用动态计算图,编码感觉更像普通Python编程。
- 面向对象的设计:模型定义为类,更符合Python用户的习惯。
- 即时执行:可以立即看到每一步的结果,便于调试。
- 灵活性:易于处理动态网络结构和复杂的研究型模型。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 加载数据
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(28 * 28, 128)self.fc2 = nn.Linear(128, 10)self.dropout = nn.Dropout(0.2)def forward(self, x):x = self.flatten(x)x = torch.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return xmodel = Net()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())# 训练模型
def train(model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 运行训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(1, 6):train(model, device, train_loader, optimizer, epoch)# 评估模型
model.eval()
correct = 0
with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()print(f'Accuracy: {correct / len(test_loader.dataset)}')
3. Keras示例(简单但没有什么可操作性)
- 高级API:Keras提供了非常简洁和直观的API,隐藏了许多底层复杂性。
- 模块化设计:可以轻松堆叠层来构建模型,如model.add(layer)。
- 内置常用模型:提供了许多预定义的架构,如Sequential模型。
- 一致的接口:无论后端如何(TensorFlow、Theano等),接口保持一致。
- 详细文档:有优秀的文档和大量教程。
比较简单,加载数据---构建模型(tf.keras自己叠就行)---编译模型(定义模型在训练过程中如何学习,使用什么优化器,使用什么损失函数评估模型性能,以及监控指标等)---训练模型---评估模型。
是不是非常的简单,结构清洗明了,确实在工程上是非常的适合的,搭建快速,便于部署
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist# 加载数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# 构建模型
model = keras.Sequential([layers.Flatten(input_shape=(28, 28)),layers.Dense(128, activation='relu'),layers.Dropout(0.2),layers.Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(x_train, y_train, epochs=5)# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f'Test accuracy: {test_acc}')
很明显:
- Keras适合快速原型设计和简单项目,学习曲线最平缓。
- PyTorch在研究和复杂模型开发中很受欢迎,因为它的直观性和灵活性。
- TensorFlow提供了从高级(Keras API)到低级的全方位控制,适合各种规模的项目,尤其是大规模部署。
这些示例展示了使用不同框架构建简单神经网络的基本步骤。每个框架都有其独特的语法和风格,但基本概念是相似的:
- 加载和预处理数据
- 定义模型结构
- 指定损失函数和优化器
- 训练模型
- 评估模型性能
需要注意的是,Keras现在是TensorFlow的一部分,所以TensorFlow和Keras的例子看起来非常相似。结合了keras的TensorFlow搭建DNN非常的简单(哥们当年用的LSTM就是用的TensorFlow搭建的),PyTorch的例子稍微复杂一些,因为它提供了更多的底层控制。(比如那个transformer的搭建就是依赖的PyTorch,比较复杂)
Final:个人eassy
对于我个人来讲,使用基于tensorflow的keras搭建我的网络确实是非常的简单便捷,这也比较适合我所做的研究方向,偏重于深度学习网络的应用研究,而不是执着于优化网络或者优化一些算法!哥们就是单纯一个套用别人开发的网络,然后自定义几层,设置一下优化器,损失函数啥的,调调参数,训练一下,出了效果,就行了。对优化这件事,哥们我是一点也不敢碰瓷的呀兄弟,这都是研究内容外的拓展学习了!!!也是我的小个人发展路径和能力提升的路子吧!
相关文章:

详解:Tensorflow、Pytorch、Keras(搭建自己的深度学习网络)
这是一个专门对Tensorflow、Pytorch、Keras三个主流DL框架的一个详解和对比分析 一、何为深度学习框架? 你可以理解为一个工具帮你构建一个深度学习网络,调用里面的各种方法就能自行构建任意层,diy你想要的DNN,而且任意指定学习…...

【CSS in Depth 2 精译_035】5.5 Grid 网格布局中的子网格布局(全新内容)
当前内容所在位置(可进入专栏查看其他译好的章节内容) 第一章 层叠、优先级与继承(已完结) 1.1 层叠1.2 继承1.3 特殊值1.4 简写属性1.5 CSS 渐进式增强技术1.6 本章小结 第二章 相对单位(已完结) 2.1 相对…...
Java是怎么处理死锁的
文章目录 避免死锁避免嵌套锁资源进行排序超时锁 检测死锁通过Java提供的API检查死锁情况jStack监控工具 Java 本身没有内置的机制自动处理死锁问题,但可以采取一些策略和技术来检测和避免死锁。 避免死锁 避免嵌套锁 尽可能减少嵌套锁操作,避免在一个…...
Effective Java 学习笔记 方法签名设计
目录 谨慎选择方法名称 不要过于追求提供便利的快捷方法 避免过长的参数列表 对于参数类型优先使用接口而不是类 对于boolean参数,要优先使用两个元素的枚举类型 本文接续前一篇文章聚焦Java方法签名的设计,方法签名包括了方法的输入和输出参数以及…...
毛利超70%、超70+智驾客户,这家AI数据训练服务商刚刚止亏
AI训练数据服务第一股海天瑞声终于迎来了“曙光”。 日前,海天瑞声发布2024年半年报显示,上半年其实现营收9242.63万,同比增长24.13%;实现净利润41.64 万元,不过同比去年同期的亏损1724.14万元,扭亏为盈。…...

本地部署高颜值某抑云音乐播放器Splayer并实现无公网IP远程听歌
💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...

图像压缩编码(4)--H.26x系列视频压缩编码_2
目录 H.261 视频编码标准 H.261的编码与解码 1) 帧内/帧间编码 2)运动补偿 3)量化 4)环路滤波器 5)缓存器 压缩数据的分层 数据复用结构 H.264的编码与解码 H.261 视频编码标准 实际应用时,要求有…...

JS渲染锻炼输入表单
前言 上篇文章为大家展现了好看的信息窗口,接下来我们跟着流程图看下一步 之前我们的带点击事件已经添加完毕,下一步就是当用户点击的时候,渲染锻炼形式,当然这是一个标签,可以提供给用户输入锻炼形式 实例 ● 我…...

proteus仿真学习(1)
一,创建工程 一般选择默认模式,不配置pcb文件 可以选用芯片型号也可以不选 不选则从零开始布局,没有初始最小系统。选用则有初始最小系统以及基础的main函数 本次学习使用从零开始,不配置固件 二,上手软件 1.在元件…...

决策树+随机森林模型实现足球大小球让球预测软件
文章目录 前言一、决策树是什么?二、数据收集与整理1.数据收集2.数据清洗3.特征选择 三、决策树构建3.1绘制训练数据图像3.2 训练决策树模型3.3 依据模型绘制决策树的决策边界3.4 树模型可视化 四、模型预测五、随机森林模型总结 前言 之前搞足球数据分析的时候&…...

31省市农业地图大数据
1.北京市 谷类作物种植结构(万亩) 农作物种植结构(万亩) 2.天津市 谷类作物种植结构(万亩) 农作物种植结构(万亩) 3.黑龙江省 谷类作物种植结构(万亩) 农作物…...
http请求包含什么
HTTP请求通常包含以下几个主要部分: 请求行(Request Line): 包含请求方法(如 GET、POST、PUT、DELETE 等)、请求的目标 URI 和 HTTP 版本。例如:GET /index.html HTTP/1.1 请求头部(…...

【基础算法总结】模拟篇
目录 一,算法介绍二,算法原理和代码实现1576.替换所有的问号495.提莫攻击6.Z字形变换38.外观数列1419.数青蛙 三,算法总结 一,算法介绍 模拟算法本质就是"依葫芦画瓢",就是在题目中已经告诉了我们该如何操作…...

《深度学习》PyTorch 手写数字识别 案例解析及实现 <下>
目录 一、回顾神经网络框架 1、单层神经网络 2、多层神经网络 二、手写数字识别 1、续接上节课代码,如下所示 2、建立神经网络模型 输出结果: 3、设置训练集 4、设置测试集 5、创建损失函数、优化器 参数解析: 1)para…...

【笔记】材料分析测试:晶体学
晶体与晶体结构Crystal and Crystal Structure 1.晶体主要特征 固态物质可以分为晶态和非晶态两大类,分别称为晶体和非晶体。 晶体和非晶体在微观结构上的区别在于是否具有长程有序。 晶体(长程有序)非晶(短程有序)…...
飞塔Fortigate7.4.4的DNS劫持功能
基础网络配置、上网策略、与Server的VIP配置(略)。 在FortiGate上配置DNS Translation,将DNS请求结果为202.103.12.2的DNS响应报文中的IP地址修改为Server的内网IP 10.10.2.100。 config firewall dnstranslationedit 1set src 2.13.12.2set…...

Unity 设计模式 之 行为型模式 -【状态模式】【观察者模式】【备忘录模式】
Unity 设计模式 之 行为型模式 -【状态模式】【观察者模式】【备忘录模式】 目录 Unity 设计模式 之 行为型模式 -【状态模式】【观察者模式】【备忘录模式】 一、简单介绍 二、状态模式(State Pattern) 1、什么时候使用状态模式 2、使用状态模式的…...

【RabbitMQ】RabbitMQ 的概念以及使用RabbitMQ编写生产者消费者代码
目录 1. RabbitMQ 核心概念 1.1生产者和消费者 1.2 Connection和Channel 1.3 Virtual host 1.4 Queue 1.5 Exchange 1.6 RabbitMO工作流程 2. AMQP 3.RabbitMO快速入门 3.1.引入依赖 3.2.编写生产者代码 3.3.编写消费者代码 4.源码 1. RabbitMQ 核心概念 在安装…...

openmv与stm32通信
控制小车视觉循迹使用 OpenMV 往往是不够的。一般使用 OpenMV 对图像进行处理,将处理过后的数据使用串口发送给STM32,使用STM32控制小车行驶。本文主要讲解 OpenMV 模块与 STM32 间的串口通信以及两种循迹方案,分别是划分检测区域和线性回归。…...

C++ STL全面解析:六大核心组件之一----序列式容器(vector和List)(STL进阶学习)
目录 序列式容器 Vector vector概述 vector的迭代器 vector的数据结构 vector的构造和内存管理 vector的元素操作 List List概述 List的设计结构 List的迭代器 List的数据结构 List的内存构造 List的元素操作 C标准模板库(STL)是一组高效的…...

树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法
树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作,无需更改相机配置。但是,一…...
ubuntu搭建nfs服务centos挂载访问
在Ubuntu上设置NFS服务器 在Ubuntu上,你可以使用apt包管理器来安装NFS服务器。打开终端并运行: sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享,例如/shared: sudo mkdir /shared sud…...
ssc377d修改flash分区大小
1、flash的分区默认分配16M、 / # df -h Filesystem Size Used Available Use% Mounted on /dev/root 1.9M 1.9M 0 100% / /dev/mtdblock4 3.0M...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)
服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...
系统设计 --- MongoDB亿级数据查询优化策略
系统设计 --- MongoDB亿级数据查询分表策略 背景Solution --- 分表 背景 使用audit log实现Audi Trail功能 Audit Trail范围: 六个月数据量: 每秒5-7条audi log,共计7千万 – 1亿条数据需要实现全文检索按照时间倒序因为license问题,不能使用ELK只能使用…...
镜像里切换为普通用户
如果你登录远程虚拟机默认就是 root 用户,但你不希望用 root 权限运行 ns-3(这是对的,ns3 工具会拒绝 root),你可以按以下方法创建一个 非 root 用户账号 并切换到它运行 ns-3。 一次性解决方案:创建非 roo…...
相机Camera日志分析之三十一:高通Camx HAL十种流程基础分析关键字汇总(后续持续更新中)
【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了:有对最普通的场景进行各个日志注释讲解,但相机场景太多,日志差异也巨大。后面将展示各种场景下的日志。 通过notepad++打开场景下的日志,通过下列分类关键字搜索,即可清晰的分析不同场景的相机运行流程差异…...

C# 求圆面积的程序(Program to find area of a circle)
给定半径r,求圆的面积。圆的面积应精确到小数点后5位。 例子: 输入:r 5 输出:78.53982 解释:由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982,因为我们只保留小数点后 5 位数字。 输…...

回溯算法学习
一、电话号码的字母组合 import java.util.ArrayList; import java.util.List;import javax.management.loading.PrivateClassLoader;public class letterCombinations {private static final String[] KEYPAD {"", //0"", //1"abc", //2"…...
iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈
在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...