当前位置: 首页 > news >正文

3.PyTorch——常用神经网络层

import numpy as np
import pandas as pd
import torch as t
from PIL import Image
from torchvision.transforms import ToTensor, ToPILImaget.__version__
'2.1.1'

3.1 图像相关层

图像相关层主要包括卷积层(Conv)、池化层(Pool)等,这些层在实际使用中可分为一维(1D)、二维(2D)、三维(3D),池化方式又分为平均池化(AvgPool)、最大值池化(MaxPool)、自适应池化(AdaptiveAvgPool)等。而卷积层除了常用的前向卷积之外,还有逆卷积(TransposeConv)。

除了这里的使用,图像的卷积操作还有各种变体,具体可以参照此处动图[^2]介绍。 [^2]: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md

to_tensor = ToTensor()
to_pil = ToPILImage()
lena = Image.open('imgs/lena.png')
lena

在这里插入图片描述

# layer对输入形状都有假设:输入的不是单个数据,而是一个batch。
# 这里输入一个数据,就必须调用tensor.unsqueeze(0)增加一个维度,伪装成batch_size=1的batch
input = to_tensor(lena).unsqueeze(0)# 锐化卷积核
kernel = t.ones(3, 3) / -9
kernel[1][1] = 1
conv = t.nn.Conv2d(1, 1, (3, 3), 1, bias=False)
conv.weight.data = kernel.view(1, 1, 3, 3)out = conv(input)
to_pil(out.data.squeeze(0))

在这里插入图片描述

池化层:可视为一种特殊的卷积层,用来下采样。注意池化层是没有可学习参数的,其weight是固定的。

pool = t.nn.AvgPool2d(2, 2)
out = pool(input)
to_pil(out.data.squeeze(0))

在这里插入图片描述

除了卷积层和池化层,深度学习中还将常用到以下几个层:

  • Linear:全连接层。
  • BatchNorm:批规范化层,分为1D、2D和3D。除了标准的BatchNorm之外,还有在风格迁移中常用到的InstanceNorm层。
  • Dropout:dropout层,用来防止过拟合,同样分为1D、2D和3D。
    下面通过例子来说明它们的使用。
# 输入batch_size=2, 维度3
input = t.rand(2,3)
linear = t.nn.Linear(3, 4)
h = linear(input)
h
tensor([[-0.2314, -0.2245,  0.0966,  0.7610],[-0.2679, -0.2403,  0.0086,  0.5799]], grad_fn=<AddmmBackward0>)
# 4 channel,初始化标准差4,均值0
bn = t.nn.BatchNorm1d(4)
bn.weight.data = t.ones(4) * 4
bn.bias.data = t.zeros(4)bn_out = bn(h)
print(bn_out)
bn_out.mean(), bn_out.var(0, unbiased=False)       # 由于计算无偏方差分母会减1, 使用unbiased=1分母不减一
tensor([[ 3.9415,  3.7136,  3.9897,  3.9976],[-3.9415, -3.7136, -3.9897, -3.9976]],grad_fn=<NativeBatchNormBackward0>)(tensor(-1.8775e-06, grad_fn=<MeanBackward0>),tensor([15.5355, 13.7908, 15.9179, 15.9805], grad_fn=<VarBackward0>))
# 每个元素以0.5的概率舍弃
dropout = t.nn.Dropout(0.5)
o = dropout(bn_out)
o           
tensor([[ 7.8830,  7.4272,  0.0000,  7.9951],[-7.8830, -7.4272, -7.9794, -7.9951]], grad_fn=<MulBackward0>)

3.2 激活函数

PyTorch实现了常见的激活函数,其具体的接口信息可参见官方文档1,这些激活函数可作为独立的layer使用。这里将介绍最常用的激活函数ReLU,其数学表达式为:
R e L U ( x ) = m a x ( 0 , x ) ReLU(x)=max(0,x) ReLU(x)=max(0,x)

ReLU函数有个inplace参数,如果设为True,它会把输出直接覆盖到输入中,这样可以节省内存/显存。之所以可以覆盖是因为在计算ReLU的反向传播时,只需根据输出就能够推算出反向传播的梯度。但是只有少数的autograd操作支持inplace操作(如tensor.sigmoid_()),除非你明确地知道自己在做什么,否则一般不要使用inplace操作。

relu = t.nn.ReLU(inplace=True)
input = t.randn(2, 3)
print(input)
output = relu(input)
print(output)        # 负数都被截断为0
tensor([[-0.4064, -0.1886,  0.4812],[ 0.8996, -0.3606,  0.6127]])
tensor([[0.0000, 0.0000, 0.4812],[0.8996, 0.0000, 0.6127]])

对于此类网络如果每次都写复杂的forward函数会有些麻烦,在此就有两种简化方式,ModuleList和Sequential。其中Sequential是一个特殊的module,它包含几个子Module,前向传播时会将输入一层接一层的传递下去。ModuleList也是一个特殊的module,可以包含几个子module,可以像用list一样使用它,但不能直接把输入传给ModuleList。下面举例说明。

# Sequential
net = t.nn.Sequential(t.nn.Conv2d(3, 3, 3),t.nn.BatchNorm2d(3),t.nn.ReLU()
)
print('net:', net)
net: Sequential((0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU()
)
# 可根据名字或序号取出module
net[2]
ReLU()
input = t.randn(1, 3, 4, 4)
output = net(input)
output
tensor([[[[1.2239, 0.0000],[0.0000, 0.6354]],[[0.1855, 0.0000],[0.7218, 0.7777]],[[1.3686, 0.0000],[0.4861, 0.0000]]]], grad_fn=<ReluBackward0>)
# modellist
modellist = t.nn.ModuleList([t.nn.Linear(3, 4), t.nn.ReLU(), t.nn.Linear(4, 2)])
input = t.randn(1, 3)for model in modellist:input = model(input)print(input)
tensor([[-0.1817,  0.3852,  1.3656, -0.5643]], grad_fn=<AddmmBackward0>)
tensor([[0.0000, 0.3852, 1.3656, 0.0000]], grad_fn=<ReluBackward0>)
tensor([[-0.0151, -0.0309]], grad_fn=<AddmmBackward0>)

3.3 RNN循环神经网络

关于RNN的基础知识,推荐阅读colah的文章2入门。PyTorch中实现了如今最常用的三种RNN:RNN(vanilla RNN)、LSTM和GRU。此外还有对应的三种RNNCell。

RNN和RNNCell层的区别在于前者一次能够处理整个序列,而后者一次只处理序列中一个时间点的数据,前者封装更完备更易于使用,后者更具灵活性。实际上RNN层的一种后端实现方式就是调用RNNCell来实现的。

t.manual_seed(1000)
# 输入:batch_size=3, 序列长度为2,序列中每个元素占4维
input = t.randn(2, 3, 4)
# lstm输入向量4维,隐藏元3. 1层
lstm = t.nn.LSTM(4, 3, 1)
# 初始状态:1层,batch_size=3, 3个隐藏元
h0 = t.randn(1, 3, 3)
c0 = t.randn(1, 3, 3)
out, hn = lstm(input, (h0, c0))
out
tensor([[[-0.3610, -0.1643,  0.1631],[-0.0613, -0.4937, -0.1642],[ 0.5080, -0.4175,  0.2502]],[[-0.0703, -0.0393, -0.0429],[ 0.2085, -0.3005, -0.2686],[ 0.1482, -0.4728,  0.1425]]], grad_fn=<MkldnnRnnLayerBackward0>)
t.manual_seed(1000)
input = t.randn(2, 3, 4)
# 一个LSTMCell对应的层数只能是一层
lstm = t.nn.LSTMCell(4, 3)
hx = t.randn(3, 3)
cx = t.randn(3, 3)
out = []
for i_ in input:hx, cx=lstm(i_, (hx, cx))out.append(hx)
t.stack(out)
tensor([[[-0.3610, -0.1643,  0.1631],[-0.0613, -0.4937, -0.1642],[ 0.5080, -0.4175,  0.2502]],[[-0.0703, -0.0393, -0.0429],[ 0.2085, -0.3005, -0.2686],[ 0.1482, -0.4728,  0.1425]]], grad_fn=<StackBackward0>)

3.4 损失函数

损失函数可看作是一种特殊的layer,PyTorch也将这些损失函数实现为nn.Module的子类。然而在实际使用中通常将这些loss function专门提取出来,和主模型互相独立。详细的loss使用请参照文档3,这里以分类中最常用的交叉熵损失CrossEntropyloss为例说明。

# batch_size = 3, 计算对应每个类别的分数(只有两个类别)
score = t.randn(3, 2)
# 三个样本分别属于1, 0, 1类,label必须是LongTensor
label = t.Tensor([1, 0, 1]).long()# loss与普通的layer无差异
criterion = t.nn.CrossEntropyLoss()
loss = criterion(score, label)
loss
tensor(1.8772)

  1. http://pytorch.org/docs/nn.html#non-linear-activations ↩︎

  2. http://colah.github.io/posts/2015-08-Understanding-LSTMs/ ↩︎

  3. http://pytorch.org/docs/nn.html#loss-functions ↩︎

相关文章:

3.PyTorch——常用神经网络层

import numpy as np import pandas as pd import torch as t from PIL import Image from torchvision.transforms import ToTensor, ToPILImaget.__version__2.1.13.1 图像相关层 图像相关层主要包括卷积层&#xff08;Conv&#xff09;、池化层&#xff08;Pool&#xff09;…...

状态机的练习:按键控制led灯

设计思路&#xff1a; 三个按键控制led输出。 三个按键经过滤波(消抖)&#xff0c;产生三个按键标志信号。 三个led数据的产生模块&#xff08;流水&#xff0c;跑马&#xff0c;闪烁模块&#xff09;&#xff0c;分别产生led信号。 这六路信号&#xff08;三路按键信号&am…...

看图学源码之 CopyOnWriteArraySet源码分析

基本介绍 使用内部CopyOnWriteArrayList进行所有操作的Set 特点 它最适合以下应用程序&#xff1a;集合大小通常较小、只读操作的数量远远多于可变操作&#xff0c;并且您需要在遍历期间防止线程之间的干扰。它是线程安全的。突变操作&#xff08; add 、 set 、 remove等&…...

almaLinux centos8 下载ffmpeg离线安装包、离线安装

脚本 # 添加RPMfusion仓库 sudo yum install https://download1.rpmfusion.org/free/el/rpmfusion-free-release-8.noarch.rpm wget -ymkdir -p /root/ffmpeg cd /root/ffmpegwget http://rpmfind.net/linux/epel/7/x86_64/Packages/s/SDL2-2.0.14-2.el7.x86_64.rpmyum instal…...

CSS3 属性: transition过渡 与 transform动画

CSS3 提供了很多强大的功能&#xff0c;使开发人员可以创建更加吸引人的视觉效果&#xff0c;而不需要依赖于 JavaScript 或 Flash。其中&#xff0c;transition 和 transform 是两个常用的属性&#xff0c;它们分别用于创建平滑的过渡效果和元素的变形效果。下面我们将详细介绍…...

TCP通讯

第二十一章 网络通信 本章节主要讲解的是TCP和UDP两种通信方式它们都有着自己的优点和缺点 这两种通讯方式不通的地方就是TCP是一对一通信 UDP是一对多的通信方式 接下来会一一讲解 TCP通信 TCP通信方式呢 主要的通讯方式是一对一的通讯方式&#xff0c;也有着优点和缺点 …...

(NeRF学习)3D Gaussian Splatting Instant-NGP

学习参考&#xff1a; 3D Gaussian Splatting入门指南【五分钟学会渲染自己的NeRF模型&#xff0c;有手就行&#xff01;】 三维重建instant-ngp环境部署与colmap、ffmpeg的脚本参数使用 一、3D Gaussian Splatting &#xff08;一&#xff09;3D Gaussian Splatting环境配置…...

uni-app 微信小程序之好看的ui登录页面(三)

文章目录 1. 页面效果2. 页面样式代码 更多登录ui页面 uni-app 微信小程序之好看的ui登录页面&#xff08;一&#xff09; uni-app 微信小程序之好看的ui登录页面&#xff08;二&#xff09; uni-app 微信小程序之好看的ui登录页面&#xff08;三&#xff09; uni-app 微信小程…...

Android 默认打开应用的权限

有项目需要客户要安装第三方软件&#xff0c;但是要手动点击打开权限&#xff0c;就想不动手就打开。 //安装第三方软件&#xff0c;修改方式 frameworks\base\services\core\java\com\android\server\pm\PackageManagerService.java //找到如下源码&#xff1a; //有三种方…...

2023年广东工业大学腾讯杯新生程序设计竞赛

E.不知道叫什么名字 题意&#xff1a;找一段连续的区间&#xff0c;使得区间和为0且区间长度最大&#xff0c;输出区间长度。 思路&#xff1a;考虑前缀和&#xff0c;然后使用map去记录每个前缀和第一次出现的位置&#xff0c;然后对数组进行扫描即可。原理&#xff1a;若 s …...

FFmpeg开发笔记(六)如何访问Github下载FFmpeg源码

学习FFmpeg的时候&#xff0c;经常要到GitHub下载各种开源代码&#xff0c;比如FFmpeg的源码页面位于https://github.com/FFmpeg/FFmpeg。然而国内访问GitHub很不稳定&#xff0c;经常打不开该网站&#xff0c;比如在命令行执行下面的ping命令。 ping github.com 上面的ping结…...

SpringCloud | Dubbo 微服务实战——注册中心详解

前言 「作者主页」&#xff1a;雪碧有白泡泡 「个人网站」&#xff1a;雪碧的个人网站 |Eureka,Nacos,Consul,Zookeeper在Spring Cloud和Dubbo中实战 引言 在项目开发过程中&#xff0c;随着项目不断扩大&#xff0c;也就是业务的不断增多&#xff0c;我们将采用集群&#xf…...

PostGIS学习教程十一:投影数据

PostGIS学习教程十一&#xff1a;投影数据 地球不是平的&#xff0c;也没有简单的方法把它放在一张平面纸地图上&#xff08;或电脑屏幕上&#xff09;&#xff0c;所以人们想出了各种巧妙的解决方案&#xff08;投影&#xff09;。 每种投影方案都有优点和缺点&#xff0c;一…...

jQuery ajax读取本地json文件 三级联动下拉框

步骤 1&#xff1a;创建本地JSON文件 {"departments": [{"name": "会计学院","code": "052"},{"name": "金融学院","code": "053"},{"name": "财税学院",&qu…...

Kubernetes(K8s 1.27.x) 快速上手+实践,无废话纯享版(视频笔记)

视频源&#xff1a;1.03-k8s是什么&#xff1f;_哔哩哔哩_bilibili 1 基础知识 1.1 K8s 有用么&#xff1f; K8s有没有用 K8s要不要学&#xff1f; 参考资料: https://www.infoq.com/articles/devops-and-cloud-trends-2022/?itm_sourcearticles_about_InfoQ-trends-report…...

深度学习实战66-基于计算机视觉的自动驾驶技术,利用YOLOP模型实现车辆区域检测框、可行驶区域和车道线分割图

大家好,我是微学AI,今天给大家介绍一下深度学习实战66-基于计算机视觉的自动驾驶技术,利用YOLOP模型实现车辆区域检测框、可行驶区域和车道线分割图。本文我将介绍自动驾驶技术及其应用场景,并重点阐述了基于计算机视觉技术下的自动驾驶。自动驾驶技术是一种利用人工智能和…...

Stable Diffusion 系列教程 - 1 基础准备(针对新手)

使用SD有两种方式&#xff1a; 本地&#xff1a; 显卡要求&#xff1a;硬件环境推荐NVIDIA的具有8G显存的独立显卡&#xff0c;这个显存勉勉强强能摸到门槛。再往下的4G可能面临各种炸显存、炼丹失败、无法生成图片等各种问题。对于8G显存&#xff0c;1.0模型就不行&#xff0…...

听GPT 讲Rust源代码--src/tools(8)

File: rust/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_match_arms.rs 在Rust源代码中&#xff0c;rust-analyzer是一个Rust编程语言的语言服务器。它提供了代码补全、代码重构和代码导航等功能来帮助开发者提高编码效率。 在rust-analyzer的代码目…...

Linux硬链接和软连接是什么?

在Linux操作系统中&#xff0c;文件管理是一个基本且重要的概念。其中&#xff0c;软链接&#xff08;Symbolic Link&#xff09;和硬链接&#xff08;Hard Link&#xff09;是文件系统中两种不同类型的链接方式&#xff0c;它们在文件管理和操作中扮演着重要的角色。软链接 软…...

LangChain 23 Agents中的Tools用于增强和扩展智能代理agent的功能

LangChain系列文章 LangChain 实现给动物取名字&#xff0c;LangChain 2模块化prompt template并用streamlit生成网站 实现给动物取名字LangChain 3使用Agent访问Wikipedia和llm-math计算狗的平均年龄LangChain 4用向量数据库Faiss存储&#xff0c;读取YouTube的视频文本搜索I…...

C++笔记 缺省值 函数重载 名字空间域(基础核心)

本文为C基础核心知识点笔记&#xff0c;聚焦「缺省值」「函数重载&#xff08;概念&#xff09;」「名字空间域」三大高频基础考点&#xff0c;语言通俗、重点突出&#xff0c;兼顾入门理解和考试记忆&#xff0c;适合新手入门、作业复习及GitHub归档。一、缺省值&#xff08;默…...

OpenClaw常见错误排查:nanobot连接问题解决方案

OpenClaw常见错误排查&#xff1a;nanobot连接问题解决方案 1. 问题背景与排查思路 上周我在本地部署OpenClaw对接nanobot镜像时&#xff0c;遇到了几个典型的连接问题。作为一个开源自动化框架&#xff0c;OpenClaw在实际使用中经常会遇到各种"水土不服"的情况。特…...

Flutter控制麦克风的方法

Flutter本身不直接提供麦克风控制的原生API&#xff0c;需借助第三方插件实现&#xff0c;核心围绕「权限申请」「麦克风开启/关闭」「音频采样/录音」「资源释放」四大场景。以下是最常用、兼容性最强的实现方案&#xff0c;覆盖多平台适配&#xff0c;附完整代码示例。 一、核…...

Mac视频预览增强工具:解决MKV文件无法预览问题的全方位方案

Mac视频预览增强工具&#xff1a;解决MKV文件无法预览问题的全方位方案 【免费下载链接】QuickLookVideo This package allows macOS Finder to display thumbnails, static QuickLook previews, cover art and metadata for most types of video files. 项目地址: https://g…...

springboot-vue+nodejs 的酒店客房预定管理系统的设计与实现

目录技术栈选择系统模块划分后端实现前端实现中间层实现数据库设计支付集成测试与部署项目技术支持源码获取详细视频演示 &#xff1a;文章底部获取博主联系方式&#xff01;同行可合作技术栈选择 Spring Boot 作为后端框架&#xff0c;提供 RESTful API 接口&#xff1b;Vue.…...

标签噪声鲁棒训练:从理论到实践,构建深度学习模型的抗噪防线

1. 标签噪声&#xff1a;深度学习中的隐形杀手 第一次用MNIST数据集跑分类模型时&#xff0c;我发现哪怕故意把20%的标签打乱&#xff0c;模型在测试集上依然能达到85%以上的准确率。这个结果让我误以为深度神经网络对标签噪声天然具有免疫力——直到后来在医疗影像分类项目里…...

WebPlotDigitizer实战指南:从科研图表中智能提取数据的完整方案

WebPlotDigitizer实战指南&#xff1a;从科研图表中智能提取数据的完整方案 【免费下载链接】WebPlotDigitizer WebPlotDigitizer: 一个基于 Web 的工具&#xff0c;用于从图形图像中提取数值数据&#xff0c;支持 XY、极地、三角图和地图。 项目地址: https://gitcode.com/g…...

小白也能懂的EmbeddingGemma-300m:用Ollama一键部署嵌入模型

小白也能懂的EmbeddingGemma-300m&#xff1a;用Ollama一键部署嵌入模型 1. 什么是EmbeddingGemma-300m&#xff1f; EmbeddingGemma-300m是谷歌推出的开源文本嵌入模型&#xff0c;它能够将任何文本转换为300维的数字向量。这些向量有一个神奇的特性&#xff1a;语义相似的文…...

日本留学中介避坑指南:免费申请与实体保障,哪种模式更适合你?

摘要随着赴日留学热度持续攀升&#xff0c;市面上的日本留学中介机构也如雨后春笋般涌现。对于计划通过语言学校过渡并升学的学生及家庭而言&#xff0c;如何在‘免费申请’与‘传统收费’、‘线上服务’与‘实体保障’之间做出抉择&#xff0c;往往充满困惑与信息不对称。本文…...

5步打造高效音乐体验:Listen1扩展的智能选择与效率提升指南

5步打造高效音乐体验&#xff1a;Listen1扩展的智能选择与效率提升指南 【免费下载链接】listen1_chrome_extension one for all free music in china (chrome extension, also works for firefox) 项目地址: https://gitcode.com/gh_mirrors/li/listen1_chrome_extension …...