【三维几何学习】从零开始网格上的深度学习-3:Transformer篇(Pytorch)
本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052
从零开始网格上的深度学习-3:Transformer篇
- 引言
- 一、概述
- 二、核心代码
- 2.1 位置编码
- 2.2 网络框架
- 三、基于Transformer的网格分类
- 3.1 分类结果
- 3.2 全部代码
引言
本文主要内容如下:
- 简述网格上的位置编码
- 参考点云上的Transformer-1:PCT:Point cloud transformer,构造网格分类网络
一、概述
个人认为
对于三角形网格来说,想要将Transformer应用到其上较为重要的一步是位置编码
。三角网格在3D空间中如何编码每一个元素的位置,能尽可能保证的泛化性能? 以xyz坐标为例,最好是模型经过对齐的预处理,使朝向一致。或者保证网格水密的情况下使用谱域特征,如热核特征。或者探索其他位置编码等等… 上图为一个外星人x坐标的位置编码可视化
- 使用简化网格每一个面直接作为一个Token即可,高分辨率的网格(考虑输入特征计算、训练数据对齐等)并不适合深度学习(
个人认为
) - 直接应用现有的Tranformer网络框架、自注意力模块等,
细节或参数需要微调
二、核心代码
2.1 位置编码
使用每一个网格面的中心坐标作为位置编码,计算代码在DataLoader中
:
- 需要平移到坐标轴原点,并进行尺度归一化
# xyz
xyz_min = np.min(vs[:, 0:3], axis=0)
xyz_max = np.max(vs[:, 0:3], axis=0)
xyz_move = xyz_min + (xyz_max - xyz_min) / 2
vs[:, 0:3] = vs[:, 0:3] - xyz_move
# scale
scale = np.max(vs[:, 0:3])
vs[:, 0:3] = vs[:, 0:3] / scale
# 面中心坐标
xyz = []
for i in range(3):xyz.append(vs[faces[:, i]])
xyz = np.array(xyz) # 转为np
mean_xyz = xyz.sum(axis=0) / 3
2.2 网络框架
- 参考上图PCT框架,修改了部分细节,如减少了Attention模块数量等
- 参考上图自注意力模块,
个人感觉
图中应该有误. 从一个共享权重的Linear里出来了Q、K、VQ、K、VQ、K、V三个矩阵,但VVV的维度和Q、KQ、KQ、K不一致,少画了一个Linear?
class SA(nn.Module):def __init__(self, channels):super().__init__()self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)self.q_conv.weight = self.k_conv.weightself.v_conv = nn.Conv1d(channels, channels, 1, bias=False)self.trans_conv = nn.Conv1d(channels, channels, 1)self.after_norm = nn.BatchNorm1d(channels)self.act = nn.GELU()self.softmax = nn.Softmax(dim=-1)def forward(self, x):x_q = self.q_conv(x).permute(0, 2, 1)x_k = self.k_conv(x)x_v = self.v_conv(x)energy = x_q @ x_kattention = self.softmax(energy)attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))x_r = x_v @ attentionx_r = self.act(self.after_norm(self.trans_conv(x - x_r)))x = x + x_rreturn xclass TriTransNet(nn.Module):def __init__(self, dim_in, classes_n=30):super().__init__()self.conv_fea = FaceConv(6, 128, 4)self.conv_pos = FaceConv(3, 128, 4)self.bn_fea = nn.BatchNorm1d(128)self.bn_pos = nn.BatchNorm1d(128)self.sa1 = SA(128)self.sa2 = SA(128)self.gp = nn.AdaptiveAvgPool1d(1)self.linear1 = nn.Linear(256, 128, bias=False)self.bn1 = nn.BatchNorm1d(128)self.linear2 = nn.Linear(128, classes_n)self.act = nn.GELU()def forward(self, x, mesh):x = x.permute(0, 2, 1).contiguous()# 位置编码 放到DataLoader中比较好pos = [m.xyz for m in mesh]pos = np.array(pos)pos = torch.from_numpy(pos).float().to(x.device).requires_grad_(True)batch_size, _, N = x.size()x = self.act(self.bn_fea(self.conv_fea(x, mesh).squeeze(-1)))pos = self.act(self.bn_pos(self.conv_pos(pos, mesh).squeeze(-1)))x1 = self.sa1(x + pos)x2 = self.sa2(x1 + pos)x = torch.cat((x1, x2), dim=1)x = self.gp(x)x = x.view(batch_size, -1)x = self.act(self.bn1(self.linear1(x)))x = self.linear2(x)return x
三、基于Transformer的网格分类
数据集是SHREC’11
可参考三角网格(Triangular Mesh)分类数据集 或 MeshCNN
3.1 分类结果
准确率太低… 可以尝试改进的点:
- 尝试不同的位置编码(
谱域特征
),不同的位置嵌入方式 (sum可改为concat
) 数据集较小
的情况下Transformer略难收敛,加入更多CNN可加速且提升明显 (或者加入降采样
)- 打印loss进行分析,是否
欠拟合
,尝试增加网络参数?
基于Transformer的网络在网格分割上的表现会很好,仅用少量参数即可媲美甚至超过基于面卷积的分割结果,个人感觉得益于其近乎全局的感受野…
3.2 全部代码
DataLoader代码请参考2:从零开始网格上的深度学习-1:输入篇(Pytorch)
FaceConv代码请参考3:从零开始网格上的深度学习-2:卷积网络CNN篇
import torch
import torch.nn as nn
import numpy as np
from CNN import FaceConv
from DataLoader_shrec11 import DataLoader
from DataLoader_shrec11 import Meshclass SA(nn.Module):def __init__(self, channels):super().__init__()self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)self.q_conv.weight = self.k_conv.weightself.v_conv = nn.Conv1d(channels, channels, 1, bias=False)self.trans_conv = nn.Conv1d(channels, channels, 1)self.after_norm = nn.BatchNorm1d(channels)self.act = nn.GELU()self.softmax = nn.Softmax(dim=-1)def forward(self, x):x_q = self.q_conv(x).permute(0, 2, 1)x_k = self.k_conv(x)x_v = self.v_conv(x)energy = x_q @ x_kattention = self.softmax(energy)attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))x_r = x_v @ attentionx_r = self.act(self.after_norm(self.trans_conv(x - x_r)))x = x + x_rreturn xclass TriTransNet(nn.Module):def __init__(self, dim_in, classes_n=30):super().__init__()self.conv_fea = FaceConv(6, 128, 4)self.conv_pos = FaceConv(3, 128, 4)self.bn_fea = nn.BatchNorm1d(128)self.bn_pos = nn.BatchNorm1d(128)self.sa1 = SA(128)self.sa2 = SA(128)self.gp = nn.AdaptiveAvgPool1d(1)self.linear1 = nn.Linear(256, 128, bias=False)self.bn1 = nn.BatchNorm1d(128)self.linear2 = nn.Linear(128, classes_n)self.act = nn.GELU()def forward(self, x, mesh):x = x.permute(0, 2, 1).contiguous()# 位置编码 放到DataLoader中比较好pos = [m.xyz for m in mesh]pos = np.array(pos)pos = torch.from_numpy(pos).float().to(x.device).requires_grad_(True)batch_size, _, N = x.size()x = self.act(self.bn_fea(self.conv_fea(x, mesh).squeeze(-1)))pos = self.act(self.bn_pos(self.conv_pos(pos, mesh).squeeze(-1)))x1 = self.sa1(x + pos)x2 = self.sa2(x1 + pos)x = torch.cat((x1, x2), dim=1)x = self.gp(x)x = x.view(batch_size, -1)x = self.act(self.bn1(self.linear1(x)))x = self.linear2(x)return xif __name__ == '__main__':# 输入data_train = DataLoader(phase='train') # 训练集data_test = DataLoader(phase='test') # 测试集print('#train meshes = %d' % len(data_train)) # 输出训练模型个数print('#test meshes = %d' % len(data_test)) # 输出测试模型个数# 网络net = TriTransNet(data_train.input_n, data_train.class_n) # 创建网络 以及 优化器optimizer = torch.optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999))net = net.cuda(0)loss_fun = torch.nn.CrossEntropyLoss(ignore_index=-1)num_params = 0for param in net.parameters():num_params += param.numel()print('[Net] Total number of parameters : %.3f M' % (num_params / 1e6))print('-----------------------------------------------')# 迭代训练for epoch in range(1, 201):print('---------------- Epoch: %d -------------' % epoch)for i, data in enumerate(data_train):# 前向传播net.train(True) # 训练模式optimizer.zero_grad() # 梯度清零face_features = torch.from_numpy(data['face_features']).float()face_features = face_features.to(data_train.device).requires_grad_(True)labels = torch.from_numpy(data['label']).long().to(data_train.device)out = net(face_features, data['mesh']) # 输入到网络# 反向传播loss = loss_fun(out, labels)loss.backward()optimizer.step() # 参数更新# 测试net.eval()acc = 0for i, data in enumerate(data_test):with torch.no_grad():# 前向传播face_features = torch.from_numpy(data['face_features']).float()face_features = face_features.to(data_test.device).requires_grad_(False)labels = torch.from_numpy(data['label']).long().to(data_test.device)out = net(face_features, data['mesh'])# 计算准确率pred_class = out.data.max(1)[1]correct = pred_class.eq(labels).sum().float()acc += correctacc = acc / len(data_test)print('epoch: %d, TEST ACC: %0.2f' % (epoch, acc * 100))
PCT:Point cloud transformer ↩︎
从零开始网格上的深度学习-1:输入篇(Pytorch) ↩︎
从零开始网格上的深度学习-2:卷积网络CNN篇 ↩︎
相关文章:

【三维几何学习】从零开始网格上的深度学习-3:Transformer篇(Pytorch)
本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052 从零开始网格上的深度学习-3:Transformer篇引言一、概述二、核心代码2.1 位置编码2.2 网络框架三、基于Transformer的网格分类3.1 分类结果3.2 全部代码引言 本文主要内容如下&#…...

一、基础算法3:二分 模板题+算法模板(数的范围,数的三次方根)
文章目录算法模板整数二分算法模板浮点数二分算法模板模板题数的范围原题链接题目题解数的三次方根原题链接题目题解算法模板 整数二分算法模板 bool check(int x) {/* ... */} // 检查x是否满足某种性质// 区间[l, r]被划分成[l, mid]和[mid 1, r]时使用: int b…...

Spring 源码解析 - Bean创建过程 以及 解决循环依赖
一、Spring Bean创建过程以及循环依赖 上篇文章对 Spring Bean资源的加载注册过程进行了源码梳理和解析,我们可以得到结论,资源文件中的 bean 定义信息,被组装成了 BeanDefinition 存放进了 beanDefinitionMap 容器中,那 Bean 是…...

移除元素(双指针)
给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素,并返回移除后数组的新长度。 不要使用额外的数组空间,你必须仅使用 O(1) 额外空间并原地修改输入数组。 元素的顺序可以改变。你不需要考虑数组中超出新长度后面的…...

76.qt qml-QianWindow开源炫酷界面框架(支持白色暗黑渐变自定义控件均以适配)
界面介绍界面支持: 透明 白色 黑色 渐变 单色 静态图 动态图侧边栏支持:抽屉、带折叠、多模式场景控件已集成: 暗黑风格 高亮风格、并附带个人自定义控件及开源demo白色场景如下所示:单色暗黑风格如下所示:用户自定义皮肤如下所示:皮肤预览如下所示:b站入口:https://www.bilibi…...

Python生日蛋糕
目录 前言 底盘 蛋糕 蜡烛 祝福 前言 Hello,小伙伴们晚上好吖!前两天博主满20岁啦(要开始奔三辽呜呜呜),这几天收到了不少小伙伴们的祝福,浪漫的小博主想送给大家一份不一样的生日蛋糕,…...

QT 如何提高 Qt Creator 的编译速度
如何提高编译速度,貌似是一个老生常谈的话题。对于Qter而言,如何提高QT Creator 的编辑速度是一直都是大家所期盼的。本文也是查阅了各路大神的方法后整理出来的,希望对各位有所帮助。 1、在*.pro文件添加预编译机制 QT官方给出的示例&…...

STM32之震动传感器、继电器介绍及实战
目录 一、震动传感器介绍及实战 二、编程代码实现 1、gpio.c---------初始化GPIO口引脚函数 2、调用中断服务函数 3、中断服务函数 4、中断服务回调函数 5、把上述的中断服务回调函数,放入main主函数里 6、结果演示 三、继电器介绍及实战 一、震动传感器介…...

RK3588平台开发系列讲解(显示篇)RK3588 平台 的DP介绍
平台内核版本安卓版本RK3588Linux 5.10Android 12文章目录 一、功能特性二、 DP 输⼊三、DP 输出四、 代码路径沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇将介绍 RK3588 平台 DP 的使⽤与调试⽅法。 一、功能特性 RK3588 的 DP ⽀持 1.4a 版本的 DP 协议,最…...

【Java】i++和++i的实现原理
文章目录 i++案例反编译分析扩展 x = x++我们接下来从字节码层面分析: 不了解字节码的可以参考这篇:【精通JVM】字节码指令全解 i++案例 package org.example;public class Main {public static void main...

第十四届蓝桥杯三月真题刷题训练——第 18 天
目录 第 1 题:排列字母 问题描述 运行限制 代码: 第 2 题:GCD_数论 问题描述 输入格式 输出格式 样例输入 样例输出 评测用例规模与约定 运行限制 第 3 题:选数异或 第 4 题:背包与魔法 第 1 题&#x…...

软件测试拿了几个20K offer,分享一波面经
1、你的测试职业发展是什么? 测试经验越多,测试能力越高。所以我的职业发展是需要时间积累的,一步步向着高级测试工程师奔去。而且我也有初步的职业规划,前3年积累测试经验,按如何做好测试工程师的要点去要求自己,不断…...

spring2
1.Spring配置数据源1.1 数据源(连接池)的作用 数据源(连接池)是提高程序性能如出现的事先实例化数据源,初始化部分连接资源使用连接资源时从数据源中获取使用完毕后将连接资源归还给数据源常见的数据源(连接池):DBCP、C3P0、BoneC…...

【Linux】网络编程套接字(中)
🎇Linux: 博客主页:一起去看日落吗分享博主的在Linux中学习到的知识和遇到的问题博主的能力有限,出现错误希望大家不吝赐教分享给大家一句我很喜欢的话: 看似不起波澜的日复一日,一定会在某一天让你看见坚持…...

手撕数据结构—队列
队列队列的话只允许在一端插入,在另外一端删除。插入数据的那一段叫做队尾,出数据的那一段叫做队头(从尾巴插入)。因此的话队列是先进先出的。入的顺序与出的顺序的话是一样的。这个与栈是不一样的,因为栈的话就是说如…...

gdb调试工具和makemakefile工具
gdb调试工具和make/makefile工具 文章目录gdb调试工具和make/makefile工具一、gdb调试工具1.debug/release2.使用二、make/makefile1.什么是make/makefile2.编写一、gdb调试工具 1.debug/release 程序有两种默认的发布方式debug和release。release是无法进行调试的。Linux中g…...

【进阶数据结构】平衡搜索二叉树 —— AVL树
🌈感谢阅读East-sunrise学习分享——[进阶数据结构]AVL树 博主水平有限,如有差错,欢迎斧正🙏感谢有你 码字不易,若有收获,期待你的点赞关注💙我们一起进步🚀 🌈我们上一篇…...

ROS使用(5)action学习
action消息的构建 首先进行功能包的创建 mkdir -p ros2_ws/src cd ros2_ws/src ros2 pkg create action_tutorials_interfaces action消息的类型 # Request --- # Result --- # Feedback 动作定义由三个消息定义组成,以---分隔。 从动作客户机向动作服务器发送…...

2023前端面试题集(含答案)之HTML+CSS篇(一)
在又到了金三银四的招聘季,不管你是刚入行的小白,亦或是混迹职场的老鸟,还在为面试前端工程师时不知道面试官要问什么怎么回答而苦恼吗?为了帮助你获得面试官的青睐,顺利通过面试,跳槽进入大厂,…...

设计模式2 - 观察者模式
定义: 观察者模式又叫发布订阅模式,它定义了对象之间的一对多依赖,这样一来,当一个对象改变状态时,它的所有依赖者都会收到通知并自动更新。 组成: Subject(通知者/被观察者)&#…...

ini配置文件
ini配置文件 ini文件是initialization file的缩写,即初始化文件,是widows系统配置文件所采用的存储格式。 文件扩展名: .ini ini配置文件的后缀名也不一定必须是.ini, 也可以是.cfg, .conf或者是.txt ini文件格式 ini配置文件由参数, 节, 注解组成 参…...

蓝桥杯备赛经验 pythonA组(非科班选手)
个人2022 CA组江苏省一等奖,决赛成绩不理想,没有拿到一二等奖,但是因为自己是非科班的学生,所以能拿到这样的成绩自己其实也应该知足了 题外话: 很多ACMer嘲笑蓝桥杯非常水,但是据我观察CA组决赛一等奖获奖…...

C++实现通讯录管理系统
通讯录是一个可以记录亲人、好友信息的工具,本博客借助黑马程序员的项目进行修改,利用C实现一个通讯录管理系统,旨在复习C的语法。 一、系统需求 系统需要实现的功能如下: 添加联系人∶向通讯录中添加新人,信息包括…...

开关电源Y电容放置的位置
Y电容,是我们工程师做开关电源设计时都要接触到的一个非常关键的元器件,它对EMI的贡献是相当的大的,但是它是一个较难把控的元器件,原理上并没有那么直观易懂,在EMI传播路径中需要联系到很多的寄生参数才能够去分析。 …...

二叉树的最小深度——递归法、迭代法
1题目给定一个二叉树,找出其最小深度。最小深度是从根节点到最近叶子节点的最短路径上的节点数量。说明:叶子节点是指没有子节点的节点。示例 1:输入:root [3,9,20,null,null,15,7]输出:2示例 2:输入&…...

Vue中常使用的三种刷新页面的方式
一、通过js原始方法刷新 缺点: 出现闪白 目录 一、通过js原始方法刷新 二、通过Vue自带的路由进行跳转 三、通过在APP页面进行demo进行刷新(推荐) 1.vue2写法 2. vue3.2写法 <template><div><div class"header"><button clic…...

【Shell】脚本
Shell脚本脚本格式第一个Shell脚本:hello.sh脚本常用执行方式1. bash或sh脚本的相对路径或绝对路径2. 输入脚本的绝对路径或相对路径3. 在脚本的路径前加上.或者source脚本格式 脚本以#!/bin/bash开头(指定解析器) #! 是一个约定的标记&…...

Mybatis的多表操作
1.Mybatis多表查询 1.1一对一查询 1.一对一查询的模型 用户表和订单表的关系为,一个用户有多个订单,一个订单只从属于一个用户 一对一查询的需求:查询一个订单,与此同时查询出该订单所属的用户2.创建Order和User实体public class…...

【JVM】字节码指令全解
文章目录 入门案例原始 java 代码编译后的字节码文件常量池载入运行时常量池方法字节码载入方法区main 线程开始运行,分配栈帧内存执行引擎开始执行字节码bipush 10istore_1ldc #3istore_2iload_1iload_2iaddistore_3getstatic #4iload_3invokevirtual #5return条件判断指令循…...

【精品】华为认证数通HCIA+HCIP题库分享(含答案解析)
嗨~大家好久不见,我是薄荷学姐,随着华为业务也全球领域的迅猛发展,越来越多人开始重视华为认证的重要性。今天给大家分享一下去年8月份的题库,基本都是一样,希望可以帮助到大家哈想要通过华为认证,除了进行…...