RK3568笔记三:基于ResNet18的Cifar-10分类识别训练部署
若该文为原创文章,转载请注明原文出处。
本篇文章参考的是野火-lubancat的rk3568教程,本篇记录了在正点原子的ATK-DLK3568部署。
一、介绍
ResNet18 是一种卷积神经网络,它有 18 层深度,其中包括带有权重的卷积层和全连接层。它是ResNet 系列网络的一个变体,使用了残差连接(residual connection)来解决深度网络的退化问题。
ResNet(Residual Neural Network)由微软研究院的 Kaiming He 等人在 2015 年提出,ResNet 的结 构可以极快的加速神经网络的训练,模型的准确率也有比较大的提升。 ResNet 是一种残差网络,可以把它理解为一个子网络,这个子网络经过堆叠可以构成一个很深的 网络。ResNet 系列有多种变体,如 ResNet18,ResNet34,ResNet50,ResNet101 和 ResNet152,其 网络结构如下:
这里我们主要看下 ResNet18,ResNet18 基本含义是网络的基本架构是 ResNet,网络的深度是 18层,是带有权重的 18 层,不包括 BN 层,池化层。ResNet18 使用的基本残差单元,每个单元由两 个 3x3 卷积层组成,中间有一个 BN 层和一个 ReLU 激活函数。
PyTorch 中的 ResNet18 源码实现:https://github.com/pytorch/vision/blob/main/torchvision/models/ resnet.py
二、环境安装
环境分为两部分:一是训练的环境;二是rknn环境;
rknn环境前面有介绍,自行安装;训练的环境是windows电脑无gpu,使用的是CPU
1、创建虚拟环境
conda create -n ResNet18_env python=3.8 -y
2、激活环境
conda activate ResNet18_env
3、安装环境
pip install torchvision
pip install onnxruntime
三、训练
自定义一个 ResNet18 网络结构,并使用 CIFAR-10 数据集进行简单测试。CIFAR-10 数据集由 10 个类别的 60000 张 32x32 彩色图像组成,每个类别有 6000 张图像,总共分为 50000 张训练图像和 10000 张测试图像。
resnet18.py
import os
import torchvision
import torch
import torch.nn as nn#device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device=torch.device("cpu")# Transform configuration and data augmentation
transform_train=torchvision.transforms.Compose([torchvision.transforms.Pad(4),torchvision.transforms.RandomHorizontalFlip(), #图像一半的概率翻转,一半的概率不翻转torchvision.transforms.RandomCrop(32), #图像随机裁剪成32*32# torchvision.transforms.RandomVerticalFlip(),# torchvision.transforms.RandomRotation(15),torchvision.transforms.ToTensor(), #转为Tensor ,归一化torchvision.transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])#torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))])
transform_test=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])#torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
# epoch时才对数据集进行以上数据增强操作num_classes=10
batch_size=128
learning_rate=0.001
num_epoches=100
classes = ("plane","car","bird","cat","deer","dog","frog","horse","ship","truck")# load downloaded dataset
train_dataset = torchvision.datasets.CIFAR10('./data', download=True, train=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10('./data', download=True, train=False, transform=transform_test)# Data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# Define 3*3 convolutional neural network
def conv3x3(in_channels, out_channels, stride=1):return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(ResidualBlock, self).__init__()self.conv1 = conv3x3(in_channels, out_channels, stride)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(out_channels, out_channels)self.bn2 = nn.BatchNorm2d(out_channels)self.downsample = downsampledef forward(self, x):residual=xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if(self.downsample):residual = self.downsample(x)out += residualout = self.relu(out)return out# 自定义一个神经网络,使用nn.model,,通过__init__初始化每一层神经网络。
# 使用forward连接数据
class ResNet(nn.Module):def __init__(self, block, layers, num_classes):super(ResNet, self).__init__()self.in_channels = 16self.conv = conv3x3(3, 16)self.bn = torch.nn.BatchNorm2d(16)self.relu = torch.nn.ReLU(inplace=True)self.layer1 = self._make_layers(block, 16, layers[0])self.layer2 = self._make_layers(block, 32, layers[1], 2)self.layer3 = self._make_layers(block, 64, layers[2], 2)self.layer4 = self._make_layers(block, 128, layers[3], 2)self.avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))self.fc = torch.nn.Linear(128, num_classes)def _make_layers(self, block, out_channels, blocks, stride=1):downsample = Noneif (stride != 1) or (self.in_channels != out_channels):downsample = torch.nn.Sequential(conv3x3(self.in_channels, out_channels, stride=stride),torch.nn.BatchNorm2d(out_channels))layers = []layers.append(block(self.in_channels, out_channels, stride, downsample))self.in_channels = out_channelsfor i in range(1, blocks):layers.append(block(out_channels, out_channels))return torch.nn.Sequential(*layers)def forward(self, x):out = self.conv(x)out = self.bn(out)out = self.relu(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avg_pool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out# Make model,使用cpu
model=ResNet(ResidualBlock, [2,2,2,2], num_classes).to(device=device)# 打印model结构
# print(f"Model structure: {model}\n\n")# 优化器和损失函数
criterion = nn.CrossEntropyLoss() #交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) #优化器随机梯度下降if __name__ == "__main__":# Train the modeltotal_step = len(train_loader)for epoch in range(0,num_epoches):for i, (images, labels) in enumerate(train_loader):images = images.to(device=device)labels = labels.to(device=device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()# 反向传播loss.backward()# 更新参数optimizer.step()#sum_loss += loss.item()#_, predicted = torch.max(outputs.data, dim=1)#total += labels.size(0)#correct += predicted.eq(labels.data).cpu().sum()if (i+1) % total_step == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epoches, i+1, total_step, loss.item()))print("Finished Tranining")# 保存权重文件#torch.save(model.state_dict(), 'model_weights.pth')#torch.save(model, 'model.pt')print('\nTest the model')model.eval()with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device=device)labels = labels.to(device=device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('在10000张测试集图片上的准确率:{:.4f} %'.format(100 * correct / total))# 导出onnx模型x = torch.randn((1, 3, 32, 32))torch.onnx.export(model, x, './resnet18.onnx', opset_version=12, input_names=['input'], output_names=['output'])
这里有个要注意的,数据集已经提前下载好了,所以没有在线下载
数据集下载是通过下面代码,数据集放在data目录下:
# load downloaded dataset
train_dataset = torchvision.datasets.CIFAR10('./data', download=True, train=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10('./data', download=True, train=False, transform=transform_test)
类型分类为10类
classes = ("plane","car","bird","cat","deer","dog","frog","horse","ship","truck")
等待大概1小时,训练结束后会在当前目录下生成resnet18.onnx模型
四、测试onnx模型
测试代码如下:
test_resnet18_onnx.py
import os, syssys.path.append(os.getcwd())
import onnxruntime
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Imagedef to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# 自定义的数据增强
def get_test_transform(): return transforms.Compose([transforms.Resize([32, 32]),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])# 推理的图片路径
image = Image.open('./horse5.jpg').convert('RGB')img = get_test_transform()(image)
img = img.unsqueeze_(0) # -> NCHW, 1,3,224,224
# 模型加载
onnx_model_path = "resnet18.onnx"
resnet_session = onnxruntime.InferenceSession(onnx_model_path)
inputs = {resnet_session.get_inputs()[0].name: to_numpy(img)}
outs = resnet_session.run(None, inputs)[0]print("onnx weights", outs)
print("onnx prediction", outs.argmax(axis=1)[0])
测试预测结果7,对应的是马。
五、RKNN模型转换
打开ATK搭建好的虚拟机,进入环境rknn2_env,RKNN环境要确保安装好。
转换代码在rknn-toolkit2目录下的example的pytorch里也有,参考代码如下:
rknn_transfer.py
import numpy as np
import cv2
from rknn.api import RKNN
import torchvision.models as models
import torch
import osdef softmax(x):return np.exp(x)/sum(np.exp(x))def torch_version():import torchtorch_ver = torch.__version__.split('.')torch_ver[2] = torch_ver[2].split('+')[0]return [int(v) for v in torch_ver]if __name__ == '__main__':if torch_version() < [1, 9, 0]:import torchprint("Your torch version is '{}', in order to better support the Quantization Aware Training (QAT) model,\n""Please update the torch version to '1.9.0' or higher!".format(torch.__version__))exit(0)MODEL = './resnet18.onnx'# Create RKNN objectrknn = RKNN(verbose=True)# Pre-process configprint('--> Config model')rknn.config(mean_values=[127.5, 127.5, 127.5], std_values=[255, 255, 255], target_platform='rk3568')#rknn.config(mean_values=[123.675, 116.28, 103.53], std_values=[58.395, 58.395, 58.395], target_platform='rk3568')#rknn.config(mean_values=[125.307, 122.961, 113.8575], std_values=[51.5865, 50.847, 51.255], target_platform='rk3568')print('done')# Load modelprint('--> Loading model')#ret = rknn.load_pytorch(model=model, input_size_list=input_size_list)ret = rknn.load_onnx(model=MODEL)if ret != 0:print('Load model failed!')exit(ret)print('done')# Build modelprint('--> Building model')ret = rknn.build(do_quantization=False)if ret != 0:print('Build model failed!')exit(ret)print('done')# Export rknn modelprint('--> Export rknn model')ret = rknn.export_rknn('./resnet_18_100.rknn')if ret != 0:print('Export rknn model failed!')exit(ret)print('done')# Set inputsimg = cv2.imread('./0_125.jpg')img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = cv2.resize(img,(32,32))#img = np.expand_dims(img, 0)# Init runtime environmentprint('--> Init runtime environment')ret = rknn.init_runtime()if ret != 0:print('Init runtime environment failed!')exit(ret)print('done')# Inferenceprint('--> Running model')outputs = rknn.inference(inputs=[img])np.save('./pytorch_resnet18_qat_0.npy', outputs[0])#show_outputs(softmax(np.array(outputs[0][0])))print(outputs)print('done')rknn.release()
运行python rknn_transfer.py,正常生成rknn文件。
会在当前目录下生成rknn模型
六、部署
导出rknn后,把rknn和测试图片通过adb上传到开发板。
rknnlite_inference0.py
import numpy as np
import cv2
import os
from rknnlite.api import RKNNLiteIMG_PATH = '2_67.jpg'
RKNN_MODEL = './resnet18.rknn'
img_height = 32
img_width = 32
class_names = ["plane","car","bird","cat","deer","dog","frog","horse","ship","truck"]# Create RKNN object
rknn_lite = RKNNLite()# load RKNN model
print('--> Load RKNN model')
ret = rknn_lite.load_rknn(RKNN_MODEL)
if ret != 0:print('Load RKNN model failed')exit(ret)
print('done')# Init runtime environment
print('--> Init runtime environment')
ret = rknn_lite.init_runtime()
if ret != 0:print('Init runtime environment failed!')exit(ret)
print('done')# load image
img = cv2.imread(IMG_PATH)
img = cv2.resize(img,(32,32))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.expand_dims(img, 0)# runing model
print('--> Running model')
outputs = rknn_lite.inference(inputs=[img])
print("result: ", outputs)
print("This image most likely belongs to {}.".format(class_names[np.argmax(outputs)])
)rknn_lite.release()
在开发板上运行结果
识别出来horse,验证正常。
如有侵权,或需要完整代码,请及时联系博主。
相关文章:

RK3568笔记三:基于ResNet18的Cifar-10分类识别训练部署
若该文为原创文章,转载请注明原文出处。 本篇文章参考的是野火-lubancat的rk3568教程,本篇记录了在正点原子的ATK-DLK3568部署。 一、介绍 ResNet18 是一种卷积神经网络,它有 18 层深度,其中包括带有权重的卷积层和全连接层。它…...
块状数据结构学习笔记
分块 分块的思想和珂朵莉树很类似,就是把原序列分成若干个块,对块进行操作的奇妙思想。复杂度通常带根号。分块的块长也有讲究,通常对于大小为 n n n 的数组,取距离 n \sqrt n n 最近的 2 2 2 的幂数或直接取 n \sqrt n n…...

DOM4J解析.XML文件
<?xml version"1.0" encoding"utf-8" ?> <books><book id"SN123123413241"><name>java编程思想</name><author>华仔</author><price>9.9</price></book><book id"SN1234…...

黑豹程序员-架构师学习路线图-百科:MVC的演变终点SpringMVC
MVC发展史 在我们开发小型项目时,我们代码是混杂在一起的,术语称为紧耦合。 如最终写ASP、PHP。里面既包括服务器端代码,数据库操作的代码,又包括前端页面代码、HTML展现的代码、CSS美化的代码、JS交互的代码。可以看到早期编程就…...

二、BurpSuite Intruder暴力破解
一、介绍 解释: Burp Suite Intruder是一款功能强大的网络安全测试工具,它用于执行暴力破解攻击。它是Burp Suite套件的一部分,具有高度可定制的功能,能够自动化和批量化执行各种攻击,如密码破解、参数枚举和身份验证…...

solidworks 2024新功能之-让您的工作更加高效
您可以创建杰出的设计,并将这些杰出的设计将融入产品体验中。为了帮您简化和加快由概念到成品的产品开发流程,SOLIDWORKS 2024 涵盖全新的用户驱动型增强功能,致力于帮您实现更智能、更快速地与您的团队和外部合作伙伴协同工作。 SOLIDWORKS…...

华为eNSP配置专题-VRRP的配置
文章目录 华为eNSP配置专题-VRRP的配置0、参考文档1、前置环境1.1、宿主机1.2、eNSP模拟器 2、基本环境搭建2.1、基本终端构成和连接 2.VRRP的配置2.1、PC1的配置2.2、接入交换机acsw的配置2.3、核心交换机coresw1的配置2.4、核心交换机coresw2的配置2.5、配置VRRP2.6、配置出口…...
LuatOS-SOC接口文档(air780E)--lcd - lcd驱动模块
常量 常量 类型 解释 lcd.font_opposansm8 font 8号字体 lcd.font_unifont_t_symbols font 符号字体 lcd.font_open_iconic_weather_6x_t font 天气字体 lcd.font_opposansm10 font 10号字体 lcd.font_opposansm12 font 12号字体 lcd.font_opposansm16 font…...

敏捷是怎么提高工作效率的
敏捷管理是一门极力减少不必要工作量的艺术。 谷歌、亚马逊、苹果、微信、京东等全球 500 强企业都在用的管理方法,适用于各行各业,被盛赞为应获“管理学的诺贝尔奖”。 它专注于让员工不受种种杂事的羁绊,激发个体斗志,释放出巨大…...

【C++】哈希的应用 -- 布隆过滤器
文章目录 一、布隆过滤器提出二、布隆过滤器概念三、布隆过滤器哈希函数个数的选择四、布隆过滤器的实现1.布隆过滤器的插入2.布隆过滤器的查找3.布隆过滤器删除4.完整代码实现 五、布隆过滤器总结1.布隆过滤器优点2.布隆过滤器缺陷3.布隆过滤器的应用4.布隆过滤器相关面试题 一…...
如何在Git中修改远程仓库地址
原文(可不登录复制代码):如何在Git中修改远程仓库地址-北的杂货间 Git是广泛使用的分布式版本控制系统,它允许开发者在本地仓库上工作,并将更改上传到远程仓库。然而,有时候你可能需要修改远程仓库的地址&…...
Go语言的sync.Once()函数
sync.Once 是 Go 语言标准库 sync 包提供的一个类型,它用于确保一个函数只会被执行一次,即使在多个 goroutine 中同时调用。 sync.Once 包含一个 Do 方法,其签名如下: func (o *Once) Do(f func()) Do 方法接受一个函数作为参数…...
修改 Stable Diffusion 使 api 接口增加模型参数
参考:https://zhuanlan.zhihu.com/p/644545784 1、修改 modules/api/models.py 中的 StableDiffusionTxt2ImgProcessingAPI 增加模型名称 StableDiffusionTxt2ImgProcessingAPI PydanticModelGenerator("StableDiffusionProcessingTxt2Img",StableDiff…...

微信小程序自定义组件及会议管理与个人中心界面搭建
一、自定义tabs组件 1.1 创建自定义组件 新建一个components文件夹 --> tabs文件夹 --> tabs文件 创建好之后win7 以上的系统会报个错误:提示代码分析错误,已经被其他模块引用,只需要在 在project.config.json文件里添加两行配置 &…...

UiPath:一家由生成式AI驱动的流程自动化软件公司
来源:猛兽财经 作者:猛兽财经 总结: (1)UiPath(PATH)的股价并没有因为生成式AI的炒作而上涨,但很可能会成为主要受益者。 (2)即使在严峻的宏观环境下,UiPath的收入还在不…...

使用AI编写测试用例——详细教程
随着今年chatGPT的大热,每个行业都试图从这项新技术当中获得一些收益我之前也写过一篇测试领域在AI技术中的探索:软件测试中的AI——运用AI编写测试用例现阶段AI还不能完全替代人工测试用例编写,但是如果把AI当做一个提高效率的工具ÿ…...

又哭又笑,这份面试宝典要是早遇到就好了
01、算法原理 选择排序(Selection sort)是一种简单直观的排序算法。 第一次从待排序的数据元素中选出最小(或最大)的一个元素,存放在序列的起始位置,然后再从剩余的未排序元素中寻找到最小(大)元素&#…...

订单30分钟自动关闭的五种解决方案
1 前言 在开发中,往往会遇到一些关于延时任务的需求。例如 生成订单30分钟未支付,则自动取消生成订单60秒后,给用户发短信 对上述的任务,我们给一个专业的名字来形容,那就是延时任务 。那么这里就会产生一个问题,这…...

【vSphere 8 自签名 VMCA 证书】企业 CA 签名证书替换 vSphere VMCA CA 证书Ⅰ—— 生成 CSR
目录 替换拓扑图证书关系示意图说明 & 关联博文1. 默认证书截图2. 使用 certificate-manager 生成CSR2.1 创建存放CSR的目录2.2 记录PNID和IP2.3 生成CSR2.4 验证CSR 参考资料 替换拓扑图 证书关系示意图 本系列博文要实现的拓扑是 说明 & 关联博文 因为使用企业 …...

【diffusion model】扩散模型入门
写在最前,参加DataWhale 10月组队学习。 参考资料: HuggingFace 开源diffusion-models-class 1.扩散模型介绍 2.调用模型生成一张赛博风格的猫咪图片 2.1 安装依赖包 %pip install -qq -U diffusers datasets transformers accelerate ftfy pyarrow9…...

(十)学生端搭建
本次旨在将之前的已完成的部分功能进行拼装到学生端,同时完善学生端的构建。本次工作主要包括: 1.学生端整体界面布局 2.模拟考场与部分个人画像流程的串联 3.整体学生端逻辑 一、学生端 在主界面可以选择自己的用户角色 选择学生则进入学生登录界面…...
Axios请求超时重发机制
Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式: 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...

k8s业务程序联调工具-KtConnect
概述 原理 工具作用是建立了一个从本地到集群的单向VPN,根据VPN原理,打通两个内网必然需要借助一个公共中继节点,ktconnect工具巧妙的利用k8s原生的portforward能力,简化了建立连接的过程,apiserver间接起到了中继节…...

如何理解 IP 数据报中的 TTL?
目录 前言理解 前言 面试灵魂一问:说说对 IP 数据报中 TTL 的理解?我们都知道,IP 数据报由首部和数据两部分组成,首部又分为两部分:固定部分和可变部分,共占 20 字节,而即将讨论的 TTL 就位于首…...

RNN避坑指南:从数学推导到LSTM/GRU工业级部署实战流程
本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。 本文全面剖析RNN核心原理,深入讲解梯度消失/爆炸问题,并通过LSTM/GRU结构实现解决方案,提供时间序列预测和文本生成…...
A2A JS SDK 完整教程:快速入门指南
目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库ÿ…...
MinIO Docker 部署:仅开放一个端口
MinIO Docker 部署:仅开放一个端口 在实际的服务器部署中,出于安全和管理的考虑,我们可能只能开放一个端口。MinIO 是一个高性能的对象存储服务,支持 Docker 部署,但默认情况下它需要两个端口:一个是 API 端口(用于存储和访问数据),另一个是控制台端口(用于管理界面…...

《Docker》架构
文章目录 架构模式单机架构应用数据分离架构应用服务器集群架构读写分离/主从分离架构冷热分离架构垂直分库架构微服务架构容器编排架构什么是容器,docker,镜像,k8s 架构模式 单机架构 单机架构其实就是应用服务器和单机服务器都部署在同一…...
comfyui 工作流中 图生视频 如何增加视频的长度到5秒
comfyUI 工作流怎么可以生成更长的视频。除了硬件显存要求之外还有别的方法吗? 在ComfyUI中实现图生视频并延长到5秒,需要结合多个扩展和技巧。以下是完整解决方案: 核心工作流配置(24fps下5秒120帧) #mermaid-svg-yP…...

Python训练营-Day26-函数专题1:函数定义与参数
题目1:计算圆的面积 任务: 编写一个名为 calculate_circle_area 的函数,该函数接收圆的半径 radius 作为参数,并返回圆的面积。圆的面积 π * radius (可以使用 math.pi 作为 π 的值)要求:函数接收一个位置参数 radi…...