深度学习框架:Pytorch与Keras的区别与使用方法

☁️主页 Nowl
🔥专栏《机器学习实战》 《机器学习》
📑君子坐而论道,少年起而行之

文章目录
Pytorch与Keras介绍
Pytorch
模型定义
模型编译
模型训练
输入格式
完整代码
Keras
模型定义
模型编译
模型训练
输入格式
完整代码
区别与使用场景
结语
Pytorch与Keras介绍

pytorch和keras都是一种深度学习框架,使我们能很便捷地搭建各种神经网络,但它们在使用上有一些区别,也各自有其特性,我们一起来看看吧
Pytorch

模型定义
我们以最简单的网络定义来学习pytorch的基本使用方法,我们接下来要定义一个神经网络,包括一个输入层,一个隐藏层,一个输出层,这些层都是线性的,给隐藏层添加一个激活函数Relu,给输出层添加一个Sigmoid函数
import torch
import torch.nn as nnclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(1, 32)self.relu = nn.ReLU()self.fc2 = nn.Linear(32, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.Sigmoid(x)return x
模型编译
我们在之前的机器学习文章中反复提到过,模型的训练是怎么进行的呢,要有一个损失函数与优化方法,我们接下来看看在pytorch中怎么定义这些
import torch.optim as optim# 实例化模型对象
model = SimpleNet()
# 定义损失函数
criterion = nn.MSELoss()# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
我们上面创建的神经网络是一个类,所以我们实例化一个对象model,然后定义损失函数为mse,优化器为随机梯度下降并设置学习率
模型训练
# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1)) # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1)) # 100个样本,每个样本有1个目标值# 训练模型
epochs = 100for epoch in range(epochs):# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target_data)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()
以上步骤是先创建了一些随机样本,作为模型的训练集,然后定义训练轮次为100次,然后前向传播数据集,计算损失,再优化,如此反复
输入格式
关于输入格式是很多人在实战中容易出现问题的,对于pytorch创建的神经网络,我们的输入内容是一个torch张量,怎么创建呢
data = torch.Tensor([[1], [2], [3]])
很简单对吧,上面这个例子创建了一个torch张量,有三组数据,每组数据有1个特征
我们可以把这个数据输入到训练好的模型中,得到输出结果,如果输出不是torch张量,代码就会报错
完整代码
import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的神经网络模型
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(1, 32)self.relu = nn.ReLU()self.fc2 = nn.Linear(32, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.sigmoid(x)return xmodel = SimpleNet()
criterion = nn.MSELoss()# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1)) # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1)) # 100个样本,每个样本有1个目标值# 训练模型
epochs = 100for epoch in range(epochs):# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target_data)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()data = torch.Tensor([[1], [2], [3]])
prediction = model(data)print(prediction)

可以看到模型输出了三个预测值
注意,这个任务本身没有意义,因为我们的训练集是随机生成的,这里主要学习框架的使用方法
Keras

我们在这里把和上面相同的神经网络结构使用keras框架实现一遍
模型定义
from keras.models import Sequential
from keras.layers import Densemodel = Sequential([Dense(32, input_dim=1, activation='relu'),Dense(1, activation='sigmoid')
])
注意这里也是一层输入层,一层隐藏层,一层输出层,和pytorch一样,输入层是隐式的,我们的输入数据就是输入层,上述代码定义了一个隐藏层,输入维度是1,输出维度是32,还定义了一个输出层,输入维度是32,输出维度是1,和pytorch环节的模型结构是一样的
模型编译
那么在Keras中模型又是怎么编译的呢
model.compile(loss='mse', optimizer='sgd')
非常简单,只需要这一行代码 ,设置损失函数为mse,优化器为随机梯度下降
模型训练
模型的训练也非常简单
# 训练模型
model.fit(input_data, target_data, epochs=100)
因为我们已经编译好了损失函数和优化器,在fit里只需要输入数据,输出数据和训练轮次这些参数就可以训练了
输入格式
对于Keras模型的输入,我们要把它转化为numpy数组,不然会报错
data = np.array([[1], [2], [3]])
完整代码
from keras.models import Sequential
from keras.layers import Dense
import numpy as np# 定义模型
model = Sequential([Dense(32, input_dim=1, activation='relu'),Dense(1, activation='sigmoid')
])# 创建随机输入数据和目标数据
input_data = np.random.randn(100, 1) # 100个样本,每个样本有10个特征
target_data = np.random.randn(100, 1) # 100个样本,每个样本有5个目标值# 编译模型
model.compile(loss='mse', optimizer='sgd')
# 训练模型
model.fit(input_data, target_data, epochs=10)data = np.array([[1], [2], [3]])prediction = model(data)
print(prediction)
可以看到,同样的任务,Keras的代码量小很多
区别与使用场景
Keras代码量少,使用便捷,适用于快速实验和快速神经网络设计
而pytorch由于结构是由类定义的,可以更加灵活地组建神经网络层,这对于要求细节的任务更有利,同时,pytorch还采用动态计算图,使得模型的结构可以在运行时根据输入数据动态调整,但这个特点我还没有接触到,之后可能会详细讲解
结语
Keras和Pytorch都各有各的优点,请读者根据需求选择,同时有些深度学习教程偏向于使用某一种框架,最好都学习一点,以适应不同的场景

感谢阅读,觉得有用的话就订阅下本专栏吧
相关文章:
深度学习框架:Pytorch与Keras的区别与使用方法
☁️主页 Nowl 🔥专栏《机器学习实战》 《机器学习》 📑君子坐而论道,少年起而行之 文章目录 Pytorch与Keras介绍 Pytorch 模型定义 模型编译 模型训练 输入格式 完整代码 Keras 模型定义 模型编译 模型训练 输入格式 完整代…...
1145. 北极通讯网络(Kruskal,并查集维护)
北极的某区域共有 n 座村庄,每座村庄的坐标用一对整数 (x,y) 表示。 为了加强联系,决定在村庄之间建立通讯网络,使每两座村庄之间都可以直接或间接通讯。 通讯工具可以是无线电收发机,也可以是卫星设备。 无线电收发机有多种不…...
【23-24 秋学期】NNDL 作业9 RNN - SRN
简单循环网络(Simple Recurrent Network,SRN)只有一个隐藏层的神经网络. 目录 1. 实现SRN (1)使用Numpy (2)在1的基础上,增加激活函数tanh (3࿰…...
Docker + Jenkins + Nginx实现前端自动化部署
目录 前言一、前期准备工作1、示例环境2、安装docker3、安装Docker Compose4、安装Git5、安装Nginx和Jenkinsnginx.confdocker-compose.yml 6、启动环境7、验证Nginx8、验证Jenkins 二、Jenkins 自动化部署配置1、设置中文2、安装Publish Over SSH、NodeJS(1&#x…...
文生视频的发展史及其原理解析:从Gen2、Emu Video到PixelDance、SVD、Pika 1.0
前言 考虑到文生视频开始爆发,比如11月份就是文生视频最火爆的一个月 11月3日,Runway的Gen-2发布里程碑式更新,支持4K超逼真的清晰度作品(runway是Stable Diffusion最早版本的开发商,Stability AI则开发的SD后续版本)11月16日&a…...
【python+Excel】读取和存储测试数据完成接口自动化测试
http_request2.py用于发起http请求 #读取多条测试用例 #1、导入requests模块 import requests #从 class_12_19.do_excel1导入read_data函数 from do_excel2 import read_data from do_excel2 import write_data from do_excel2 import count_case #定义http请求函数COOKIENon…...
WordPress插件大全-免费的WordPress插件汇总
随着互联网的不断发展,网站建设变得日益普及。对于大多数人而言,WordPress是一个熟悉且易于使用的网站建设平台。然而,有时候我们可能会觉得WordPress的功能还不够满足我们的需求,这时候,插件就成了解决问题的得力工具…...
STM32通讯设计
STM32通讯设计 通讯流程STM32程序 通讯流程 1.使用HT2202芯片配置为主机接收(轮询模式)。 2.将STM32芯片配置为从机发送,中断模式下发送固定数据。 3.如果HT2202芯片能够收到STM32发送的数据则通讯成功,否则通讯失败。 STM32程序…...
外汇天眼:在QOINTEC投资需缴纳分成费才给出金?这合理么?
一般来说,在正规的平台上申请出金是不需要缴纳什么费用的,除非有一些特殊情况,像低额出金、没有交易就申请出金等情况下,或许会让你缴纳一定的手续费或者隔夜利息费等(不同的平台有不同的规则),…...
C_8练习题
一、单项选择题(本大题共20小题,每小题2分,共40分。在每小题给出的四个备选项中选出一个正确的答案,并将所选项前的字母填写在答题纸的相应位置上。) 1,在每个C语言程序中都必须包含有这样一个函数,该函数的函数名为() A. main B. MAIN C.name D. function 以下正确…...
HuggingFace学习笔记--Tokenizer的使用
1--AutoTokenizer的使用 官方文档 AutoTokenizer() 常用于分词,其可调用现成的模型来对输入句子进行分词。 1-1--简单Demo 测试代码: # 分词器测试Demo from transformers import AutoTokenizerif __name__ "__main__":checkpoint "…...
解决苹果手机iphone手机强制重启
强制关机: 方法1.同时按住左侧的,- 键中的一个和右侧的电源键 方法2.点击桌面的悬浮键–设备–更多–重新启动...
10分钟的时间,带你彻底搞懂JavaScript数据类型转换
前言 📫 大家好,我是南木元元,热衷分享有趣实用的文章,希望大家多多支持,一起进步! 🍅 个人主页:南木元元 目录 JS数据类型 3种转换类型 ToBoolean ToString ToNumber 对象转原…...
好用的chatgpt工具用过这个比较快
chatgpthttps://www.askchat.ai?r237422 chatGPT能做什么 1. 对话和聊天:我可以与您进行对话和聊天,回答您的问题、提供信息和建议。 2. 问题回答:无论是关于事实、历史、科学、文化、地理还是其他领域的问题,我都可以尽力回答…...
系统设计概念:生产 Web 应用的架构
在你使用的每个完美应用程序背后,都有一整套的架构、测试、监控和安全措施。今天,让我们来看看一个生产就绪应用程序的非常高层次的架构。 CI/CD 管道 我们的第一个关键领域是持续集成和持续部署——CI/CD 管道。 这确保我们的代码从存储库经过一系列测试…...
基于docker的onlyoffice使用--运行JavaSpringExample
背景 我之前看到有开源项目很好地集成了onlyoffice,效果要比kkfilepreview好(应当说应用场景不太一样)。本文是在window10环境,安装完Docker Desktop的基础上运行onlyoffice,并利用官网JavaSpringExample进行了集成。 …...
SQL server-excel数据追加到表
参考文章:SQL server 2019 从Excel导入数据_mssql2019 导入excel数据-CSDN博客 将excel数据导入到SQL server数据库的详细过程 注意:第一行数据默认为数据库表中的字段,所以这个必须要有,否则无法映射导入 问题1:ADD…...
深度学习-模型调试经验总结
1、 这句话的意思是:期望张量的后端处理是在cpu上,但是实际是在cuda上。排查代码发现,数据还在cpu上,但是模型已经转到cuda上,所以可以通过把数据转到cuda上解决。 解决代码: tensor.to("cuda")…...
Redis打包事务,分批提交
一、需求背景 接手一个老项目,在项目启动的时候,需要将xxx省整个省的所有区域数据数据、以及系统字典配置逐条保存在Redis缓存里面,这样查询的时候会更快; 区域数据字典数据一共大概20000多条,,前同事直接使用 list.forEach…...
深度学习毕设项目 深度学习 python opencv 动物识别与检测
文章目录 0 前言1 深度学习实现动物识别与检测2 卷积神经网络2.1卷积层2.2 池化层2.3 激活函数2.4 全连接层2.5 使用tensorflow中keras模块实现卷积神经网络 3 YOLOV53.1 网络架构图3.2 输入端3.3 基准网络3.4 Neck网络3.5 Head输出层 4 数据集准备4.1 数据标注简介4.2 数据保存…...
Lombok 的 @Data 注解失效,未生成 getter/setter 方法引发的HTTP 406 错误
HTTP 状态码 406 (Not Acceptable) 和 500 (Internal Server Error) 是两类完全不同的错误,它们的含义、原因和解决方法都有显著区别。以下是详细对比: 1. HTTP 406 (Not Acceptable) 含义: 客户端请求的内容类型与服务器支持的内容类型不匹…...
从WWDC看苹果产品发展的规律
WWDC 是苹果公司一年一度面向全球开发者的盛会,其主题演讲展现了苹果在产品设计、技术路线、用户体验和生态系统构建上的核心理念与演进脉络。我们借助 ChatGPT Deep Research 工具,对过去十年 WWDC 主题演讲内容进行了系统化分析,形成了这份…...
QMC5883L的驱动
简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...
基于Flask实现的医疗保险欺诈识别监测模型
基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施,由雇主和个人按一定比例缴纳保险费,建立社会医疗保险基金,支付雇员医疗费用的一种医疗保险制度, 它是促进社会文明和进步的…...
渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止
<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet: https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...
Spring Boot面试题精选汇总
🤟致敬读者 🟩感谢阅读🟦笑口常开🟪生日快乐⬛早点睡觉 📘博主相关 🟧博主信息🟨博客首页🟫专栏推荐🟥活动信息 文章目录 Spring Boot面试题精选汇总⚙️ **一、核心概…...
k8s业务程序联调工具-KtConnect
概述 原理 工具作用是建立了一个从本地到集群的单向VPN,根据VPN原理,打通两个内网必然需要借助一个公共中继节点,ktconnect工具巧妙的利用k8s原生的portforward能力,简化了建立连接的过程,apiserver间接起到了中继节…...
安卓基础(aar)
重新设置java21的环境,临时设置 $env:JAVA_HOME "D:\Android Studio\jbr" 查看当前环境变量 JAVA_HOME 的值 echo $env:JAVA_HOME 构建ARR文件 ./gradlew :private-lib:assembleRelease 目录是这样的: MyApp/ ├── app/ …...
LabVIEW双光子成像系统技术
双光子成像技术的核心特性 双光子成像通过双低能量光子协同激发机制,展现出显著的技术优势: 深层组织穿透能力:适用于活体组织深度成像 高分辨率观测性能:满足微观结构的精细研究需求 低光毒性特点:减少对样本的损伤…...
在 Spring Boot 中使用 JSP
jsp? 好多年没用了。重新整一下 还费了点时间,记录一下。 项目结构: pom: <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://ww…...
