基于chatGPT设计卷积神经网络
1. 简介
本文主要介绍基于chatGPT,设计一个针对骁龙855芯片设计的友好型神经网络。
提问->跑通总共花了5min左右,最终得到的网络在Cifar100数据集上与ResNet18的精度对比如下。
模型 | flops | params | train acc1/5 | test acc1/5 |
---|---|---|---|---|
ResNet18(timm) | 1.82 | 11.18 | ~98%/100% | 63.11%/86.11% |
CNN-GPT | 1.51 | 6.66 | ~98%/100% | 68.44%/89.14% |
此外,GPT生成的模型训练速度更快,ResNet18才训练170个epoch,CNN-GPT已经完整地训练完了,具体的模型结构在文末,欢迎尝试使用~
个人认为比较有意思的一点是,GPT生成的模型并没有在stem block中就进行降采样,而是卷了一会之后才开始,这个与之前在文本任务中的经验结果类似,特征图越大,会避免信息损失。
基于chatGPT生成模型结构的方式或许在某些情况下可以解决我们的问题,但是基础的神经网络知识还是需要的,否则它给出的答案有些小bug,我们也不知道怎么修改。
模型训练结果与超参设置强相关,在这里仅设置默认的训练策略作为对比,根据训练结果看出,两个模型均过拟合严重。而且,不管是ResNet18或者CNN-GPT,均与Cifar100数据集的SOTA精度相差很远,此处需要注意。
2. 具体步骤
Step1 骁龙855芯片上可以正常运行的算子
基于上述问题,chatGPT给出的回答如下所示。
以下是在骁龙855芯片上比较友好的神经网络算子:1. Conv2d
2. Depthwise Conv2d
3. Maxpool2d
4. Avgpool2d
5. BatchNorm2d
6. ReLU
7. Flatten
8. Linear
整体看起来还算比较靠谱。
Step2 构建一个模型大小在10M左右的模型
针对上面的问题,chatGPT首先给出了思路。
如果想设计一个在骁龙855芯片上运行速度很快,模型大小在10M以内的模型,可以考虑使用以下方法来进行优化:1. 使用深度可分离卷积代替标准卷积,以获得更快的速度和更小的模型尺寸。
2. 使用深度方向的滤波器来替换全连接层,以减少参数数量并减少内存使用。
3. 对于每个卷积层,使用Batch Norm层来加速收敛。
4. 使用ReLU代替传统的Sigmoid或Tanh激活函数,因为ReLU的计算效率更高。
然后给出了一个模型示例。
以下是一个可以作为参考的模型:Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
BatchNorm2d(32)
ReLU()DepthwiseConv2d(32, kernel_size=3, stride=1, padding=1)
BatchNorm2d(32)
ReLU()Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
BatchNorm2d(64)
ReLU()DepthwiseConv2d(64, kernel_size=3, stride=1, padding=1)
BatchNorm2d(64)
ReLU()Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
BatchNorm2d(128)
ReLU()DepthwiseConv2d(128, kernel_size=3, stride=1, padding=1)
BatchNorm2d(128)
ReLU()Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
BatchNorm2d(256)
ReLU()DepthwiseConv2d(256, kernel_size=3, stride=1, padding=1)
BatchNorm2d(256)
ReLU()Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
BatchNorm2d(512)
ReLU()DepthwiseConv2d(512, kernel_size=3, stride=1, padding=1)
BatchNorm2d(512)
ReLU()Conv2d(512, 1024, kernel_size=3, stride=2, padding=1)
BatchNorm2d(1024)
ReLU()Avgpool2d(kernel_size=4)
Flatten()
Linear(1024, 10)
上面生成的模型有几个小问题。
- DepthwiseConv2d没有被定义
- DepthwiseConv2d的参数列表中没有out channel的参数,需要手动补充。
Step3 生成一个DepthwiseConv2d结构
chatGPT回答如下
以下是一个非常简单的torch的DepthwiseConv2d实现:
import torch.nn as nnclass DepthwiseConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=None):super(DepthwiseConv2d, self).__init__()if groups is None:groups = in_channelsself.depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups)self.pointwise_conv = nn.Conv2d(out_channels, out_channels, 1, 1, 0, 1, groups=1)def forward(self, x):x = self.depthwise_conv(x)x = self.pointwise_conv(x)return x
加入到模型结构定义文件中,同时修改类别数,就可以进行模型训练了。
设置训练超参数
为保证对比的公平性,下面是训练过程中的一些超参数,
parser.add_argument("--epoch", type=int, default=240)parser.add_argument("--batch_size", type=int, default=64)parser.add_argument("--lr", type=float, default=0.05)parser.add_argument("--lr_decay_rate", type=float, default=0.1)parser.add_argument("--lr_decay_stages", type=str, default="150,180,210")parser.add_argument("--wd", type=float, default=0.0005)
3. 完整的模型结构
完整的模型结构如下,也欢迎大家在自己的数据集上尝试。
import torch
import torch.nn as nnclass DepthwiseConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=None):super(DepthwiseConv2d, self).__init__()if groups is None:groups = in_channelsself.depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups)self.pointwise_conv = nn.Conv2d(out_channels, out_channels, 1, 1, 0, 1, groups=1)def forward(self, x):x = self.depthwise_conv(x)x = self.pointwise_conv(x)return xclass CNNGPT(nn.Module):def __init__(self) -> None:super().__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(32),nn.ReLU(),DepthwiseConv2d(32, 32, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(32),nn.ReLU(),nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(64),nn.ReLU(),DepthwiseConv2d(64, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(128),nn.ReLU(),DepthwiseConv2d(128, 128, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(256),nn.ReLU(),DepthwiseConv2d(256, 256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(512),nn.ReLU(),DepthwiseConv2d(512, 512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),nn.BatchNorm2d(1024),nn.AdaptiveAvgPool2d(1),nn.Flatten(),nn.Linear(1024, 100),)def forward(self, x):y = self.model(x)return ydef get_flops_params(model):from thop import profilemodel.eval()flops, params = profile(model,inputs=[torch.randn([1, 3, 224, 224]),],)print(f"flops: {flops/1000**3} G, params: {params/1000**2} M")return flops, paramsif __name__ == "__main__":model = CNNGPT()get_flops_params(model)
相关文章:
基于chatGPT设计卷积神经网络
1. 简介 本文主要介绍基于chatGPT,设计一个针对骁龙855芯片设计的友好型神经网络。 提问->跑通总共花了5min左右,最终得到的网络在Cifar100数据集上与ResNet18的精度对比如下。 模型flopsparamstrain acc1/5test acc1/5ResNet18(timm)1.8211.18~98…...
java.sql.Date和java.util.Date的区别
参考答案 java.sql.Date 是 java.util.Date 的子类java.util.Date 是 JDK 中的日期类,精确到时、分、秒、毫秒java.sql.Date 与数据库 Date 相对应的一个类型,只有日期部分,时分秒都会设置为 0,如:2019-10-23 00:00:0…...

动态规划---线性dp和区间dp
动态规划(三) 目录动态规划(三)一:线性DP1.数字三角形1.1数字三角形题目1.2代码思路1.3代码实现(正序and倒序)2.最长上升子序列2.1最长上升子序列题目2.2代码思路2.3代码实现3.最长公共子序列3.1最长公共子序列题目3.2代码思路3.3代码实现4.石子合并4.1题目如下4.2代…...
常见的2D与3D碰撞检测算法
分离轴分离轴定理(Separating Axis Theorem)是用于解决2D或3D物体碰撞检测问题的一种方法。其基本思想是,如果两个物体未发生碰撞,那么可以找到一条分离轴(即一条直线或平面),两个物体在该轴上的…...

STM32 10个工程篇:1.IAP远程升级(二)
一直提醒自己要更新CSDN博客,但是确实这段时间到了一个项目的关键节点,杂七杂八的事情突然就一涌而至。STM32、FPGA下位机代码和对应Labview的IAP升级助手、波形设置助手上位机代码笔者已经调试通过,因为不想去水博客、凑数量,复制…...

Unity+ChatGpt的联动 AICommand
果然爱是会消失的,对吗 chatGpt没出现之前起码还看人家的文章,现在都是随便你。 本着师夷长技以制夷的思路,既然打不过,那么我就加入 github地址:https://github.com/keijiro/AICommand 文档用chatGpt翻译如下&#…...

STM-32:按键控制LED灯 程序详解
目录一、基本原理二、接线图三、程序思路3.1库函数3.2程序代码注:一、基本原理 左边是STM322里电路每一个端口均可以配置的电路部分,右边部分是外接设备 电路图。 配置为 上拉输入模式的意思就是,VDD开关闭合,VSS开关断开。 浮空…...

北邮22信通:(8)实验1 题目五:大整数加减法(搬运官方代码)
北邮22信通一枚~ 跟随课程进度每周更新数据结构与算法的代码和文章 持续关注作者 解锁更多邮苑信通专属代码~ 上一篇文章: 北邮22信通:(7)实验1 题目四:一元多项式(节省内存版)_青山如…...

Fiddler抓取https史上最强教程
有任何疑问建议观看下面视频 2023最新Fiddler抓包工具实战,2小时精通十年技术!!!对于想抓取HTTPS的测试初学者来说,常用的工具就是fiddler。 但是初学时,大家对于fiddler如何抓取HTTPS难免走歪路ÿ…...

STM32开发基础知识入门
C语言基础 位操作 对基本类型变量可以在位级别进行操作。 1) 不改变其他位的值的状况下,对某几个位进行设值。 先对需要设置的位用&操作符进行清零操作,然后用|操作符设值。 2) 移位操作提高代码的可读性。 3) ~取反操作使用技巧 可用于对某…...

学习操作系统的必备教科书《操作系统:原理与实现》| 文末赠书4本
使用了6年的实时操作系统,是时候梳理一下它的知识点了 摘要: 本文简单介绍了博主学习操作系统的心路历程,同时还给大家总结了一下当下流行的几种实时操作系统,以及在工程中OSAL应该如何设计。希望对大家有所启发和帮助。 文章目录…...
大数据的常用算法(分类、回归分析、聚类、关联规则、神经网络方法、web数据挖掘)
在大数据时代,数据挖掘是最关键的工作。大数据的挖掘是从海量、不完全的、有噪声的、模糊的、随机的大型数据库中发现隐含在其中有价值的、潜在有用的信息和知识的过程,也是一种决策支持过程。其主要基于人工智能,机器学习,模式学…...

【数据结构】详解二叉树与堆与堆排序的关系
🌇个人主页:平凡的小苏 📚学习格言:别人可以拷贝我的模式,但不能拷贝我不断往前的激情 🛸C语言专栏:https://blog.csdn.net/vhhhbb/category_12174730.html 🚀数据结构专栏ÿ…...

【Pandas】数据分析入门
文章目录前言一、Pandas简介1.1 什么是Pandas1.2 Pandas应用二、Series结构2.1 Series简介2.2 基本使用三、DataFrame结构3.1 DataFrame简介3.2 基本使用四、Pandas-CSV4.1 CSV简介4.2 读取CSV文件4.3 数据处理五、数据清洗5.1 数据清洗的方法5.2 清洗案例总结前言 大家好&…...

【c++】:list模拟实现“任意位置插入删除我最强ƪ(˘⌣˘)ʃ“
文章目录 前言一.list的基本功能的使用二.list的模拟实现总结前言 1. list是可以在常数范围内在任意位置进行插入和删除的序列式容器,并且该容器可以前后双向迭代。2. list的底层是双向链表结构,双向链表中每个元素存储在互不相关的独立节点中࿰…...

QT表格控件实例(Table Widget 、Table View)
欢迎小伙伴的点评✨✨,相互学习🚀🚀🚀 博主🧑🧑 本着开源的精神交流Qt开发的经验、将持续更新续章,为社区贡献博主自身的开源精神👩🚀 文章目录前言一、图示实例二、列…...

第二章Vue组件化编程
文章目录模块与组件、模块化与组件化模块组件模块化组件化Vue中的组件含义非单文件组件基本使用组件注意事项使用 kebab-case使用 PascalCase组件的嵌套模板templateVueComponent一个重要的内置功能单文件组件Vue脚手架使用Vue CLI脚手架先配置环境初始化脚手架分析脚手架结构实…...

面试官:vue2和vue3的区别有哪些
目录 多根节点,fragment(碎片) Composition API reactive 函数是用来创建响应式对象 Ref toRef toRefs 去除了管道 v-model的prop 和 event 默认名称会更改 vue2写法 Vue 3写法 vue3组件需要使用v-model时的写法 其他语法 1. 创…...
【TopK问题】——用堆实现
文章目录一、TopK问题是什么二、解决方法三、时间复杂度一、TopK问题是什么 TopK问题就是从1000个数中找出前K个最大的数或者最小的数这样的类似问题。 不过并不要求这k个数字必须是有序的,如果题目有要求,则进行堆排序即可。 还有比如求出全国玩韩信…...

【Spring从成神到升仙系列 四】从源码分析 Spring 事务的来龙去脉
👏作者简介:大家好,我是爱敲代码的小黄,独角兽企业的Java开发工程师,CSDN博客专家,阿里云专家博主📕系列专栏:Java设计模式、数据结构和算法、Kafka从入门到成神、Kafka从成神到升仙…...
【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15
缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下: struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...
rknn优化教程(二)
文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK,开始写第二篇的内容了。这篇博客主要能写一下: 如何给一些三方库按照xmake方式进行封装,供调用如何按…...
HTML前端开发:JavaScript 常用事件详解
作为前端开发的核心,JavaScript 事件是用户与网页交互的基础。以下是常见事件的详细说明和用法示例: 1. onclick - 点击事件 当元素被单击时触发(左键点击) button.onclick function() {alert("按钮被点击了!&…...
C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)
名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...

解读《网络安全法》最新修订,把握网络安全新趋势
《网络安全法》自2017年施行以来,在维护网络空间安全方面发挥了重要作用。但随着网络环境的日益复杂,网络攻击、数据泄露等事件频发,现行法律已难以完全适应新的风险挑战。 2025年3月28日,国家网信办会同相关部门起草了《网络安全…...

群晖NAS如何在虚拟机创建飞牛NAS
套件中心下载安装Virtual Machine Manager 创建虚拟机 配置虚拟机 飞牛官网下载 https://iso.liveupdate.fnnas.com/x86_64/trim/fnos-0.9.2-863.iso 群晖NAS如何在虚拟机创建飞牛NAS - 个人信息分享...
PostgreSQL——环境搭建
一、Linux # 安装 PostgreSQL 15 仓库 sudo dnf install -y https://download.postgresql.org/pub/repos/yum/reporpms/EL-$(rpm -E %{rhel})-x86_64/pgdg-redhat-repo-latest.noarch.rpm# 安装之前先确认是否已经存在PostgreSQL rpm -qa | grep postgres# 如果存在࿰…...

c++第七天 继承与派生2
这一篇文章主要内容是 派生类构造函数与析构函数 在派生类中重写基类成员 以及多继承 第一部分:派生类构造函数与析构函数 当创建一个派生类对象时,基类成员是如何初始化的? 1.当派生类对象创建的时候,基类成员的初始化顺序 …...
用鸿蒙HarmonyOS5实现中国象棋小游戏的过程
下面是一个基于鸿蒙OS (HarmonyOS) 的中国象棋小游戏的实现代码。这个实现使用Java语言和鸿蒙的Ability框架。 1. 项目结构 /src/main/java/com/example/chinesechess/├── MainAbilitySlice.java // 主界面逻辑├── ChessView.java // 游戏视图和逻辑├──…...

MySQL的pymysql操作
本章是MySQL的最后一章,MySQL到此完结,下一站Hadoop!!! 这章很简单,完整代码在最后,详细讲解之前python课程里面也有,感兴趣的可以往前找一下 一、查询操作 我们需要打开pycharm …...