Unity3d 基于Barracuda推理库和YOLO算法实现对象检测功能
前言
近年来,随着AI技术的发展,在游戏引擎中实现和运行机器学习模型的需求也逐渐显现。Unity3d引擎官方推出深度学习推理框架–Barracuda ,旨在帮助开发者在Unity3d中轻松地实现和运行机器学习模型,它的主要功能是支持在 Unity 中加载和推理训练好的深度学习模型,尤其适用于需要人工智能(AI)或机器学习(ML)推理的游戏或应用。
YOLO(You Only Look Once)是一种用于目标检测的深度学习模型,它是由Joseph Redmon等人在2015年提出的。YOLO的核心思想是将目标检测问题转化为一个回归问题,在单一的神经网络中同时预测图像中的多个目标位置和类别标签。它通过将目标检测转化为回归问题,极大地提高了检测速度,并且在精度上也能达到非常好的水平。随着版本的更新和技术的不断进步,YOLO逐渐成为了计算机视觉领域中最重要和最广泛应用的模型之一,特别适用于实时处理、嵌入式设备和大规模部署。
本文依托上述两个技术,在Unity3d中实现YOLO的目标检测功能,基于Barracuda(2.0.0)的跨平台性,将实现包含移动端(目前测试了安卓)的目标检测功能,能检测出日常物体桌、椅、人、狗、羊、马等对象。
理论上本工程可以在Windows/Mac/iPhone/Android/Magic Leap/Switch/PS4/Xbox等系统和平台正常工作,目前仅测试了Windows和Android平台,相比Windows平台的流畅,Android手机上运行有明显的掉帧和卡顿,具体可以对比效果图。
官方给出支持的平台说明:
CPU 推理:支持所有 Unity 平台。
GPU 推理:支持所有 Unity 平台,但以下平台:
OpenGL ESon :使用 Vulkan/Metal。Android/iOS
OpenGL Core上:使用 Metal。Mac
WebGL:使用 CPU 推理。
关注并私信 U3D目标检测免费获取应用包(底部公众号)。
效果
手机端效果:
PC端效果:
实现
Barracuda 是一个简单、对开发人员友好的API,只需编写少量代码即可开始使用Barracuda:
var model = ModelLoader.Load(filename);
var engine = WorkerFactory.CreateWorker(model, WorkerFactory.Device.GPU);
var input = new Tensor(1, 1, 1, 10);
var output = engine.Execute(input).PeekOutput();
Barracuda 神经网络导入管道基于ONNX(Open Neural Network Exchange)格式的模型,允许您从各种外部框架(包括Pytorch、TensorFlow和Keras)引入神经网络模型。
关于模型
Barracuda目前仅支持推理,所以模型靠TensorFlow/Pytorch/Keras训练、导入,而且必须先将其转换为 ONNX,然后将其加载到 Unity中。ONNX(Open Neural Network Exchange)是一种用于ML 模型的开放格式。它允许您在各种ML框架和工具之间轻松交换模型。
Pytorch将模型导出到ONNX很容易
# network
net = ...# Input to the model
x = torch.randn(1, 3, 256, 256)# Export the model
torch.onnx.export(net, # model being runx, # model input (or a tuple for multiple inputs)"example.onnx", # where to save the model (can be a file or file-like object)export_params=True, # store the trained parameter weights inside the model fileopset_version=9, # the ONNX version to export the model todo_constant_folding=True, # whether to execute constant folding for optimizationinput_names = ['X'], # the model's input namesoutput_names = ['Y'] # the model's output names)
我这里准备的是很简单的模型,如下图:
确保ONNX模型的输入尺寸、通道顺序(NCHW)与Barracuda兼容。
因为要兼顾移动端效果,所以模型检测识别对象较少,以防止在移动设备上的推理慢。
UI搭建
运行时的UI相对简单,两个button用于打开摄像头和打开视频功能,一个Slider用于控制标记框的显示阈值,就是检测的可信度从0-1(0%-100%)的范围;一个rawImage组件用于显示检测的画面:
其次是标记框的UI,由一个图片和Text构成:
编码
加载模型
var model = ModelLoader.Load(resources.model);
其中模型类型是NNModel。
创建推理引擎 (Worker)并执行模型:
_worker = model.CreateWorker();
using (var t = new Tensor(_config.InputShape, _buffers.preprocess))_worker.Execute(t);
提取神经网络输出:
_worker.CopyOutput("Identity", _buffers.feature1);
_worker.CopyOutput("Identity_1", _buffers.feature2);
将网络的两个输出复制到缓冲区。
第一阶段后处理,检测数据:
var post1 = _resources.postprocess1;
post1.SetInt("ClassCount", _config.ClassCount);
post1.SetFloat("Threshold", threshold);
post1.SetBuffer(0, "Output", _buffers.post1);
post1.SetBuffer(0, "OutputCount", _buffers.counter);var width1 = _config.FeatureMap1Width;
post1.SetTexture(0, "Input", _buffers.feature1);
post1.SetInt("InputSize", width1);
post1.SetFloats("Anchors", _config.AnchorArray1);
post1.DispatchThreads(0, width1, width1, 1);var width2 = _config.FeatureMap2Width;
post1.SetTexture(0, "Input", _buffers.feature2);
post1.SetInt("InputSize", width2);
post1.SetFloats("Anchors", _config.AnchorArray2);
post1.DispatchThreads(0, width2, width2, 1);
聚合检测结果,使用两个特征图进行目标检测,执行目标定位(Bounding Box)预测。
第二阶段后处理,重叠移除:
var post2 = _resources.postprocess2;
post2.SetFloat("Threshold", 0.5f);
post2.SetBuffer(0, "Input", _buffers.post1);
post2.SetBuffer(0, "InputCount", _buffers.counter);
post2.SetBuffer(0, "Output", _buffers.post2);
post2.Dispatch(0, 1, 1, 1);
移除重叠的边界框。
上面的复杂处理是借Compute Shader的Preprocess、Postprocess1和postprocess2来实现的,Compute Shader 是一种图形编程中的着色器类型,专门用于执行计算任务,而不直接参与渲染。详细内容如下。
Common.hlsl:
// Compile-time constants
#define MAX_DETECTION 512
#define ANCHOR_COUNT 3// Detection data structure - The layout of this structure must be matched
// with the one defined in Detection.cs.
struct Detection
{float x, y, w, h;uint classIndex;float score;
};// Misc math functionsfloat CalculateIOU(in Detection d1, in Detection d2)
{float x0 = max(d1.x - d1.w / 2, d2.x - d2.w / 2);float x1 = min(d1.x + d1.w / 2, d2.x + d2.w / 2);float y0 = max(d1.y - d1.h / 2, d2.y - d2.h / 2);float y1 = min(d1.y + d1.h / 2, d2.y + d2.h / 2);float area0 = d1.w * d1.h;float area1 = d2.w * d2.h;float areaInner = max(0, x1 - x0) * max(0, y1 - y0);return areaInner / (area0 + area1 - areaInner);
}float Sigmoid(float x)
{return 1 / (1 + exp(-x));
}#endif
Postprocess1.compute:
#pragma kernel Postprocess1#include "Common.hlsl"// Input
Texture2D<float> Input;
uint InputSize;
float2 Anchors[ANCHOR_COUNT];
uint ClassCount;
float Threshold;// Output buffer
RWStructuredBuffer<Detection> Output;
RWStructuredBuffer<uint> OutputCount; // Only used as a counter[numthreads(8, 8, 1)]
void Postprocess1(uint2 id : SV_DispatchThreadID)
{if (!all(id < InputSize)) return;// Input reference point:// We have to read the input tensor in reversed order.uint ref_y = (InputSize - 1 - id.y) * InputSize + (InputSize - 1 - id.x);for (uint aidx = 0; aidx < ANCHOR_COUNT; aidx++){uint ref_x = aidx * (5 + ClassCount);// Bounding box / confidencefloat x = Input[uint2(ref_x + 0, ref_y)];float y = Input[uint2(ref_x + 1, ref_y)];float w = Input[uint2(ref_x + 2, ref_y)];float h = Input[uint2(ref_x + 3, ref_y)];float c = Input[uint2(ref_x + 4, ref_y)];// ArgMax[SoftMax[classes]]uint maxClass = 0;float maxScore = exp(Input[uint2(ref_x + 5, ref_y)]);float scoreSum = maxScore;for (uint cidx = 1; cidx < ClassCount; cidx++){float score = exp(Input[uint2(ref_x + 5 + cidx, ref_y)]);if (score > maxScore){maxClass = cidx;maxScore = score;}scoreSum += score;}// Output structureDetection data;data.x = (id.x + Sigmoid(x)) / InputSize;data.y = (id.y + Sigmoid(y)) / InputSize;data.w = exp(w) * Anchors[aidx].x;data.h = exp(h) * Anchors[aidx].y;data.classIndex = maxClass;data.score = Sigmoid(c) * maxScore / scoreSum;// Thresholdingif (data.score > Threshold){// Detected: Count and outputuint count = OutputCount.IncrementCounter();if (count < MAX_DETECTION) Output[count] = data;}}
}
Postprocess2.compute:
#pragma kernel Postprocess2#include "Common.hlsl"// Input
StructuredBuffer<Detection> Input;
RWStructuredBuffer<uint> InputCount; // Only used as a counter
float Threshold;// Output
AppendStructuredBuffer<Detection> Output;// Local arrays for data cache
groupshared Detection _entries[MAX_DETECTION];
groupshared bool _flags[MAX_DETECTION];[numthreads(1, 1, 1)]
void Postprocess2(uint3 id : SV_DispatchThreadID)
{// Initialize data cache arraysuint entry_count = min(MAX_DETECTION, InputCount.IncrementCounter());if (entry_count == 0) return;for (uint i = 0; i < entry_count; i++){_entries[i] = Input[i];_flags[i] = true;}for (i = 0; i < entry_count - 1; i++){if (!_flags[i]) continue;for (uint j = i + 1; j < entry_count; j++){if (!_flags[j]) continue;if (CalculateIOU(_entries[i], _entries[j]) < Threshold)continue;if (_entries[i].score < _entries[j].score){_flags[i] = false;break;}else_flags[j] = false;}}for (i = 0; i < entry_count; i++)if (_flags[i]) Output.Append(_entries[i]);
}
Postprocess.compute:
#pragma kernel Preprocesssampler2D Image;
RWStructuredBuffer<float> Tensor;
uint Size;[numthreads(8, 8, 1)]
void Preprocess(uint2 id : SV_DispatchThreadID)
{// UV (vertically flipped)float2 uv = float2(0.5 + id.x, Size - 0.5 - id.y) / Size;// UV gradientsfloat2 duv_dx = float2(1.0 / Size, 0);float2 duv_dy = float2(0, -1.0 / Size);// Texture samplefloat3 rgb = tex2Dgrad(Image, uv, duv_dx, duv_dy).rgb;// Tensor element outputuint offs = (id.y * Size + id.x) * 3;Tensor[offs + 0] = rgb.r;Tensor[offs + 1] = rgb.g;Tensor[offs + 2] = rgb.b;
}
通过以上的处理,最后输出了一个目标检测的对象结果数组,主要包含如下数据:
public readonly struct Detection
{public readonly float x, y, w, h;public readonly uint classIndex;public readonly float score;
}
通过遍历这个数组,并将结果标记框和对象名称等信息显示出来:
public void SetAttributes(in Detection d)
{var rect = _parent.rect;var x = d.x * rect.width;var y = (1 - d.y) * rect.height;var w = d.w * rect.width;var h = d.h * rect.height;_xform.anchoredPosition = new Vector2(x, y);_xform.SetSizeWithCurrentAnchors(RectTransform.Axis.Horizontal, w);_xform.SetSizeWithCurrentAnchors(RectTransform.Axis.Vertical, h);var name = _labels[(int)d.classIndex];_label.text = $"{name} {(int)(d.score * 100)}%";var hue = d.classIndex * 0.073f % 1.0f;var color = Color.HSVToRGB(hue, 1, 1);_panel.color = color;transform.localScale = Vector3.one;
}
源码
https://download.csdn.net/download/qq_33789001/90242899
相关文章:

Unity3d 基于Barracuda推理库和YOLO算法实现对象检测功能
前言 近年来,随着AI技术的发展,在游戏引擎中实现和运行机器学习模型的需求也逐渐显现。Unity3d引擎官方推出深度学习推理框架–Barracuda ,旨在帮助开发者在Unity3d中轻松地实现和运行机器学习模型,它的主要功能是支持在 Unity 中…...
Lambda离线实时分治架构深度解析与实战
一、引言 在大数据技术日新月异的今天,Lambda架构作为一种经典的数据处理模型,在应对大规模数据应用方面展现出了强大的能力。它整合了离线批处理和实时流处理,为需要同时处理批量和实时数据的应用场景提供了成熟的解决方案。本文将对Lambda…...

Spring Boot教程之五十一:Spring Boot – CrudRepository 示例
Spring Boot – CrudRepository 示例 Spring Boot 建立在 Spring 之上,包含 Spring 的所有功能。由于其快速的生产就绪环境,使开发人员能够直接专注于逻辑,而不必费力配置和设置,因此如今它正成为开发人员的最爱。Spring Boot 是…...

jenkins入门6 --拉取代码
Jenkins代码拉取 需要的插件,缺少的安装下 新建一个item,选择freestyle project 源码管理配置如下:需要添加git库地址,和登录git的用户密码 配置好后执行编译,成功后拉取的代码在工作空间里...
CAPL概述与环境搭建
CAPL概述与环境搭建 目录 CAPL概述与环境搭建1. CAPL简介与应用领域1.1 CAPL简介1.2 CAPL的应用领域 2. CANoe/CANalyzer 安装与配置2.1 CANoe/CANalyzer 简介2.2 安装CANoe/CANalyzer2.2.1 系统要求2.2.2 安装步骤 2.3 配置CANoe/CANalyzer2.3.1 配置CAN通道2.3.2 配置CAPL节点…...

Virgo:增强慢思考推理能力的多模态大语言模型
每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…...

偃动访无穿戴动作捕捉系统:赋能多行业开启动作捕捉新篇章
在当今科技飞速发展的时代,动作捕捉技术正以前所未有的态势深入到社会发展的各个领域,成为众多行业不可或缺的重要助力。从早期的惯性动捕与光捕技术,到如今更为先进的无标记动捕技术,动作捕捉领域不断迎来革新与突破。 无标记动…...
mikro-orm 和typeorm 对比
以下是Mikro-ORM和TypeORM的详细对比: 设计理念与架构 Mikro-ORM:基于数据映射器、工作单元和身份映射模式。这种设计使得它在管理内存中实体状态方面表现优异,能够自动处理事务,当调用em.flush()时,所有计算出的更改…...

Docker入门之docker基本命令
Docker入门之docker基本命令 官方网站:https://www.docker.com/ 1. 拉取官方镜像并创建容器(以redis为例) 拉取官方镜像 docker pull redis# 如果不需要添加到自定义网络使用这个命令,如需要,直接看第二步 docker r…...
mysql的一些函数及其用法
mysql 1-来自于leetcode1517的题目 表: Users------------------------ | Column Name | Type | ------------------------ | user_id | int | | name | varchar | | mail | varchar | ------------------------已知一个表,它的…...

NO.3 《机器学习期末复习篇》以题(问答题)促习(人学习),满满干huo,大胆学大胆补!
目录 🔍 1. 对于非齐次线性模型 ,试将其表示为齐次线性模型形式。 编辑 🔍 2. 某汽车公司一年内各月份的广告投入与月销量数据如表3-28所示,试根据表中数据构造线性回归模型,并使用该模型预测月广告投入为20万元时…...

黑马跟学.苍穹外卖.Day03
黑马跟学.苍穹外卖.Day03 苍穹外卖-day03课程内容1. 公共字段自动填充1.1 问题分析1.2 实现思路1.3 代码开发1.3.1 步骤一1.3.2 步骤二1.3.3 步骤三 1.4 功能测试1.5 代码提交 2. 新增菜品2.1 需求分析与设计2.1.1 产品原型2.1.2 接口设计2.1.3 表设计 2.2 代码开发2.2.1 文件上…...
js -音频变音(听不出说话的人是谁)
学习参考来源: https://zhuanlan.zhihu.com/p/634848804 https://developer.mozilla.org/zh-CN/docs/Web/API/Web_Audio_API 实际效果: http://www.qingkong.zone/laboratory?typeaudio-confusion 前言 本文内容可结合上面学习参考来源,结合…...

鸿蒙UI(ArkUI-方舟UI框架)
参考:https://developer.huawei.com/consumer/cn/doc/harmonyos-guides-V13/arkts-layout-development-overview-V13 ArkUI简介 ArkUI(方舟UI框架)为应用的UI开发提供了完整的基础设施,包括简洁的UI语法、丰富的UI功能ÿ…...
常见的http状态码 + ResponseEntity
常见的http状态码 ResponseStatus(HttpStatus.CREATED) 是 Spring Framework 中的注解,用于指定 HTTP 响应状态码。 1. 基本说明 HttpStatus.CREATED 对应 HTTP 状态码 201表示请求成功且创建了新的资源通常用于 POST 请求的处理方法上 2. 使用场景和示例 基本…...

pikachu - Cross-Site Scripting(XSS)
pikachu - Cross-Site Scripting(XSS) 声明! 笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习交流,其他均与本人无关,切勿触碰法律底线,否则后果自负&#x…...
操作系统之文件系统的基本概念
目录 用户和磁盘视角的文件 文件控制块(FCB)和索引结点(inode) 文件的操作 创建文件(create系统调用) 写文件(write系统调用) 读文件(read系统调用) 重…...

深入探讨 Android 中的 AlarmManager:定时任务调度及优化实践
引言 在 Android 开发中,AlarmManager 是一个非常重要的系统服务,用于设置定时任务或者周期性任务。无论是设置一个闹钟,还是定时进行数据同步,AlarmManager 都是不可或缺的工具之一。然而,随着 Android 系统的不断演…...

西电-算法分析-研究生课程复习笔记
24年秋的应该是张老师最后一次用卷面考试,他说以后这节课的期末考试都是在OJ上刷题了张老师上课还挺有意思的,上完之后能学会独立地思考算法设计问题了。整节课都在强调规模压缩这个概念,考试也是考个人对这些的理解,还挺好玩的哈…...

编译时找不到需要的库,如何在PyCharm中为你的项目添加需要的库
丰富的库支持是 Python 语言的一大特点,但是在使用 PyCharm 进行Python 代码编译的时候,遇到一些需要使用到的库提示不能解析时,该如何添加呢? 比如下图所示的代码,可以看到需要使用 selenium、b4、jieba 这些库&…...

K8s基础一
Kubernetes 架构 Kubernetes 背后的架构概念。 Kubernetes 集群由一个控制平面和一组用于运行容器化应用的工作机器组成, 这些工作机器称作节点(Node)。每个集群至少需要一个工作节点来运行 Pod。 工作节点托管着组成应用负载的 Pod。控制平…...

如何基于Mihomo Party http端口配置git与bash命令行代理
如何基于Mihomo Party http端口配置git与bash命令行代理 1. 确定Mihomo Party http端口配置 点击内核设置后即可查看 默认7892端口,开启允许局域网连接 2. 配置git代理 配置本机代理可以使用 127.0.0.1 配置局域网内其它机代理需要使用本机的非回环地址 IP&am…...
leetcode_206 反转链表
1. 题意 原地反转链表,非常经典的一道题。 2. 解决 2.1 非递归 非递归的比较好理解;链表需要维护前驱和后继两个信息,当我们要更改后继时,先要把原来的后继先存起来。 /*** Definition for singly-linked list.* struct List…...
12.6Swing控件4 JSplitPane JTabbedPane
JSplitPane JSplitPane 是 Java Swing 中用于创建分隔面板的组件,支持两个可调整大小组件的容器。它允许用户通过拖动分隔条来调整两个组件的相对大小,适合用于需要动态调整视图比例的场景。 常用方法: setLeftComponent(Component comp)&a…...
android 之 Tombstone
Android 系统中的 Tombstone 是记录 Native 层崩溃信息的关键日志文件,当应用或系统服务因严重错误(如内存访问异常、空指针解引用等)崩溃时自动生成。以下是其核心机制与分析方法详解: 一、Tombstone 的生成机制 触发条件 当 Na…...

平安养老险蚌埠中心支公司开展金融宣教活动
近日,平安养老保险股份有限公司(以下简称“平安养老险”)蚌埠中心支公司,走进某合作企业开展金融教育宣传活动。 活动现场,平安养老险蚌埠中心支公司工作人员通过发放宣传手册和小礼品等方式,向企业员工普…...

数据库系统概论(十六)数据库安全性(安全标准,控制,视图机制,审计与数据加密)
数据库系统概论(十六)数据库安全性 前言一、数据库安全性1. 什么是数据库安全性?2. 为何会存在安全问题? 二、安全标准的发展1. 早期的“开拓者”:TCSEC标准2. 走向国际统一:CC标准3. TCSEC和CC标准有什么不…...

51单片机基础部分——独立按键检测
前言 在单片机开发中,我们会经常对单片机的状态进行控制,比如我们会控制某个灯点亮,某个灯熄灭,这个时候我们就要开始做控制,我们可以通过什么控制呢,这个地方我们选择按键控制 按键实物及工作原理 生活…...
Github 2025-06-02 开源项目周报 Top11
根据Github Trendings的统计,本周(2025-06-02统计)共有11个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目6Jupyter Notebook项目2Shell项目1Dockerfile项目1TypeScript项目1Vue项目1PowerShell项目1MindsDB:定制企业数据人工智能的开源平台…...
JavaSec-SSTI - 模板引擎注入
简介 SSTI(Server Side Template Injection):模板引擎是一种通过将模板中的占位符替换为实际数据来动态生成内容的工具,如HTML页面、邮件等。它简化了视图层的设计,但如果未对用户输入进行有效校验,可能导致安全风险如任意代码执行…...