基于 ResNet18 架构使用 deformable convolution的车道线检测
下面是一个基于关键点的车道线检测网络的 PyTorch 代码示例,其中使用了 deformable convolution。该代码示例基于 ResNet18 架构,可以根据实际情况进行修改。
首先,需要导入必要的库和模块:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameterfrom torchvision.models.resnet import resnet18
然后,定义一个基于 ResNet18 架构的车道线检测网络模型:
class LaneDetectionNet(nn.Module):def __init__(self, num_classes=1, deformable_groups=2):super(LaneDetectionNet, self).__init__()# load ResNet18self.resnet = resnet18(pretrained=True)# replace the first conv layerself.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)# add deformable convolutionsself.resnet.layer1[0].conv1 = DeformConv2d(64, 64, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer1[0].conv2 = DeformConv2d(64, 64, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer2[0].conv1 = DeformConv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer2[0].conv2 = DeformConv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer3[0].conv1 = DeformConv2d(256, 256, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer3[0].conv2 = DeformConv2d(256, 256, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer4[0].conv1 = DeformConv2d(512, 512, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)self.resnet.layer4[0].conv2 = DeformConv2d(512, 512, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)# add the output layersself.fc1 = nn.Linear(512, 512)self.fc2 = nn.Linear(512, num_classes)def forward(self, x):x = self.resnet(x)x = F.relu(self.fc1(x))x = self.fc2(x)return x
其中,DeformConv2d 是一个 deformable convolution 的实现类。其代码如下:
class DeformConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, deformable_groups=1):super(DeformConv2d, self).__init__()self.offset_conv = nn.Conv2d(in_channels, 2 * kernel_size[0] * kernel_size[1] * deformable_groups, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True)self.weight = Parameter(torch.Tensor(out_channels, in_channels, kernel_size[0], kernel_size[1]))if bias:self.bias = Parameter(torch.Tensor(out_channels))else:self.register_parameter('bias', None)self.reset_parameters()self.stride = _pair(stride)self.padding = _pair(padding)self.dilation = _pair(dilation)self.groups = groupsself.deformable_groups = deformable_groupsdef reset_parameters(self):nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))if self.bias is not None:fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)bound = 1 / math.sqrt(fan_in)nn.init.uniform_(self.bias, -bound, bound)def forward(self, x):offset = self.offset_conv(x)output = deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups)return output
最后,定义一个 deformable convolution 的实现函数 deform_conv2d,代码如下:
def deform_conv2d(input, offset, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1):# get shapes and parametersbatch_size, in_channels, in_h, in_w = input.size()out_channels, _, kernel_h, kernel_w = weight.size()stride_h, stride_w = _pair(stride)pad_h, pad_w = _pair(padding)dilation_h, dilation_w = _pair(dilation)input_padded = F.pad(input, (pad_w, pad_w, pad_h, pad_h))# calculate output shapeout_h = (in_h + 2*pad_h - dilation_h*(kernel_h-1) - 1) // stride_h + 1out_w = (in_w + 2*pad_w - dilation_w*(kernel_w-1) - 1) // stride_w + 1# unfold input and offsetoffset = offset.view(batch_size, deformable_groups, 2 * kernel_h * kernel_w, out_h, out_w)input_unfolded = F.unfold(input_padded, (kernel_h, kernel_w), dilation=dilation, stride=stride)# calculate outputoutput = torch.zeros(batch_size, out_channels, out_h, out_w).to(input.device)weight = weight.view(1, out_channels, in_channels // groups, kernel_h, kernel_w).repeat(batch_size, 1, 1, 1, 1)for h in range(out_h):for w in range(out_w):input_region = input_unfolded[:, :, h, w].view(batch_size, -1, 1, 1)offset_region = offset[:, :, :, h, w]weight_region = weightoutput_region = F.conv2d(input_region, weight_region, bias=None, stride=1, padding=0, dilation=1, groups=deformable_groups)output_region = deformable_conv2d_compute(output_region, offset_region)output[:, :, h, w] = output_region.squeeze()if bias is not None:output += bias.view(1, -1, 1, 1)return output
其中,deformable_conv2d_compute 函数是 deformable convolution 的计算函数。它的代码如下:
def deformable_conv2d_compute(input, offset):# get shapes and parametersbatch_size, out_channels, out_h, out_w = input.size()in_channels = offset.size(1) // 2# sample input according to offsetgrid_h = torch.linspace(-1, 1, out_h).view(1, 1, out_h, 1).to(input.device)grid_w = torch.linspace(-1, 1, out_w).view(1, 1, 1, out_w).to(input.device)offset_h = offset[:, :in_channels, :, :]offset_w = offset[:, in_channels:, :, :]sample_h = torch.add(grid_h, offset_h)sample_w = torch.add(grid_w, offset_w)sample_h = sample_h.clamp(-1, 1)sample_w = sample_w.clamp(-1, 1)sample_h = ((sample_h + 1) / 2) * (out_h - 1)sample_w = ((sample_w + 1) / 2) * (out_w - 1)sample_h_floor = sample_h.floor().long()sample_w_floor = sample_w.floor().long()sample_h_ceil = sample_h.ceil().long()sample_w_ceil = sample_w.ceil().long()sample_h_floor = sample_h_floor.clamp(0, out_h - 1)sample_w_floor = sample_w_floor.clamp(0, out_w - 1)sample_h_ceil = sample_h_ceil.clamp(0, out_h - 1)sample_w_ceil = sample_w_ceil.clamp(0, out_w - 1)# gather input values according to sampled indicesinput_flat = input.view(batch_size, in_channels, out_h * out_w)index_base = torch.arange(0, batch_size, device=input.device).view(batch_size, 1, 1) * out_h * out_windex_base = index_base.expand(batch_size, in_channels, out_h * out_w)index_offset = torch.arange(0, out_h * out_w, device=input.device).view(1, 1, -1)index_offset = index_offset.expand(batch_size, in_channels, out_h * out_w)indices_a = (sample_h_floor + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)indices_b = (sample_w_floor + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)indices_c = (sample_h_ceil + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)indices_d = (sample_w_ceil + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)value_a = input_flat.gather(2, indices_a.unsqueeze(1).repeat(1, out_channels, 1))value_b = input_flat.gather(2, indices_b.unsqueeze(1).repeat(1, out_channels, 1))value_c = input_flat.gather(2, indices_c.unsqueeze(1).repeat(1, out_channels, 1))value_d = input_flat.gather(2, indices_d.unsqueeze(1).repeat(1, out_channels, 1))# calculate interpolation weights and outputw_a = ((sample_w_ceil - sample_w) * (sample_h_ceil - sample_h)).view(batch_size, 1, out_h, out_w)w_b = ((sample_w - sample_w_floor) * (sample_h_ceil - sample_h)).view(batch_size, 1, out_h, out_w)w_c = ((sample_w_ceil - sample_w) * (sample_h - sample_h_floor)).view(batch_size, 1, out_h, out_w)w_d = ((sample_w - sample_w_floor) * (sample_h - sample_h_floor)).view(batch_size, 1, out_h, out_w)output = w_a * value_a + w_b * value_b + w_c * value_c + w_d * value_dreturn output
最后,可以使用以下代码进行网络的测试:
net = LaneDetectionNet(num_classes=1, deformable_groups=2) # create the network
input = torch.randn(1, 3, 100, 100) # create a random input tensor
output = net(input) # feed it through the network
print(output.shape) # print the output shape
输出的结果应该为 (1, 1, 1, 1)。这说明网络已经成功地将 100*100 的像素图压缩成了一个标量。可以根据实际情况进行调整和优化,来达到更好的性能。
相关文章:
基于 ResNet18 架构使用 deformable convolution的车道线检测
下面是一个基于关键点的车道线检测网络的 PyTorch 代码示例,其中使用了 deformable convolution。该代码示例基于 ResNet18 架构,可以根据实际情况进行修改。 首先,需要导入必要的库和模块: import torch import torch.nn as nn…...

C++in/out输入输出流[IO流]
文章目录 1. C语言的输入与输出2.C的IO流2.1流的概念2.2CIO流2.3刷题常见while(cin >> str)重载强制类型转换运算符模拟while(cin >> str) 2.4C标准IO流2.5C文件IO流1.ifstream 1. C语言的输入与输出 C语言用到最频繁的输入输出方式就是scanf ()与printf()。 scanf…...

MongoDB的安装
MongoDB的安装 1、Windows下MongoDB的安装及配置 1.1 下载Mongodb安装包 下载地址: https://www.mongodb.com/try/download http://www.mongodb.org/dl/win32 MongoDB Windows系统64位下载地址:http://www.mongodb.org/dl/win32/x86_64 MongoDB W…...

SQL查询优化---如何查询截取分析
慢查询日志 1、慢查询日志是什么 MySQL的慢查询日志是MySQL提供的一种日志记录,它用来记录在MySQL中响应时间超过阀值的语句,具体指运行时间超过long_query_time值的SQL,则会被记录到慢查询日志中。 具体指运行时间超过long_query_time值的…...

vue3基础流程
目录 1. 安装和创建项目 2. 项目结构 3. 主要文件解析 3.1 main.js 3.2 App.vue 4. 组件和Props 5. 事件处理 6. 生命周期钩子 7. Vue 3的Composition API 8. 总结和结论 响应式系统: 组件化: 易于学习: 灵活性: 社…...

Vue 数据绑定 和 数据渲染
目录 一、Vue快速入门 1.简介 : 2.MVVM : 3.准备工作 : 二、数据绑定 1.实例 : 2.验证 : 三、数据渲染 1.单向渲染 : 2.双向渲染 : 一、Vue快速入门 1.简介 : (1) Vue[/vju/],是Vue.js的简称,是一个前端框架,常用于构建前端用户…...

【原创】解决Kotlin无法使用@Slf4j注解的问题
前言 主要还是辟谣之前的网上的用法,当然也会给出最终的使用方法。这可是Kotlin,关Slf4j何事!? 辟谣内容:创建注解来解决这个问题 例如: Target(AnnotationTarget.CLASS) Retention(AnnotationRetentio…...
CDN是如何实现全球节点同步的
当谈到内容交付网络(Content Delivery Network,CDN)加速时,我们必须了解CDN是如何实现全球节点同步的。CDN是一种网络架构,通过将内容分发到全球各地的服务器节点,以降低用户访问网站或应用程序时的延迟和提…...
Centos7 Linux系统下生成https的crt和key证书
linux下生成https的crt和key证书 步骤如下: x509证书一般会用到三类文,key,csr,crt Key 是私用密钥openssl格,通常是rsa算法。 Csr 是证书请求文件,用于申请证书。在制作csr文件的时,必须使…...

性能测试工具——Jmeter的安装【超详细】
目录 1、性能测试工具:JMeter和LoadRunner对比 2、为什么学习JMeter? 3、JMeter环境搭建 3.1、安装JDK 3.2、下载安装JMeter 3.3、配置环境变量 2.4、启动验证JMeter是否安装成功 4、认识JMeter的目录结构 1)bin目录:存放…...
系列三十、Spring AOP vs AspectJ AOP
一、关系 (1)当在Spring中要使用Aspect、Before、After等注解时,需要添加AspectJ的相关依赖,如下 <dependency><groupId>cglib</groupId><artifactId>cglib</artifactId><version>3.1</…...
面向对象设计模式——策略模式
策略设计模式(Strategy Pattern)是一种行为型设计模式,它允许在运行时选择算法的行为。该模式定义了一系列算法,将每个算法封装到一个独立的类中,使它们可以相互替换。策略模式使算法独立于客户端而变化,客…...

Kubernetes - Ingress HTTP 负载搭建部署解决方案(新版本v1.21+)
在看这一篇之前,如果不了解 Ingress 在 K8s 当中的职责,建议看之前的一篇针对旧版本 Ingress 的部署搭建,在开头会提到它的一些简介Kubernetes - Ingress HTTP 负载搭建部署解决方案_放羊的牧码的博客-CSDN博客 开始表演 1、kubeasz 一键安装…...

刚刚:腾讯云3年轻量2核2G4M服务器优惠价格366元三年
腾讯云3年轻量2核2G4M服务器,2023双十一优惠价格366元三年,自带4M公网带宽,下载速度可达512KB/秒,300GB月流量,50GB SSD盘系统盘,腾讯云百科txybk.com分享腾讯云轻量2核2G4M服务器性能、优惠活动、购买条件…...
`include指令【FPGA】
案例: 在Verilog中,include指令可以将一个文件的内容插入到当前文件中。 这个指令通常用于将一些常用的代码片段或者模块定义放在单独的文件中, 然后在需要使用的地方通过include指令将其插入到当前文件中。 这样可以提高代码的复用性和可维…...

iphone备份后怎么转到新手机,iphone备份在哪里查看
iphone备份会备份哪些东西?iphone可根据需要备份设备数据、应用数据、苹果系统等。根据不同的备份数据,可备份的数据类型不同,有些工具可整机备份,有些工具可单项数据备份。本文会详细讲解苹果手机备份可以备份哪些东西。 一、ip…...

JAVA毕业设计106—基于Java+Springboot的外卖系统(源码+数据库)
基于JavaSpringboot的外卖系统(源码数据库)106 一、系统介绍 本系统分为用户端和管理端角色 前台用户功能: 登录、菜品浏览,口味选择,加入购物车,地址管理,提交订单。 管理后台: 登录,员工管…...

SpringCore完整学习教程4,入门级别
本章从第4章开始 4. Logging Spring Boot使用Commons Logging进行所有内部日志记录,但保留底层日志实现开放。为Java Util Logging、Log4J2和Logback提供了默认配置。在每种情况下,记录器都预先配置为使用控制台输出和可选的文件输出。 默认情况下&…...

如何能在项目具体编码实现之前能尽可能早的发现问题并解决问题
在项目的具体编码实现之前尽可能早地发现并解决问题,可以大大节省时间和资源,提高项目的成功率。以下是一些策略和方法: 1. 明确需求和预期: 确保所有的项目需求都是清晰和明确的。需求模糊不清是项目失败的常见原因之一。与利益…...
Windows server服务器允许多用户远程的设置
在Windows Server上允许多用户同时进行远程桌面连接,您需要配置远程桌面服务以支持多用户并确保许可证和授权允许多用户连接。以下是在Windows Server上允许多用户远程桌面连接的步骤: 注意:这些步骤适用于 Windows Server 2012、Windows Ser…...

深入剖析AI大模型:大模型时代的 Prompt 工程全解析
今天聊的内容,我认为是AI开发里面非常重要的内容。它在AI开发里无处不在,当你对 AI 助手说 "用李白的风格写一首关于人工智能的诗",或者让翻译模型 "将这段合同翻译成商务日语" 时,输入的这句话就是 Prompt。…...

docker详细操作--未完待续
docker介绍 docker官网: Docker:加速容器应用程序开发 harbor官网:Harbor - Harbor 中文 使用docker加速器: Docker镜像极速下载服务 - 毫秒镜像 是什么 Docker 是一种开源的容器化平台,用于将应用程序及其依赖项(如库、运行时环…...

【第二十一章 SDIO接口(SDIO)】
第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...

跨链模式:多链互操作架构与性能扩展方案
跨链模式:多链互操作架构与性能扩展方案 ——构建下一代区块链互联网的技术基石 一、跨链架构的核心范式演进 1. 分层协议栈:模块化解耦设计 现代跨链系统采用分层协议栈实现灵活扩展(H2Cross架构): 适配层…...

Module Federation 和 Native Federation 的比较
前言 Module Federation 是 Webpack 5 引入的微前端架构方案,允许不同独立构建的应用在运行时动态共享模块。 Native Federation 是 Angular 官方基于 Module Federation 理念实现的专为 Angular 优化的微前端方案。 概念解析 Module Federation (模块联邦) Modul…...
3403. 从盒子中找出字典序最大的字符串 I
3403. 从盒子中找出字典序最大的字符串 I 题目链接:3403. 从盒子中找出字典序最大的字符串 I 代码如下: class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...
服务器--宝塔命令
一、宝塔面板安装命令 ⚠️ 必须使用 root 用户 或 sudo 权限执行! sudo su - 1. CentOS 系统: yum install -y wget && wget -O install.sh http://download.bt.cn/install/install_6.0.sh && sh install.sh2. Ubuntu / Debian 系统…...

HDFS分布式存储 zookeeper
hadoop介绍 狭义上hadoop是指apache的一款开源软件 用java语言实现开源框架,允许使用简单的变成模型跨计算机对大型集群进行分布式处理(1.海量的数据存储 2.海量数据的计算)Hadoop核心组件 hdfs(分布式文件存储系统)&a…...
Java编程之桥接模式
定义 桥接模式(Bridge Pattern)属于结构型设计模式,它的核心意图是将抽象部分与实现部分分离,使它们可以独立地变化。这种模式通过组合关系来替代继承关系,从而降低了抽象和实现这两个可变维度之间的耦合度。 用例子…...
MinIO Docker 部署:仅开放一个端口
MinIO Docker 部署:仅开放一个端口 在实际的服务器部署中,出于安全和管理的考虑,我们可能只能开放一个端口。MinIO 是一个高性能的对象存储服务,支持 Docker 部署,但默认情况下它需要两个端口:一个是 API 端口(用于存储和访问数据),另一个是控制台端口(用于管理界面…...