当前位置: 首页 > news >正文

深度学习:yolov3的使用--建立模型

使用argparse模块来定义和解析命令行参数

创建一个ArgumentParser对象

parser = argparse.ArgumentParser()

训练的轮数,每批图像的大小,更新模型参数之前累积梯度的次数,模型定义文件的路径。

parser.add_argument("--epochs", type=int, default=100, help="number of epochs") #训练次数parser.add_argument("--batch_size", type=int, default=1, help="size of each image batch")   #batch的大小parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")#在每一步(更新模型参数)之前累积梯度的次数”parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file") #模型的配置文件

数据配置文件的路径,从预训练的模型权重开始训练,生成批次数据时使用的CPU线程数。

parser.add_argument("--data_config", type=str, default="config/coco.data", help="path to data config file") #数据的配置文件parser.add_argument("--pretrained_weights", type=str, help="if specified starts from checkpoint model") #预训练文件parser.add_argument("--n_cpu", type=int, default=0, help="number of cpu threads to use during batch generation")#数据加载过程中应使用的CPU线程数。

每张图像的尺寸,每隔多少个epoch保存一次模型权重,每隔多少个epoch在验证集上进行一次评估,每十批计算一次平均精度(mAP),是否允许多尺度训练,

parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")parser.add_argument("--checkpoint_interval", type=int, default=20, help="interval between saving model weights")#隔多少个epoch保存一次模型权重parser.add_argument("--evaluation_interval", type=int, default=20, help="interval evaluations on validation set")#多少个epoch进行一次验证集的验证parser.add_argument("--compute_map", default=False, help="if True computes mAP every tenth batch")#parser.add_argument("--multiscale_training", default=True, help="allow for multi-scale training")

使用parse_args方法解析命令行参数

opt = parser.parse_args()

使用TensorFlow 2.0以上版本中的tf.summary模块创建日志记录器(Logger)

import tensorflow as tf  # 导入TensorFlow库,并简写为tf
# 确保使用的是TensorFlow 2.0或更高版本class Logger(object):def __init__(self, log_dir):"""Create a summary writer logging to log_dir.这个类的构造函数接受一个参数log_dir,它表示日志文件将要保存的目录。函数的作用是创建一个日志记录器,用于记录TensorFlow的摘要信息(例如训练过程中的损失、准确率等)。"""self.writer = tf.summary.create_file_writer(log_dir)  # 创建一个文件写入器,用于将摘要信息写入到指定的日志目录

调用Logger,创建目录

logger = Logger("logs")
# 创建Logger类的实例,并将日志目录设置为"logs"。这意味着所有的日志信息将被写入到当前工作目录下的"logs"文件夹中。device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 这行代码使用PyTorch的torch.device来确定运行设备。如果系统有可用的CUDA(即NVIDIA的GPU),则device将被设置为"cuda",否则将使用CPU。os.makedirs("output", exist_ok=True)
# 使用os模块的makedirs函数创建一个名为"output"的目录。exist_ok=True参数意味着如果"output"目录已经存在,不会抛出错误。os.makedirs("checkpoints", exist_ok=True)
# 类似地,这行代码创建一个名为"checkpoints"的目录。这个目录通常用于存储模型的检查点或保存的状态,以便后续可以恢复训练或进行模型评估。

定义 parse_data_config 的函数,它用于解析数据配置文件(文件内容分类种类,训练集路径,测试集路径,文件名称路径)

def parse_data_config(path):"""Parses the data configuration file"""options = dict()options['gpus'] = '0,1,2,3'options['num_workers'] = '10'with open(path, 'r') as fp:lines = fp.readlines()for line in lines:line = line.strip()if line == '' or line.startswith('#'):#startswith()用于检查字符串是否以特定的子字符串开始。如果是,它将返回True,否则返回False。continuekey, value = line.split('=')options[key.strip()] = value.strip()return options

调用函数

    data_config = parse_data_config(opt.data_config)train_path = data_config["train"]valid_path = data_config["valid"]

定义了一个名为 load_classes 的函数,它用于从指定路径加载类别标签

def load_classes(path):"""Loads class labels at 'path'"""fp = open(path, "r")names = fp.read().split("\n")[:-1]return names

调用函数

class_names = load_classes(data_config["names"])

定义了一个名为 parse_model_config 的函数,它用于解析 YOLOv3 模型的配置文件

读取文件划分如卷积、池化、上采样、路由、快捷连接和 YOLO 层

def parse_model_config(path):"""Parses the yolo-v3 layer configuration file and returns module definitions"""file = open(path, 'r')lines = file.read().split('\n')lines = [x for x in lines if x and not x.startswith('#')]   #x.startswith('#')用于检查字符串变量x是否以#前缀开始。如果x以该前缀开头,该方法将返回一个布尔值,通常是True,否则返回False。lines = [x.rstrip().lstrip() for x in lines] # get rid of fringe whitespacesmodule_defs = []for line in lines:if line.startswith('['): # This marks the start of a new blockmodule_defs.append({})module_defs[-1]['type'] = line[1:-1].rstrip()if module_defs[-1]['type'] == 'convolutional':module_defs[-1]['batch_normalize'] = 0else:key, value = line.split("=")value = value.strip()module_defs[-1][key.rstrip()] = value.strip()return module_defs

这个函数的目的是将 YOLOv3 模型配置文件中的文本描述转换成 PyTorch 可以理解的网络层模块。它首先处理超参数,然后逐个处理每个模块定义,根据模块的类型(如卷积、池化、上采样、路由、快捷连接和 YOLO 层)创建相应的 PyTorch 层,并添加到 module_list 中。

def create_modules(module_defs):"""Constructs module list of layer blocks from module configuration in module_defs"""# 从模块定义列表中弹出第一个元素,它包含了超参数(hyperparameters),例如输入图像的尺寸等。hyperparams = module_defs.pop(0)# 初始化输出过滤器列表,它将存储每一层的输出通道数(即卷积核的数量)。# 这里假设第一个超参数中的 'channels' 键对应的值是网络输入层的通道数。output_filters = [int(hyperparams["channels"])]# 创建一个 PyTorch 的 ModuleList 对象,用于存储网络层模块。module_list = nn.ModuleList()# 遍历模块定义列表,module_i 是索引,module_def 是当前模块的定义。for module_i, module_def in enumerate(module_defs):# 对于每个模块,创建一个 PyTorch 的 Sequential 对象,用于线性堆叠网络层。modules = nn.Sequential()# 如果模块类型是 "convolutional":if module_def["type"] == "convolutional":# 获取当前模块是否使用批归一化(batch normalization)。bn = int(module_def["batch_normalize"])# 获取卷积核的数量(即输出通道数)。filters = int(module_def["filters"])# 获取卷积核的大小。kernel_size = int(module_def["size"])# 计算填充值,以保持输出尺寸与输入尺寸相同。pad = (kernel_size - 1) // 2# 添加一个卷积层到 Sequential 对象中。modules.add_module(f"conv_{module_i}",nn.Conv2d(in_channels=output_filters[-1],  # 输入特征图的数量。out_channels=filters,  # 输出特征图的数量。kernel_size=kernel_size,  # 卷积核的大小。stride=int(module_def["stride"]),  # 卷积核滑动的步长。padding=pad,  # 填充值。bias=not bn,  # 是否添加偏置项。),)# 如果使用批归一化,则添加一个批归一化层。if bn:modules.add_module(f"batch_norm_{module_i}", nn.BatchNorm2d(filters, momentum=0.9))# 如果激活函数是 "leaky",则添加一个 LeakyReLU 激活层。if module_def["activation"] == "leaky":modules.add_module(f"leaky_{module_i}", nn.LeakyReLU(0.1))# 如果模块类型是 "maxpool":elif module_def["type"] == "maxpool":# 获取池化层的大小和步长。kernel_size = int(module_def["size"])stride = int(module_def["stride"])# 如果池化核大小为2且步长为1,添加一个填充层以保持尺寸。if kernel_size == 2 and stride == 1:modules.add_module(f"_debug_padding_{module_i}", nn.ZeroPad2d((0, 1, 0, 1)))# 添加一个最大池化层。maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=int((kernel_size - 1) // 2))modules.add_module(f"maxpool_{module_i}", maxpool)# 如果模块类型是 "upsample":elif module_def["type"] == "upsample":# 添加一个上采样层。upsample = Upsample(scale_factor=int(module_def["stride"]), mode="nearest")modules.add_module(f"upsample_{module_i}", upsample)# 如果模块类型是 "route":elif module_def["type"] == "route":# 获取路由层的层索引。layers = [int(x) for x in module_def["layers"].split(",")]# 计算路由层的输出通道数。filters = sum([output_filters[1:][i] for i in layers])# 添加一个空层作为路由层。modules.add_module(f"route_{module_i}", EmptyLayer())# 如果模块类型是 "shortcut":elif module_def["type"] == "shortcut":# 获取快捷连接的输入通道数。filters = output_filters[1:][int(module_def["from"])]# 添加一个空层作为快捷连接层。modules.add_module(f"shortcut_{module_i}", EmptyLayer())# 如果模块类型是 "yolo":elif module_def["type"] == "yolo":# 获取 YOLO 层的锚点索引和锚点值。anchor_idxs = [int(x) for x in module_def["mask"].split(",")]anchors = [int(x) for x in module_def["anchors"].split(",")]anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]anchors = [anchors[i] for i in anchor_idxs]# 获取 YOLO 层的类别数和图像尺寸。num_classes = int(module_def["classes"])img_size = int(hyperparams["height"])# 创建一个 YOLO 层并添加到模块中。yolo_layer = YOLOLayer(anchors, num_classes, img_size)modules.add_module(f"yolo_{module_i}", yolo_layer)# 将构建好的模块添加到模块列表中。module_list.append(modules)# 更新输出过滤器列表,添加当前模块的输出通道数。output_filters.append(filters)# 返回超参数和构建好的模块列表。return hyperparams, module_list

定义了一个名为 Darknet 的类,它是用于构建 YOLOv3 目标检测模型的 PyTorch 神经网络类。

class Darknet(nn.Module):"""YOLOv3 object detection model"""# 这个类继承自 PyTorch 的 nn.Module 类,表示它是一个神经网络模型。# YOLOv3 是一个流行的目标检测算法,这个类实现了 YOLOv3 的网络结构。def __init__(self, config_path, img_size=416):super(Darknet, self).__init__()# 类的构造函数接受两个参数:config_path(模型配置文件的路径)和 img_size(输入图像的尺寸,默认为416)。# super() 函数用于调用父类的构造函数,即初始化 PyTorch 的 nn.Module。self.module_defs = parse_model_config(config_path)# 调用 parse_model_config 函数解析配置文件,并存储解析后的模块定义。self.hyperparams, self.module_list = create_modules(self.module_defs)# 调用 create_modules 函数根据模块定义创建网络层,并存储超参数和模块列表。self.yolo_layers = [layer[0] for layer in self.module_list if hasattr(layer[0], "metrics")]# 从模块列表中提取出包含 'metrics' 属性的层,这些层通常是 YOLO 层,用于目标检测。# hasattr() 函数检查对象是否具有给定的属性。self.img_size = img_size# 存储输入图像的尺寸。self.seen = 0# 用于跟踪训练过程中看到的数据量(例如,图像数量)。self.header_info = np.array([0, 0, 0, self.seen, 0], dtype=np.int32)# 创建一个 NumPy 数组,用于存储与模型相关的头部信息,如版本、修订号、seen、未知字段和 epoch 数。

调用函数

    model = Darknet(opt.model_def).to(device)model.apply(weights_init_normal)#model.apply(fn)表示将fn函数应用到神经网络的各个模块上,包括该神经网络本身。这通常在初始化神经网络的参数时使用,本处用于初始化神经网络的权值

相关文章:

深度学习:yolov3的使用--建立模型

使用argparse模块来定义和解析命令行参数 创建一个ArgumentParser对象 parser argparse.ArgumentParser() 训练的轮数,每批图像的大小,更新模型参数之前累积梯度的次数,模型定义文件的路径。 parser.add_argument("--epochs", typeint, d…...

关于我、重生到500年前凭借C语言改变世界科技vlog.13——深入理解指针(3)

文章目录 1.字符指针变量2.数组指针变量3.函数指针变量4.函数指针数组5.二维数组传参本质6.拓展补充希望读者们多多三连支持小编会继续更新你们的鼓励就是我前进的动力! 本章节接着学习常见的指针变量类型 1.字符指针变量 字符指针变量,顾名思义就是字…...

每日算法一练:剑指offer——数组篇(6)

1.点名 某班级 n 位同学的学号为 0 ~ n-1。点名结果记录于升序数组 records。假定仅有一位同学缺席&#xff0c;请返回他的学号。 示例 1: 输入: records [0,1,2,3,5] 输出: 4示例 2: 输入: records [0, 1, 2, 3, 4, 5, 6, 8] 输出: 7提示&#xff1a; 1 < records.le…...

【环境搭建】Apache ZooKeeper 3.8.4 Stable

软件环境 Ubuntu 20.04 、OpenJDK 11 OpenJDK 11&#xff08;如果已经安装&#xff0c;可以跳过这一步&#xff09; 安装OpenJDK 11&#xff1a; $ sudo apt-get update$ sudo apt-get install -y openjdk-11-jdk 设置 JAVA_HOME 环境变量&#xff1a; $ sudo gedit ~/.bash…...

算法练习——双指针

前言&#xff1a;大佬写博客给别人看&#xff0c;菜鸟写博客给自己看&#xff0c;我是菜鸟。 学前须知&#xff08;对自己&#xff09;&#xff1a;这里的指针不一定指地址&#xff01;也可能是数组下标。 1&#xff1a;移动零(双指针) 题目要求&#xff1a; 解题思路&#x…...

vue中el-table显示文本过长提示

1.el-table设置轻提示:show-overflow-tooltip“true“&#xff0c;改变轻提示宽度...

JS 字符串拼接并去重

1、includes 循环数组将某个字段拼接成新的字符串并去重&#xff08;数组里面包含的一个对象&#xff0c;或者其他都OK&#xff09; // 定义一个数组 let arr[.......] // 定义拼接的字符串 let a //循环数组将里面某个字段拼接在一起并去重 arr.forEach(item > {if(!a.in…...

opencv 图像预处理

图像预处理 ​ 在计算机视觉和图像处理领域&#xff0c;图像预处理是一个重要的步骤&#xff0c;它能够提高后续处理&#xff08;如特征提取、目标检测等&#xff09;的准确性和效率。OpenCV 提供了许多图像预处理的函数和方法&#xff0c;以下是一些常见的图像预处理操作&…...

SAP B1 功能模块字段介绍 - 价格清单(下)

目录 背景 五、业务伙伴的特殊价格 1. 单据逻辑功能 2. 部分字段解释 3. 操作流程 3.1 时间相关 3.2 数量相关 4. 实例 六、复制特殊价格到选择标准 1. 单据逻辑功能 2. 部分字段解释 七、全局更新特殊价格 ​编辑 1. 单据逻辑功能 2. 部分字段解释 八、价格更…...

传智杯 第六届-复赛-D

题目描述&#xff1a; 小红定义两个字符串同构&#xff0c;当且仅当对于i∈[1,n],b[i]−a[i]i∈[1,n],b[i]-a[i]i∈[1,n],b[i]−a[i]是定值。例如&#xff0c;"bacd"和"edfg"是同构的。 现在小红拿到了一个长度为n的字符串a&#xff0c;她想知道&a…...

Java - 数组实现大顶堆

题目描述 实现思路 要实现一个堆&#xff0c;我们首先要了解堆的概念。 堆是一种完全二叉树&#xff0c;分为大顶堆和小顶堆。 大顶堆&#xff1a;每个节点的值都大于或等于其子节点的值。 小顶堆&#xff1a;每个节点的值都小于或等于其子节点的值。 完全二叉树&#xff…...

ifuse挂载后,在python代码中访问iOS沙盒目录获取app日志

上一次使用pymobiledevice3&#xff0c;在python代码中访问app的沙盒目录并分析业务日志&#xff0c;在使用过程中发现&#xff0c;在获取app日志的时候速度很慢&#xff0c;执行时间很长&#xff0c;需要30-61秒&#xff0c;所以这次尝试使用libimobiledevic和ifuse&#xff0…...

Windows WSL环境下安装 pytorch +ROCM 支持AMD显卡

官方文档&#xff1a;Install PyTorch for ROCm — Use ROCm on Radeon GPUs 一、操作系统及驱动 windows 下安装WSL 环境( windows subsystem for Linux), 安装ubuntu 22.04环境。 安装 rocm 软件包&#xff1a; sudo apt update wget https://repo.radeon.com/amdgpu-insta…...

uniapp中skymap.html(8100端口)提示未登录的排查与解决方法

问题&#xff1a; 目前账号已经登录&#xff0c;uniapp的其他端口均可以访问到数据&#xff0c;唯独skymap.html中的8100会提示未登录。&#xff08;8100是后端网关gateway端口&#xff09; 分析&#xff1a; 在 skymap.html 中遇到未登录提示的问题&#xff0c;通常是由于该…...

训练模型时梯度出现NAN或者INF(禁用amp的不同level)

判断参数梯度位nan或inf的代码&#xff1a; for name, param in model.named_parameters():if param.grad is not None:if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():print(f"grad layer [{name}] is NaN or Inf") 首先来说可能得原因&…...

Maven核心概念

一、项目对象模型&#xff08;POM&#xff09; 1. 定义 POM&#xff08;Project Object Model&#xff09;是 Maven 项目的核心配置文件&#xff0c;它以 XML 格式描述了项目的基本信息、项目依赖、构建配置等。可以说&#xff0c;POM 是 Maven 理解和处理项目的基础。 2. 基…...

Sonatype Nexus 部署手册

文章目录 一、前言二、软件环境2.1 版本变更&#xff1a;2.1.1 变更存储的原因2.2.2 H2作为存储的注意点 三、资源配置四、开始部署4.1 部署jdk174.2 离线部署nexus4.2.1 下载4.2.2 部署1. 上传到服务器2. 解压3. 添加用户4. 修改启动参数5. 迁移sonatype-work &#xff0c;并授…...

TLV320AIC3104IRHBR 数据手册 一款低功耗立体声音频编解码器 立体声耳机放大器芯片麦克风

TLV320AIC3104 是一款低功耗立体声音频编解码器&#xff0c;具有立体声耳机放大器以及在单端或全差分配置下可编程的多个输入和输出。该器件包括基于寄存器的全面电源控制&#xff0c;可实现立体声 48kHz DAC 回放&#xff0c;在 3.3V 模拟电源电压下的功耗低至 14mW&#xff0…...

(8)结构体、共用体和枚举类型数据

1. 结构体、共用体的定义及区别,typedef 定义别名 结构体的定义 结构体是一种用户自定义的数据类型,它可以将不同类型的数据组合在一起。例如,定义一个表示学生信息的结构体: // 定义结构体类型 struct Student struct Student {char name[20];int age;float score; };共…...

Jedis操作和springboot整合redis

Jedis-springboot整合redis Jedis 引入jedis依赖 注意事项 测试相关数据类型 Key String List set hash zset 案例 spring boot整合redis 引入相关依赖 在application.properties中配置redis 配置 创建redis配置类 创建测试类 Jedis 引入jedis依赖 <depen…...

模型参数、模型存储精度、参数与显存

模型参数量衡量单位 M&#xff1a;百万&#xff08;Million&#xff09; B&#xff1a;十亿&#xff08;Billion&#xff09; 1 B 1000 M 1B 1000M 1B1000M 参数存储精度 模型参数是固定的&#xff0c;但是一个参数所表示多少字节不一定&#xff0c;需要看这个参数以什么…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

Mac软件卸载指南,简单易懂!

刚和Adobe分手&#xff0c;它却总在Library里给你写"回忆录"&#xff1f;卸载的Final Cut Pro像电子幽灵般阴魂不散&#xff1f;总是会有残留文件&#xff0c;别慌&#xff01;这份Mac软件卸载指南&#xff0c;将用最硬核的方式教你"数字分手术"&#xff0…...

2025季度云服务器排行榜

在全球云服务器市场&#xff0c;各厂商的排名和地位并非一成不变&#xff0c;而是由其独特的优势、战略布局和市场适应性共同决定的。以下是根据2025年市场趋势&#xff0c;对主要云服务器厂商在排行榜中占据重要位置的原因和优势进行深度分析&#xff1a; 一、全球“三巨头”…...

springboot整合VUE之在线教育管理系统简介

可以学习到的技能 学会常用技术栈的使用 独立开发项目 学会前端的开发流程 学会后端的开发流程 学会数据库的设计 学会前后端接口调用方式 学会多模块之间的关联 学会数据的处理 适用人群 在校学生&#xff0c;小白用户&#xff0c;想学习知识的 有点基础&#xff0c;想要通过项…...

LINUX 69 FTP 客服管理系统 man 5 /etc/vsftpd/vsftpd.conf

FTP 客服管理系统 实现kefu123登录&#xff0c;不允许匿名访问&#xff0c;kefu只能访问/data/kefu目录&#xff0c;不能查看其他目录 创建账号密码 useradd kefu echo 123|passwd -stdin kefu [rootcode caozx26420]# echo 123|passwd --stdin kefu 更改用户 kefu 的密码…...

搭建DNS域名解析服务器(正向解析资源文件)

正向解析资源文件 1&#xff09;准备工作 服务端及客户端都关闭安全软件 [rootlocalhost ~]# systemctl stop firewalld [rootlocalhost ~]# setenforce 0 2&#xff09;服务端安装软件&#xff1a;bind 1.配置yum源 [rootlocalhost ~]# cat /etc/yum.repos.d/base.repo [Base…...

手机平板能效生态设计指令EU 2023/1670标准解读

手机平板能效生态设计指令EU 2023/1670标准解读 以下是针对欧盟《手机和平板电脑生态设计法规》(EU) 2023/1670 的核心解读&#xff0c;综合法规核心要求、最新修正及企业合规要点&#xff1a; 一、法规背景与目标 生效与强制时间 发布于2023年8月31日&#xff08;OJ公报&…...

加密通信 + 行为分析:运营商行业安全防御体系重构

在数字经济蓬勃发展的时代&#xff0c;运营商作为信息通信网络的核心枢纽&#xff0c;承载着海量用户数据与关键业务传输&#xff0c;其安全防御体系的可靠性直接关乎国家安全、社会稳定与企业发展。随着网络攻击手段的不断升级&#xff0c;传统安全防护体系逐渐暴露出局限性&a…...

Qt的学习(一)

1.什么是Qt Qt特指用来进行桌面应用开发&#xff08;电脑上写的程序&#xff09;涉及到的一套技术Qt无法开发网页前端&#xff0c;也不能开发移动应用。 客户端开发的重要任务&#xff1a;编写和用户交互的界面。一般来说和用户交互的界面&#xff0c;有两种典型风格&…...