深度学习框架:Pytorch与Keras的区别与使用方法
☁️主页 Nowl
🔥专栏《机器学习实战》 《机器学习》
📑君子坐而论道,少年起而行之
文章目录
Pytorch与Keras介绍
Pytorch
模型定义
模型编译
模型训练
输入格式
完整代码
Keras
模型定义
模型编译
模型训练
输入格式
完整代码
区别与使用场景
结语
Pytorch与Keras介绍
pytorch和keras都是一种深度学习框架,使我们能很便捷地搭建各种神经网络,但它们在使用上有一些区别,也各自有其特性,我们一起来看看吧
Pytorch
模型定义
我们以最简单的网络定义来学习pytorch的基本使用方法,我们接下来要定义一个神经网络,包括一个输入层,一个隐藏层,一个输出层,这些层都是线性的,给隐藏层添加一个激活函数Relu,给输出层添加一个Sigmoid函数
import torch
import torch.nn as nnclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(1, 32)self.relu = nn.ReLU()self.fc2 = nn.Linear(32, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.Sigmoid(x)return x
模型编译
我们在之前的机器学习文章中反复提到过,模型的训练是怎么进行的呢,要有一个损失函数与优化方法,我们接下来看看在pytorch中怎么定义这些
import torch.optim as optim# 实例化模型对象
model = SimpleNet()
# 定义损失函数
criterion = nn.MSELoss()# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
我们上面创建的神经网络是一个类,所以我们实例化一个对象model,然后定义损失函数为mse,优化器为随机梯度下降并设置学习率
模型训练
# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1)) # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1)) # 100个样本,每个样本有1个目标值# 训练模型
epochs = 100for epoch in range(epochs):# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target_data)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()
以上步骤是先创建了一些随机样本,作为模型的训练集,然后定义训练轮次为100次,然后前向传播数据集,计算损失,再优化,如此反复
输入格式
关于输入格式是很多人在实战中容易出现问题的,对于pytorch创建的神经网络,我们的输入内容是一个torch张量,怎么创建呢
data = torch.Tensor([[1], [2], [3]])
很简单对吧,上面这个例子创建了一个torch张量,有三组数据,每组数据有1个特征
我们可以把这个数据输入到训练好的模型中,得到输出结果,如果输出不是torch张量,代码就会报错
完整代码
import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的神经网络模型
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(1, 32)self.relu = nn.ReLU()self.fc2 = nn.Linear(32, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.sigmoid(x)return xmodel = SimpleNet()
criterion = nn.MSELoss()# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1)) # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1)) # 100个样本,每个样本有1个目标值# 训练模型
epochs = 100for epoch in range(epochs):# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target_data)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()data = torch.Tensor([[1], [2], [3]])
prediction = model(data)print(prediction)
可以看到模型输出了三个预测值
注意,这个任务本身没有意义,因为我们的训练集是随机生成的,这里主要学习框架的使用方法
Keras
我们在这里把和上面相同的神经网络结构使用keras框架实现一遍
模型定义
from keras.models import Sequential
from keras.layers import Densemodel = Sequential([Dense(32, input_dim=1, activation='relu'),Dense(1, activation='sigmoid')
])
注意这里也是一层输入层,一层隐藏层,一层输出层,和pytorch一样,输入层是隐式的,我们的输入数据就是输入层,上述代码定义了一个隐藏层,输入维度是1,输出维度是32,还定义了一个输出层,输入维度是32,输出维度是1,和pytorch环节的模型结构是一样的
模型编译
那么在Keras中模型又是怎么编译的呢
model.compile(loss='mse', optimizer='sgd')
非常简单,只需要这一行代码 ,设置损失函数为mse,优化器为随机梯度下降
模型训练
模型的训练也非常简单
# 训练模型
model.fit(input_data, target_data, epochs=100)
因为我们已经编译好了损失函数和优化器,在fit里只需要输入数据,输出数据和训练轮次这些参数就可以训练了
输入格式
对于Keras模型的输入,我们要把它转化为numpy数组,不然会报错
data = np.array([[1], [2], [3]])
完整代码
from keras.models import Sequential
from keras.layers import Dense
import numpy as np# 定义模型
model = Sequential([Dense(32, input_dim=1, activation='relu'),Dense(1, activation='sigmoid')
])# 创建随机输入数据和目标数据
input_data = np.random.randn(100, 1) # 100个样本,每个样本有10个特征
target_data = np.random.randn(100, 1) # 100个样本,每个样本有5个目标值# 编译模型
model.compile(loss='mse', optimizer='sgd')
# 训练模型
model.fit(input_data, target_data, epochs=10)data = np.array([[1], [2], [3]])prediction = model(data)
print(prediction)
可以看到,同样的任务,Keras的代码量小很多
区别与使用场景
Keras代码量少,使用便捷,适用于快速实验和快速神经网络设计
而pytorch由于结构是由类定义的,可以更加灵活地组建神经网络层,这对于要求细节的任务更有利,同时,pytorch还采用动态计算图,使得模型的结构可以在运行时根据输入数据动态调整,但这个特点我还没有接触到,之后可能会详细讲解
结语
Keras和Pytorch都各有各的优点,请读者根据需求选择,同时有些深度学习教程偏向于使用某一种框架,最好都学习一点,以适应不同的场景
感谢阅读,觉得有用的话就订阅下本专栏吧
相关文章:

深度学习框架:Pytorch与Keras的区别与使用方法
☁️主页 Nowl 🔥专栏《机器学习实战》 《机器学习》 📑君子坐而论道,少年起而行之 文章目录 Pytorch与Keras介绍 Pytorch 模型定义 模型编译 模型训练 输入格式 完整代码 Keras 模型定义 模型编译 模型训练 输入格式 完整代…...

1145. 北极通讯网络(Kruskal,并查集维护)
北极的某区域共有 n 座村庄,每座村庄的坐标用一对整数 (x,y) 表示。 为了加强联系,决定在村庄之间建立通讯网络,使每两座村庄之间都可以直接或间接通讯。 通讯工具可以是无线电收发机,也可以是卫星设备。 无线电收发机有多种不…...

【23-24 秋学期】NNDL 作业9 RNN - SRN
简单循环网络(Simple Recurrent Network,SRN)只有一个隐藏层的神经网络. 目录 1. 实现SRN (1)使用Numpy (2)在1的基础上,增加激活函数tanh (3࿰…...

Docker + Jenkins + Nginx实现前端自动化部署
目录 前言一、前期准备工作1、示例环境2、安装docker3、安装Docker Compose4、安装Git5、安装Nginx和Jenkinsnginx.confdocker-compose.yml 6、启动环境7、验证Nginx8、验证Jenkins 二、Jenkins 自动化部署配置1、设置中文2、安装Publish Over SSH、NodeJS(1&#x…...

文生视频的发展史及其原理解析:从Gen2、Emu Video到PixelDance、SVD、Pika 1.0
前言 考虑到文生视频开始爆发,比如11月份就是文生视频最火爆的一个月 11月3日,Runway的Gen-2发布里程碑式更新,支持4K超逼真的清晰度作品(runway是Stable Diffusion最早版本的开发商,Stability AI则开发的SD后续版本)11月16日&a…...

【python+Excel】读取和存储测试数据完成接口自动化测试
http_request2.py用于发起http请求 #读取多条测试用例 #1、导入requests模块 import requests #从 class_12_19.do_excel1导入read_data函数 from do_excel2 import read_data from do_excel2 import write_data from do_excel2 import count_case #定义http请求函数COOKIENon…...

WordPress插件大全-免费的WordPress插件汇总
随着互联网的不断发展,网站建设变得日益普及。对于大多数人而言,WordPress是一个熟悉且易于使用的网站建设平台。然而,有时候我们可能会觉得WordPress的功能还不够满足我们的需求,这时候,插件就成了解决问题的得力工具…...

STM32通讯设计
STM32通讯设计 通讯流程STM32程序 通讯流程 1.使用HT2202芯片配置为主机接收(轮询模式)。 2.将STM32芯片配置为从机发送,中断模式下发送固定数据。 3.如果HT2202芯片能够收到STM32发送的数据则通讯成功,否则通讯失败。 STM32程序…...

外汇天眼:在QOINTEC投资需缴纳分成费才给出金?这合理么?
一般来说,在正规的平台上申请出金是不需要缴纳什么费用的,除非有一些特殊情况,像低额出金、没有交易就申请出金等情况下,或许会让你缴纳一定的手续费或者隔夜利息费等(不同的平台有不同的规则),…...

C_8练习题
一、单项选择题(本大题共20小题,每小题2分,共40分。在每小题给出的四个备选项中选出一个正确的答案,并将所选项前的字母填写在答题纸的相应位置上。) 1,在每个C语言程序中都必须包含有这样一个函数,该函数的函数名为() A. main B. MAIN C.name D. function 以下正确…...
HuggingFace学习笔记--Tokenizer的使用
1--AutoTokenizer的使用 官方文档 AutoTokenizer() 常用于分词,其可调用现成的模型来对输入句子进行分词。 1-1--简单Demo 测试代码: # 分词器测试Demo from transformers import AutoTokenizerif __name__ "__main__":checkpoint "…...

解决苹果手机iphone手机强制重启
强制关机: 方法1.同时按住左侧的,- 键中的一个和右侧的电源键 方法2.点击桌面的悬浮键–设备–更多–重新启动...

10分钟的时间,带你彻底搞懂JavaScript数据类型转换
前言 📫 大家好,我是南木元元,热衷分享有趣实用的文章,希望大家多多支持,一起进步! 🍅 个人主页:南木元元 目录 JS数据类型 3种转换类型 ToBoolean ToString ToNumber 对象转原…...

好用的chatgpt工具用过这个比较快
chatgpthttps://www.askchat.ai?r237422 chatGPT能做什么 1. 对话和聊天:我可以与您进行对话和聊天,回答您的问题、提供信息和建议。 2. 问题回答:无论是关于事实、历史、科学、文化、地理还是其他领域的问题,我都可以尽力回答…...

系统设计概念:生产 Web 应用的架构
在你使用的每个完美应用程序背后,都有一整套的架构、测试、监控和安全措施。今天,让我们来看看一个生产就绪应用程序的非常高层次的架构。 CI/CD 管道 我们的第一个关键领域是持续集成和持续部署——CI/CD 管道。 这确保我们的代码从存储库经过一系列测试…...

基于docker的onlyoffice使用--运行JavaSpringExample
背景 我之前看到有开源项目很好地集成了onlyoffice,效果要比kkfilepreview好(应当说应用场景不太一样)。本文是在window10环境,安装完Docker Desktop的基础上运行onlyoffice,并利用官网JavaSpringExample进行了集成。 …...

SQL server-excel数据追加到表
参考文章:SQL server 2019 从Excel导入数据_mssql2019 导入excel数据-CSDN博客 将excel数据导入到SQL server数据库的详细过程 注意:第一行数据默认为数据库表中的字段,所以这个必须要有,否则无法映射导入 问题1:ADD…...

深度学习-模型调试经验总结
1、 这句话的意思是:期望张量的后端处理是在cpu上,但是实际是在cuda上。排查代码发现,数据还在cpu上,但是模型已经转到cuda上,所以可以通过把数据转到cuda上解决。 解决代码: tensor.to("cuda")…...

Redis打包事务,分批提交
一、需求背景 接手一个老项目,在项目启动的时候,需要将xxx省整个省的所有区域数据数据、以及系统字典配置逐条保存在Redis缓存里面,这样查询的时候会更快; 区域数据字典数据一共大概20000多条,,前同事直接使用 list.forEach…...

深度学习毕设项目 深度学习 python opencv 动物识别与检测
文章目录 0 前言1 深度学习实现动物识别与检测2 卷积神经网络2.1卷积层2.2 池化层2.3 激活函数2.4 全连接层2.5 使用tensorflow中keras模块实现卷积神经网络 3 YOLOV53.1 网络架构图3.2 输入端3.3 基准网络3.4 Neck网络3.5 Head输出层 4 数据集准备4.1 数据标注简介4.2 数据保存…...

黑马Mybatis
Mybatis 表现层:页面展示 业务层:逻辑处理 持久层:持久数据化保存 在这里插入图片描述 Mybatis快速入门 
江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命
在华东塑料包装行业面临限塑令深度调整的背景下,江苏艾立泰以一场跨国资源接力的创新实践,重新定义了绿色供应链的边界。 跨国回收网络:废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点,将海外废弃包装箱通过标准…...

智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制
在数字化浪潮席卷全球的今天,数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具,在大规模数据获取中发挥着关键作用。然而,传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时,常出现数据质…...

使用 SymPy 进行向量和矩阵的高级操作
在科学计算和工程领域,向量和矩阵操作是解决问题的核心技能之一。Python 的 SymPy 库提供了强大的符号计算功能,能够高效地处理向量和矩阵的各种操作。本文将深入探讨如何使用 SymPy 进行向量和矩阵的创建、合并以及维度拓展等操作,并通过具体…...

如何在网页里填写 PDF 表格?
有时候,你可能希望用户能在你的网站上填写 PDF 表单。然而,这件事并不简单,因为 PDF 并不是一种原生的网页格式。虽然浏览器可以显示 PDF 文件,但原生并不支持编辑或填写它们。更糟的是,如果你想收集表单数据ÿ…...
Python 包管理器 uv 介绍
Python 包管理器 uv 全面介绍 uv 是由 Astral(热门工具 Ruff 的开发者)推出的下一代高性能 Python 包管理器和构建工具,用 Rust 编写。它旨在解决传统工具(如 pip、virtualenv、pip-tools)的性能瓶颈,同时…...

AI病理诊断七剑下天山,医疗未来触手可及
一、病理诊断困局:刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断",医生需通过显微镜观察组织切片,在细胞迷宫中捕捉癌变信号。某省病理质控报告显示,基层医院误诊率达12%-15%,专家会诊…...

C++ 设计模式 《小明的奶茶加料风波》
👨🎓 模式名称:装饰器模式(Decorator Pattern) 👦 小明最近上线了校园奶茶配送功能,业务火爆,大家都在加料: 有的同学要加波霸 🟤,有的要加椰果…...

springboot 日志类切面,接口成功记录日志,失败不记录
springboot 日志类切面,接口成功记录日志,失败不记录 自定义一个注解方法 import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target;/***…...

《Docker》架构
文章目录 架构模式单机架构应用数据分离架构应用服务器集群架构读写分离/主从分离架构冷热分离架构垂直分库架构微服务架构容器编排架构什么是容器,docker,镜像,k8s 架构模式 单机架构 单机架构其实就是应用服务器和单机服务器都部署在同一…...