RK3568笔记三:基于ResNet18的Cifar-10分类识别训练部署
若该文为原创文章,转载请注明原文出处。
本篇文章参考的是野火-lubancat的rk3568教程,本篇记录了在正点原子的ATK-DLK3568部署。
一、介绍
ResNet18 是一种卷积神经网络,它有 18 层深度,其中包括带有权重的卷积层和全连接层。它是ResNet 系列网络的一个变体,使用了残差连接(residual connection)来解决深度网络的退化问题。
ResNet(Residual Neural Network)由微软研究院的 Kaiming He 等人在 2015 年提出,ResNet 的结 构可以极快的加速神经网络的训练,模型的准确率也有比较大的提升。 ResNet 是一种残差网络,可以把它理解为一个子网络,这个子网络经过堆叠可以构成一个很深的 网络。ResNet 系列有多种变体,如 ResNet18,ResNet34,ResNet50,ResNet101 和 ResNet152,其 网络结构如下:

这里我们主要看下 ResNet18,ResNet18 基本含义是网络的基本架构是 ResNet,网络的深度是 18层,是带有权重的 18 层,不包括 BN 层,池化层。ResNet18 使用的基本残差单元,每个单元由两 个 3x3 卷积层组成,中间有一个 BN 层和一个 ReLU 激活函数。
PyTorch 中的 ResNet18 源码实现:https://github.com/pytorch/vision/blob/main/torchvision/models/ resnet.py
二、环境安装
环境分为两部分:一是训练的环境;二是rknn环境;
rknn环境前面有介绍,自行安装;训练的环境是windows电脑无gpu,使用的是CPU
1、创建虚拟环境
conda create -n ResNet18_env python=3.8 -y
2、激活环境
conda activate ResNet18_env
3、安装环境
pip install torchvision
pip install onnxruntime
三、训练
自定义一个 ResNet18 网络结构,并使用 CIFAR-10 数据集进行简单测试。CIFAR-10 数据集由 10 个类别的 60000 张 32x32 彩色图像组成,每个类别有 6000 张图像,总共分为 50000 张训练图像和 10000 张测试图像。
resnet18.py
import os
import torchvision
import torch
import torch.nn as nn#device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device=torch.device("cpu")# Transform configuration and data augmentation
transform_train=torchvision.transforms.Compose([torchvision.transforms.Pad(4),torchvision.transforms.RandomHorizontalFlip(), #图像一半的概率翻转,一半的概率不翻转torchvision.transforms.RandomCrop(32), #图像随机裁剪成32*32# torchvision.transforms.RandomVerticalFlip(),# torchvision.transforms.RandomRotation(15),torchvision.transforms.ToTensor(), #转为Tensor ,归一化torchvision.transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])#torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))])
transform_test=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])#torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
# epoch时才对数据集进行以上数据增强操作num_classes=10
batch_size=128
learning_rate=0.001
num_epoches=100
classes = ("plane","car","bird","cat","deer","dog","frog","horse","ship","truck")# load downloaded dataset
train_dataset = torchvision.datasets.CIFAR10('./data', download=True, train=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10('./data', download=True, train=False, transform=transform_test)# Data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# Define 3*3 convolutional neural network
def conv3x3(in_channels, out_channels, stride=1):return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(ResidualBlock, self).__init__()self.conv1 = conv3x3(in_channels, out_channels, stride)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(out_channels, out_channels)self.bn2 = nn.BatchNorm2d(out_channels)self.downsample = downsampledef forward(self, x):residual=xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if(self.downsample):residual = self.downsample(x)out += residualout = self.relu(out)return out# 自定义一个神经网络,使用nn.model,,通过__init__初始化每一层神经网络。
# 使用forward连接数据
class ResNet(nn.Module):def __init__(self, block, layers, num_classes):super(ResNet, self).__init__()self.in_channels = 16self.conv = conv3x3(3, 16)self.bn = torch.nn.BatchNorm2d(16)self.relu = torch.nn.ReLU(inplace=True)self.layer1 = self._make_layers(block, 16, layers[0])self.layer2 = self._make_layers(block, 32, layers[1], 2)self.layer3 = self._make_layers(block, 64, layers[2], 2)self.layer4 = self._make_layers(block, 128, layers[3], 2)self.avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))self.fc = torch.nn.Linear(128, num_classes)def _make_layers(self, block, out_channels, blocks, stride=1):downsample = Noneif (stride != 1) or (self.in_channels != out_channels):downsample = torch.nn.Sequential(conv3x3(self.in_channels, out_channels, stride=stride),torch.nn.BatchNorm2d(out_channels))layers = []layers.append(block(self.in_channels, out_channels, stride, downsample))self.in_channels = out_channelsfor i in range(1, blocks):layers.append(block(out_channels, out_channels))return torch.nn.Sequential(*layers)def forward(self, x):out = self.conv(x)out = self.bn(out)out = self.relu(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avg_pool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out# Make model,使用cpu
model=ResNet(ResidualBlock, [2,2,2,2], num_classes).to(device=device)# 打印model结构
# print(f"Model structure: {model}\n\n")# 优化器和损失函数
criterion = nn.CrossEntropyLoss() #交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) #优化器随机梯度下降if __name__ == "__main__":# Train the modeltotal_step = len(train_loader)for epoch in range(0,num_epoches):for i, (images, labels) in enumerate(train_loader):images = images.to(device=device)labels = labels.to(device=device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()# 反向传播loss.backward()# 更新参数optimizer.step()#sum_loss += loss.item()#_, predicted = torch.max(outputs.data, dim=1)#total += labels.size(0)#correct += predicted.eq(labels.data).cpu().sum()if (i+1) % total_step == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epoches, i+1, total_step, loss.item()))print("Finished Tranining")# 保存权重文件#torch.save(model.state_dict(), 'model_weights.pth')#torch.save(model, 'model.pt')print('\nTest the model')model.eval()with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device=device)labels = labels.to(device=device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('在10000张测试集图片上的准确率:{:.4f} %'.format(100 * correct / total))# 导出onnx模型x = torch.randn((1, 3, 32, 32))torch.onnx.export(model, x, './resnet18.onnx', opset_version=12, input_names=['input'], output_names=['output'])
这里有个要注意的,数据集已经提前下载好了,所以没有在线下载
数据集下载是通过下面代码,数据集放在data目录下:
# load downloaded dataset
train_dataset = torchvision.datasets.CIFAR10('./data', download=True, train=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10('./data', download=True, train=False, transform=transform_test)
类型分类为10类
classes = ("plane","car","bird","cat","deer","dog","frog","horse","ship","truck")
等待大概1小时,训练结束后会在当前目录下生成resnet18.onnx模型
四、测试onnx模型
测试代码如下:
test_resnet18_onnx.py
import os, syssys.path.append(os.getcwd())
import onnxruntime
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Imagedef to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# 自定义的数据增强
def get_test_transform(): return transforms.Compose([transforms.Resize([32, 32]),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])# 推理的图片路径
image = Image.open('./horse5.jpg').convert('RGB')img = get_test_transform()(image)
img = img.unsqueeze_(0) # -> NCHW, 1,3,224,224
# 模型加载
onnx_model_path = "resnet18.onnx"
resnet_session = onnxruntime.InferenceSession(onnx_model_path)
inputs = {resnet_session.get_inputs()[0].name: to_numpy(img)}
outs = resnet_session.run(None, inputs)[0]print("onnx weights", outs)
print("onnx prediction", outs.argmax(axis=1)[0])

测试预测结果7,对应的是马。
五、RKNN模型转换
打开ATK搭建好的虚拟机,进入环境rknn2_env,RKNN环境要确保安装好。
转换代码在rknn-toolkit2目录下的example的pytorch里也有,参考代码如下:
rknn_transfer.py
import numpy as np
import cv2
from rknn.api import RKNN
import torchvision.models as models
import torch
import osdef softmax(x):return np.exp(x)/sum(np.exp(x))def torch_version():import torchtorch_ver = torch.__version__.split('.')torch_ver[2] = torch_ver[2].split('+')[0]return [int(v) for v in torch_ver]if __name__ == '__main__':if torch_version() < [1, 9, 0]:import torchprint("Your torch version is '{}', in order to better support the Quantization Aware Training (QAT) model,\n""Please update the torch version to '1.9.0' or higher!".format(torch.__version__))exit(0)MODEL = './resnet18.onnx'# Create RKNN objectrknn = RKNN(verbose=True)# Pre-process configprint('--> Config model')rknn.config(mean_values=[127.5, 127.5, 127.5], std_values=[255, 255, 255], target_platform='rk3568')#rknn.config(mean_values=[123.675, 116.28, 103.53], std_values=[58.395, 58.395, 58.395], target_platform='rk3568')#rknn.config(mean_values=[125.307, 122.961, 113.8575], std_values=[51.5865, 50.847, 51.255], target_platform='rk3568')print('done')# Load modelprint('--> Loading model')#ret = rknn.load_pytorch(model=model, input_size_list=input_size_list)ret = rknn.load_onnx(model=MODEL)if ret != 0:print('Load model failed!')exit(ret)print('done')# Build modelprint('--> Building model')ret = rknn.build(do_quantization=False)if ret != 0:print('Build model failed!')exit(ret)print('done')# Export rknn modelprint('--> Export rknn model')ret = rknn.export_rknn('./resnet_18_100.rknn')if ret != 0:print('Export rknn model failed!')exit(ret)print('done')# Set inputsimg = cv2.imread('./0_125.jpg')img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = cv2.resize(img,(32,32))#img = np.expand_dims(img, 0)# Init runtime environmentprint('--> Init runtime environment')ret = rknn.init_runtime()if ret != 0:print('Init runtime environment failed!')exit(ret)print('done')# Inferenceprint('--> Running model')outputs = rknn.inference(inputs=[img])np.save('./pytorch_resnet18_qat_0.npy', outputs[0])#show_outputs(softmax(np.array(outputs[0][0])))print(outputs)print('done')rknn.release()
运行python rknn_transfer.py,正常生成rknn文件。

会在当前目录下生成rknn模型

六、部署
导出rknn后,把rknn和测试图片通过adb上传到开发板。
rknnlite_inference0.py
import numpy as np
import cv2
import os
from rknnlite.api import RKNNLiteIMG_PATH = '2_67.jpg'
RKNN_MODEL = './resnet18.rknn'
img_height = 32
img_width = 32
class_names = ["plane","car","bird","cat","deer","dog","frog","horse","ship","truck"]# Create RKNN object
rknn_lite = RKNNLite()# load RKNN model
print('--> Load RKNN model')
ret = rknn_lite.load_rknn(RKNN_MODEL)
if ret != 0:print('Load RKNN model failed')exit(ret)
print('done')# Init runtime environment
print('--> Init runtime environment')
ret = rknn_lite.init_runtime()
if ret != 0:print('Init runtime environment failed!')exit(ret)
print('done')# load image
img = cv2.imread(IMG_PATH)
img = cv2.resize(img,(32,32))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.expand_dims(img, 0)# runing model
print('--> Running model')
outputs = rknn_lite.inference(inputs=[img])
print("result: ", outputs)
print("This image most likely belongs to {}.".format(class_names[np.argmax(outputs)])
)rknn_lite.release()
在开发板上运行结果

识别出来horse,验证正常。
如有侵权,或需要完整代码,请及时联系博主。
相关文章:
RK3568笔记三:基于ResNet18的Cifar-10分类识别训练部署
若该文为原创文章,转载请注明原文出处。 本篇文章参考的是野火-lubancat的rk3568教程,本篇记录了在正点原子的ATK-DLK3568部署。 一、介绍 ResNet18 是一种卷积神经网络,它有 18 层深度,其中包括带有权重的卷积层和全连接层。它…...
块状数据结构学习笔记
分块 分块的思想和珂朵莉树很类似,就是把原序列分成若干个块,对块进行操作的奇妙思想。复杂度通常带根号。分块的块长也有讲究,通常对于大小为 n n n 的数组,取距离 n \sqrt n n 最近的 2 2 2 的幂数或直接取 n \sqrt n n…...
DOM4J解析.XML文件
<?xml version"1.0" encoding"utf-8" ?> <books><book id"SN123123413241"><name>java编程思想</name><author>华仔</author><price>9.9</price></book><book id"SN1234…...
黑豹程序员-架构师学习路线图-百科:MVC的演变终点SpringMVC
MVC发展史 在我们开发小型项目时,我们代码是混杂在一起的,术语称为紧耦合。 如最终写ASP、PHP。里面既包括服务器端代码,数据库操作的代码,又包括前端页面代码、HTML展现的代码、CSS美化的代码、JS交互的代码。可以看到早期编程就…...
二、BurpSuite Intruder暴力破解
一、介绍 解释: Burp Suite Intruder是一款功能强大的网络安全测试工具,它用于执行暴力破解攻击。它是Burp Suite套件的一部分,具有高度可定制的功能,能够自动化和批量化执行各种攻击,如密码破解、参数枚举和身份验证…...
solidworks 2024新功能之-让您的工作更加高效
您可以创建杰出的设计,并将这些杰出的设计将融入产品体验中。为了帮您简化和加快由概念到成品的产品开发流程,SOLIDWORKS 2024 涵盖全新的用户驱动型增强功能,致力于帮您实现更智能、更快速地与您的团队和外部合作伙伴协同工作。 SOLIDWORKS…...
华为eNSP配置专题-VRRP的配置
文章目录 华为eNSP配置专题-VRRP的配置0、参考文档1、前置环境1.1、宿主机1.2、eNSP模拟器 2、基本环境搭建2.1、基本终端构成和连接 2.VRRP的配置2.1、PC1的配置2.2、接入交换机acsw的配置2.3、核心交换机coresw1的配置2.4、核心交换机coresw2的配置2.5、配置VRRP2.6、配置出口…...
LuatOS-SOC接口文档(air780E)--lcd - lcd驱动模块
常量 常量 类型 解释 lcd.font_opposansm8 font 8号字体 lcd.font_unifont_t_symbols font 符号字体 lcd.font_open_iconic_weather_6x_t font 天气字体 lcd.font_opposansm10 font 10号字体 lcd.font_opposansm12 font 12号字体 lcd.font_opposansm16 font…...
敏捷是怎么提高工作效率的
敏捷管理是一门极力减少不必要工作量的艺术。 谷歌、亚马逊、苹果、微信、京东等全球 500 强企业都在用的管理方法,适用于各行各业,被盛赞为应获“管理学的诺贝尔奖”。 它专注于让员工不受种种杂事的羁绊,激发个体斗志,释放出巨大…...
【C++】哈希的应用 -- 布隆过滤器
文章目录 一、布隆过滤器提出二、布隆过滤器概念三、布隆过滤器哈希函数个数的选择四、布隆过滤器的实现1.布隆过滤器的插入2.布隆过滤器的查找3.布隆过滤器删除4.完整代码实现 五、布隆过滤器总结1.布隆过滤器优点2.布隆过滤器缺陷3.布隆过滤器的应用4.布隆过滤器相关面试题 一…...
如何在Git中修改远程仓库地址
原文(可不登录复制代码):如何在Git中修改远程仓库地址-北的杂货间 Git是广泛使用的分布式版本控制系统,它允许开发者在本地仓库上工作,并将更改上传到远程仓库。然而,有时候你可能需要修改远程仓库的地址&…...
Go语言的sync.Once()函数
sync.Once 是 Go 语言标准库 sync 包提供的一个类型,它用于确保一个函数只会被执行一次,即使在多个 goroutine 中同时调用。 sync.Once 包含一个 Do 方法,其签名如下: func (o *Once) Do(f func()) Do 方法接受一个函数作为参数…...
修改 Stable Diffusion 使 api 接口增加模型参数
参考:https://zhuanlan.zhihu.com/p/644545784 1、修改 modules/api/models.py 中的 StableDiffusionTxt2ImgProcessingAPI 增加模型名称 StableDiffusionTxt2ImgProcessingAPI PydanticModelGenerator("StableDiffusionProcessingTxt2Img",StableDiff…...
微信小程序自定义组件及会议管理与个人中心界面搭建
一、自定义tabs组件 1.1 创建自定义组件 新建一个components文件夹 --> tabs文件夹 --> tabs文件 创建好之后win7 以上的系统会报个错误:提示代码分析错误,已经被其他模块引用,只需要在 在project.config.json文件里添加两行配置 &…...
UiPath:一家由生成式AI驱动的流程自动化软件公司
来源:猛兽财经 作者:猛兽财经 总结: (1)UiPath(PATH)的股价并没有因为生成式AI的炒作而上涨,但很可能会成为主要受益者。 (2)即使在严峻的宏观环境下,UiPath的收入还在不…...
使用AI编写测试用例——详细教程
随着今年chatGPT的大热,每个行业都试图从这项新技术当中获得一些收益我之前也写过一篇测试领域在AI技术中的探索:软件测试中的AI——运用AI编写测试用例现阶段AI还不能完全替代人工测试用例编写,但是如果把AI当做一个提高效率的工具ÿ…...
又哭又笑,这份面试宝典要是早遇到就好了
01、算法原理 选择排序(Selection sort)是一种简单直观的排序算法。 第一次从待排序的数据元素中选出最小(或最大)的一个元素,存放在序列的起始位置,然后再从剩余的未排序元素中寻找到最小(大)元素&#…...
订单30分钟自动关闭的五种解决方案
1 前言 在开发中,往往会遇到一些关于延时任务的需求。例如 生成订单30分钟未支付,则自动取消生成订单60秒后,给用户发短信 对上述的任务,我们给一个专业的名字来形容,那就是延时任务 。那么这里就会产生一个问题,这…...
【vSphere 8 自签名 VMCA 证书】企业 CA 签名证书替换 vSphere VMCA CA 证书Ⅰ—— 生成 CSR
目录 替换拓扑图证书关系示意图说明 & 关联博文1. 默认证书截图2. 使用 certificate-manager 生成CSR2.1 创建存放CSR的目录2.2 记录PNID和IP2.3 生成CSR2.4 验证CSR 参考资料 替换拓扑图 证书关系示意图 本系列博文要实现的拓扑是 说明 & 关联博文 因为使用企业 …...
【diffusion model】扩散模型入门
写在最前,参加DataWhale 10月组队学习。 参考资料: HuggingFace 开源diffusion-models-class 1.扩散模型介绍 2.调用模型生成一张赛博风格的猫咪图片 2.1 安装依赖包 %pip install -qq -U diffusers datasets transformers accelerate ftfy pyarrow9…...
CloudCompare实战:点云二次曲面拟合精度分析与优化策略
1. 二次曲面拟合基础与CloudCompare实现 点云数据处理中,曲面拟合是个绕不开的话题。我第一次接触CloudCompare的二次曲面拟合功能时,就被它的简洁界面吸引,但实际用起来发现没那么简单。二次曲面拟合的本质,是用数学方程来描述点…...
收藏!SaaS小白必看:AI大模型落地实战路线图,从功能堆砌到价值创造
本文分析了SaaS公司在整合AI大模型时应避免“功能堆砌”陷阱,并介绍了三大AI技术路线:Prompt/RAG/微调的特点及适用场景。文章强调SaaSAI产品的成功关键在于技术路线与客户价值的适配,提出了分阶段组合策略,即初创期以提示词为主&…...
车载T-BOX中MCU与SoC的SPI通信协议设计与实现
1. 车载T-BOX中的MCU与SoC通信需求解析 在车载T-BOX(Telematics BOX)这个黑匣子里,MCU(微控制器单元)和SoC(系统级芯片)就像两个性格迥异但必须密切配合的搭档。MCU通常负责实时性要求高的底层控…...
5分钟上手:用Python工具免费下载B站4K大会员视频终极指南
5分钟上手:用Python工具免费下载B站4K大会员视频终极指南 【免费下载链接】bilibili-downloader B站视频下载,支持下载大会员清晰度4K,持续更新中 项目地址: https://gitcode.com/gh_mirrors/bil/bilibili-downloader 你是否遇到过这样…...
Golang怎么用reflect获取类型名称_Golang如何动态获取变量的类型名称字符串【方法】
应使用 reflect.TypeOf(v).String() 获取稳定类型名,因 .Name() 仅对命名类型有效;需结合 .PkgPath() 和 .Elem() 等方法处理指针、接口、别名等场景。用 reflect.TypeOf 拿到类型,再调 .Name() 不一定行得通直接对变量调 reflect.TypeOf(v).…...
揭秘LLM代码生成落地困局:5类典型业务场景的个性化适配路径(含可复用决策树)
第一章:智能代码生成个性化适配策略 2026奇点智能技术大会(https://ml-summit.org) 智能代码生成已从通用模板输出迈向深度个性化适配阶段。开发者背景、项目约束、团队规范与运行时环境共同构成多维适配边界,单一模型输出无法满足真实工程场景的差异化…...
Matlab小波去噪实战:从wden函数参数优化到实际信号处理
1. 小波去噪与wden函数基础入门 第一次接触小波去噪时,我被它神奇的去噪效果惊艳到了。记得当时处理一组工业传感器数据,传统滤波方法怎么调参数都效果不佳,直到尝试了小波去噪才解决问题。Matlab中的wden函数是小波去噪的核心工具ÿ…...
Node.js服务器架构深度剖析:从事件驱动到多进程负载均衡
Node.js服务器架构深度剖析:从事件驱动到多进程负载均衡 【免费下载链接】understand-nodejs 通过源码分析nodejs原理 项目地址: https://gitcode.com/gh_mirrors/un/understand-nodejs Node.js作为基于事件驱动的单进程单线程应用,通过独特的架构…...
Spring Boot应用远程监控实战:用JConsole连接Docker容器里的JMX端口
Spring Boot应用远程监控实战:用JConsole连接Docker容器里的JMX端口 在云原生时代,Spring Boot应用越来越多地运行在Docker容器中。当我们需要监控这些容器化应用的性能指标、内存使用情况或线程状态时,JMX(Java Management Exte…...
5步掌握Open WebUI:企业级自托管AI平台部署实战指南
5步掌握Open WebUI:企业级自托管AI平台部署实战指南 【免费下载链接】open-webui User-friendly AI Interface (Supports Ollama, OpenAI API, ...) 项目地址: https://gitcode.com/GitHub_Trending/op/open-webui Open WebUI是一个功能丰富、可完全离线运行…...
