实现pytorch版的mobileNetV1
mobileNet具体细节,在前面已做了分析记录:轻量化网络-MobileNet系列-CSDN博客
这里是根据网络结构,搭建模型,用于图像分类任务。
1. 网络结构和基本组件

2. 搭建组件
(1)普通的卷积组件:CBL = Conv2d + BN + ReLU6;
(2)深度可分离卷积:DwCBL = Conv dw+ Conv dp;
Conv dw+ Conv dp = {Conv2d(3x3) + BN + ReLU6 } + {Conv2d(1x1) + BN + ReLU6};
Conv dw是3x3的深度卷积,通过步长控制是否进行下采样;
Conv dp是1x1的逐点卷积,通过控制输出通道数,控制通道维度的变化;
# 普通卷积
class CBN(nn.Module):def __init__(self, in_c, out_c, stride=1):super(CBN, self).__init__()self.conv = nn.Conv2d(in_c, out_c, 3, stride, padding=1, bias=False)self.bn = nn.BatchNorm2d(out_c)self.relu = nn.ReLU6(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return x
# 深度可分离卷积: 深度卷积(3x3x1) + 逐点卷积(1x1xc卷积)
class DwCBN(nn.Module):def __init__(self, in_c, out_c, stride=1):super(DwCBN, self).__init__()# conv3x3x1, 深度卷积,通过步长,只控制是否缩小特征hwself.conv3x3 = nn.Conv2d(in_c, in_c, 3, stride, padding=1, groups=in_c, bias=False)self.bn1 = nn.BatchNorm2d(in_c)self.relu1 = nn.ReLU6(inplace=True)# conv1x1xc, 逐点卷积,通过控制输出通道数,控制通道维度的变化self.conv1x1 = nn.Conv2d(in_c, out_c, 1, stride=1, padding=0, bias=False)self.bn2 = nn.BatchNorm2d(out_c)self.relu2 = nn.ReLU6(inplace=True)def forward(self, x):x = self.conv3x3(x)x = self.bn1(x)x = self.relu1(x)x = self.conv1x1(x)x = self.bn2(x)x = self.relu2(x)return x
3. 搭建网络
class MobileNetV1(nn.Module):def __init__(self, class_num=1000):super(MobileNetV1, self).__init__()self.stage1 = torch.nn.Sequential(CBN(3, 32, 2), # 下采样/2DwCBN(32, 64, 1))self.stage2 = torch.nn.Sequential(DwCBN(64, 128, 2), # 下采样/4DwCBN(128, 128, 1))self.stage3 = torch.nn.Sequential(DwCBN(128, 256, 2), # 下采样/8DwCBN(256, 256, 1))self.stage4 = torch.nn.Sequential(DwCBN(256, 512, 2), # 下采样/16DwCBN(512, 512, 1), # 5个DwCBN(512, 512, 1),DwCBN(512, 512, 1),DwCBN(512, 512, 1),DwCBN(512, 512, 1),)self.stage5 = torch.nn.Sequential(DwCBN(512, 1024, 2), # 下采样/32DwCBN(1024, 1024, 1))# classifierself.avg_pooling = torch.nn.AdaptiveAvgPool2d((1, 1))self.fc = torch.nn.Linear(1024, class_num, bias=True)# self.classifier = torch.nn.Softmax() # 原始的softmax值# torch.log_softmax 首先计算 softmax 然后再取对数,因此在数值上更加稳定。# 在分类网络在训练过程中,通常使用交叉熵损失函数(Cross-Entropy Loss)。# torch.nn.CrossEntropyLoss 会在内部进行 softmax 操作,因此在网络的最后一层不需要手动加上 softmax 操作。def forward(self, x):scale1 = self.stage1(x) # /2scale2 = self.stage2(scale1)scale3 = self.stage3(scale2)scale4 = self.stage4(scale3)scale5 = self.stage5(scale4) # /32. 7x7x = self.avg_pooling(scale5) # (b,1024,7,7)->(b,1024,1,1)x = torch.flatten(x, 1) # (b,1024,1,1)->(b,1024,)x = self.fc(x) # (b,1024,) -> (b,1000,)return xif __name__ == '__main__':m1 = MobileNetV1(class_num=1000)input_data = torch.randn(64, 3, 224, 224)output = m1.forward(input_data)print(output.shape)
4. 训练验证
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optimfrom mobilenetv1 import MobileNetV1def validate(model, val_loader, criterion, device):model.eval() # Set the model to evaluation modetotal_correct = 0total_samples = 0with torch.no_grad():for val_inputs, val_labels in val_loader:val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)val_outputs = model(val_inputs)_, predicted = torch.max(val_outputs, 1)total_samples += val_labels.size(0)total_correct += (predicted == val_labels).sum().item()accuracy = total_correct / total_samplesmodel.train() # Set the model back to training modereturn accuracyif __name__ == '__main__':# 下载并准备数据集# Define image transformations (adjust as needed)transform = transforms.Compose([transforms.Resize((224, 224)), # Resize images to a consistent sizetransforms.ToTensor(), # converts to PIL Image to a Pytorch Tensor and scales values to the range [0, 1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Adjust normalization values. val = (val - mean) / std.])# Create ImageFolder datasetdata_folder = r"D:\zxq\data\car_or_dog"dataset = torchvision.datasets.ImageFolder(root=data_folder, transform=transform)# Optionally, split the dataset into training and validation sets# Adjust the `split_ratio` as neededsplit_ratio = 0.8train_size = int(split_ratio * len(dataset))val_size = len(dataset) - train_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])# Create DataLoader for training and validationtrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)# 初始化模型、损失函数和优化器net = MobileNetV1(class_num=2)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)# 训练模型device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)net.to(device)for epoch in range(20): # 例如,训练 20 个周期for i, data in enumerate(train_loader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device) # 将数据移动到GPUoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()if i % 100 == 0:print("epoch/step: {}/{}: loss: {}".format(epoch, i, loss.item()))# Validation after each epochval_accuracy = validate(net, val_loader, criterion, device)print("Epoch {} - Validation Accuracy: {:.2%}".format(epoch, val_accuracy))print('Finished Training')
待续。。。
相关文章:
实现pytorch版的mobileNetV1
mobileNet具体细节,在前面已做了分析记录:轻量化网络-MobileNet系列-CSDN博客 这里是根据网络结构,搭建模型,用于图像分类任务。 1. 网络结构和基本组件 2. 搭建组件 (1)普通的卷积组件:CBL …...
vue多tab页面全部关闭后自动退出登录
业务场景:主项目是用vue写的单页面应用,但是有多开页面的需求,现在需要在用户关闭了所有的浏览器标签页面后,自动退出登录。 思路:因为是不同的tab页面,我只能用localStorage来通信,新打开一个…...
记一个集群环境部署不完整导致的BUG
一 背景 产品有三个环境:开发测试环境、验收环境、生产环境。 开发测试环境,保持最新的更新; 验收环境,阶段待发布内容; 生产环境,部署稳定内容。 产品为BS架构,后端采用微服务…...
Go zero copy,复制文件
这里使用零拷贝技术复制文件,从内核态操作源文件和目标文件。避免了在用户态开辟缓冲区,然后从内核态复制文件到用户态的问题。 由内核态完成文件复制操作。 调用的是syscall.Sendfile系统调用函数。 //go:build linuxpackage zero_copyimport ("f…...
http协议九种请求方法介绍及常见状态码
http1.0定义了三种: GET: 向服务器获取资源,比如常见的查询请求POST: 向服务器提交数据而发送的请求Head: 和get类似,返回的响应中没有具体的内容,用于获取报头 http1.1定义了六种 PUT:一般是用于更新请求,…...
详解flink exactly-once和两阶段提交
以下是我们常见的三种 flink 处理语义: 最多一次(At-most-Once):用户的数据只会被处理一次,不管成功还是失败,不会重试也不会重发。 至少一次(At-least-Once):系统会保…...
Qt/QML编程学习之心得:QDbus实现service接口调用(28)
D-Bus协议用于进程间通讯的。 QString value = retrieveValue();QDBusPendingCall pcall = interface->asyncCall(QLatin1String("Process"), value);QDBusPendingCallWatcher *watcher = new QDBusPendingCallWatcher(pcall, this);QObject::connect(watcher, SI…...
前端nginx配置指南
前端项目发布后,有些接口需要在服务器配置反向代理,资源配置gzip压缩,配置跨域允许访问等 配置文件模块概览 配置示例 反向代理 反向代理是Nginx的核心功能之一,是指客户端发送请求到代理服务器,代理服务器再将请求…...
接口测试到底怎么做,5分钟时间看完这篇文章彻底搞清楚
01、通用的项目架构 02、什么是接口 接口:服务端程序对外提供的一种统一的访问方式,通常采用HTTP协议,通过不同的url,不同的请求类型(GET、POST),不同的参数,来执行不同的业务逻辑。…...
显示管理磁盘分区 fdisk
显示管理磁盘分区 fdisk fdisk是用于检查一个磁盘上分区信息最通用的命令。 fdisk可以显示分区信息及一些细节信息,比如文件系统类型等。 设备的名称通常是/dev/sda、/dev/sdb 等。 对于以前的设备有可能还存在设备名为 /dev/hd* (IDE)的设备,这个设…...
Hyperledger Fabric 管理链码 peer lifecycle chaincode 指令使用
链上代码(Chaincode)简称链码,包括系统链码和用户链码。系统链码(System Chaincode)指的是 Fabric Peer 中负责系统配置、查询、背书、验证等平台功能的代码逻辑,运行在 Peer 进程内,将在第 14 …...
L1-011 A-B(Java)
题目 本题要求你计算A−B。不过麻烦的是,A和B都是字符串 —— 即从字符串A中把字符串B所包含的字符全删掉,剩下的字符组成的就是字符串A−B。 输入格式: 输入在2行中先后给出字符串A和B。两字符串的长度都不超过10的四次方,并且…...
系列七、Ribbon
一、Ribbon 1.1、概述 Ribbon是基于Netflix Ribbon实现的一套客户端负载均衡的工具,是Netflix发布的一款开源项目,其主要功能是提供客户端的软件负载均衡算法和服务调用,Ribbon客户端组件提供一系列完善的配置项,例如:…...
山东名岳轩印刷包装携专业包装袋盛装亮相2024济南生物发酵展
山东名岳轩印刷包装有限公司盛装亮相2024第12届国际生物发酵展,3月5-7日山东国际会展中心与您相约! 展位号:1号馆F17 山东名岳轩印刷包装有限公司是一家拥有南北两个生产厂区,设计、制版、印刷,营销策划为一体的专业…...
BGP公认必遵属性——Next-hop(一)
BGP公认必遵属性共有三个,分别是:Next-hop、Origin、As-path,本期介绍Next-hop 点赞关注,持续更新!!! Next-hop 华为BGP路由下一跳特点: 默认情况下传给EBGP邻居的BGP路由的下一跳…...
增强Wi-Fi信号的10种方法,值得去尝试
Wi-Fi信号丢失,无线盲区。在一个对一些人来说,上网和呼吸一样必要的世界里,这些问题中的每一个都令人抓狂。 如果你觉得你的Wi-Fi变得迟钝,有很多工具可以用来测试你的互联网速度。你还可以尝试一些技巧来解决网络问题。然而,如果你能获得良好接收的唯一方法是站在无线路…...
第十五章 ECMAScript6新增的常用语法
文章目录 一、声明关键字二、箭头函数三、解构赋值四、展开运算符五、对字符的补充六、Symbol七、对象的简写语法八、Set和Map九、for-of 一、声明关键字 ES6新增的声明关键字: let,const:声明变量class:声明类import,…...
vulhub中的Apache SSI 远程命令执行漏洞
Apache SSI 远程命令执行漏洞 1.cd到ssi-rce cd /opt/vulhub/httpd/ssi-rce/ 2.执行docker-compose up -d docker-compose up -d 3.查看靶场是否开启成功 dooker ps 拉取成功了 4.访问url 这里已经执行成功了,注意这里需要加入/upload.php 5.写入一句话木马 &…...
MSB20M-ASEMI迷你贴片整流桥MSB20M
编辑:ll MSB20M-ASEMI迷你贴片整流桥MSB20M 型号:MSB20M 品牌:ASEMI 封装:UMSB-4 特性:贴片、整流桥 最大平均正向电流:2A 最大重复峰值反向电压:1000V 恢复时间:࿱…...
工程管理系统功能设计与实践:实现高效、透明的工程管理
在现代化的工程项目管理中,一套功能全面、操作便捷的系统至关重要。本文将介绍一个基于Spring Cloud和Spring Boot技术的Java版工程项目管理系统,结合Vue和ElementUI实现前后端分离。该系统涵盖了项目管理、合同管理、预警管理、竣工管理、质量管理等多个…...
XML Group端口详解
在XML数据映射过程中,经常需要对数据进行分组聚合操作。例如,当处理包含多个物料明细的XML文件时,可能需要将相同物料号的明细归为一组,或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码,增加了开…...
java_网络服务相关_gateway_nacos_feign区别联系
1. spring-cloud-starter-gateway 作用:作为微服务架构的网关,统一入口,处理所有外部请求。 核心能力: 路由转发(基于路径、服务名等)过滤器(鉴权、限流、日志、Header 处理)支持负…...
golang循环变量捕获问题
在 Go 语言中,当在循环中启动协程(goroutine)时,如果在协程闭包中直接引用循环变量,可能会遇到一个常见的陷阱 - 循环变量捕获问题。让我详细解释一下: 问题背景 看这个代码片段: fo…...
【Oracle APEX开发小技巧12】
有如下需求: 有一个问题反馈页面,要实现在apex页面展示能直观看到反馈时间超过7天未处理的数据,方便管理员及时处理反馈。 我的方法:直接将逻辑写在SQL中,这样可以直接在页面展示 完整代码: SELECTSF.FE…...
测试markdown--肇兴
day1: 1、去程:7:04 --11:32高铁 高铁右转上售票大厅2楼,穿过候车厅下一楼,上大巴车 ¥10/人 **2、到达:**12点多到达寨子,买门票,美团/抖音:¥78人 3、中饭&a…...
[ICLR 2022]How Much Can CLIP Benefit Vision-and-Language Tasks?
论文网址:pdf 英文是纯手打的!论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误,若有发现欢迎评论指正!文章偏向于笔记,谨慎食用 目录 1. 心得 2. 论文逐段精读 2.1. Abstract 2…...
拉力测试cuda pytorch 把 4070显卡拉满
import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试,通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小,增大可提高计算复杂度duration: 测试持续时间(秒&…...
Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习)
Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习) 一、Aspose.PDF 简介二、说明(⚠️仅供学习与研究使用)三、技术流程总览四、准备工作1. 下载 Jar 包2. Maven 项目依赖配置 五、字节码修改实现代码&#…...
Mysql中select查询语句的执行过程
目录 1、介绍 1.1、组件介绍 1.2、Sql执行顺序 2、执行流程 2.1. 连接与认证 2.2. 查询缓存 2.3. 语法解析(Parser) 2.4、执行sql 1. 预处理(Preprocessor) 2. 查询优化器(Optimizer) 3. 执行器…...
【JVM面试篇】高频八股汇总——类加载和类加载器
目录 1. 讲一下类加载过程? 2. Java创建对象的过程? 3. 对象的生命周期? 4. 类加载器有哪些? 5. 双亲委派模型的作用(好处)? 6. 讲一下类的加载和双亲委派原则? 7. 双亲委派模…...
