实现pytorch版的mobileNetV1
mobileNet具体细节,在前面已做了分析记录:轻量化网络-MobileNet系列-CSDN博客
这里是根据网络结构,搭建模型,用于图像分类任务。
1. 网络结构和基本组件
2. 搭建组件
(1)普通的卷积组件:CBL = Conv2d + BN + ReLU6;
(2)深度可分离卷积:DwCBL = Conv dw+ Conv dp;
Conv dw+ Conv dp = {Conv2d(3x3) + BN + ReLU6 } + {Conv2d(1x1) + BN + ReLU6};
Conv dw是3x3的深度卷积,通过步长控制是否进行下采样;
Conv dp是1x1的逐点卷积,通过控制输出通道数,控制通道维度的变化;
# 普通卷积
class CBN(nn.Module):def __init__(self, in_c, out_c, stride=1):super(CBN, self).__init__()self.conv = nn.Conv2d(in_c, out_c, 3, stride, padding=1, bias=False)self.bn = nn.BatchNorm2d(out_c)self.relu = nn.ReLU6(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return x
# 深度可分离卷积: 深度卷积(3x3x1) + 逐点卷积(1x1xc卷积)
class DwCBN(nn.Module):def __init__(self, in_c, out_c, stride=1):super(DwCBN, self).__init__()# conv3x3x1, 深度卷积,通过步长,只控制是否缩小特征hwself.conv3x3 = nn.Conv2d(in_c, in_c, 3, stride, padding=1, groups=in_c, bias=False)self.bn1 = nn.BatchNorm2d(in_c)self.relu1 = nn.ReLU6(inplace=True)# conv1x1xc, 逐点卷积,通过控制输出通道数,控制通道维度的变化self.conv1x1 = nn.Conv2d(in_c, out_c, 1, stride=1, padding=0, bias=False)self.bn2 = nn.BatchNorm2d(out_c)self.relu2 = nn.ReLU6(inplace=True)def forward(self, x):x = self.conv3x3(x)x = self.bn1(x)x = self.relu1(x)x = self.conv1x1(x)x = self.bn2(x)x = self.relu2(x)return x
3. 搭建网络
class MobileNetV1(nn.Module):def __init__(self, class_num=1000):super(MobileNetV1, self).__init__()self.stage1 = torch.nn.Sequential(CBN(3, 32, 2), # 下采样/2DwCBN(32, 64, 1))self.stage2 = torch.nn.Sequential(DwCBN(64, 128, 2), # 下采样/4DwCBN(128, 128, 1))self.stage3 = torch.nn.Sequential(DwCBN(128, 256, 2), # 下采样/8DwCBN(256, 256, 1))self.stage4 = torch.nn.Sequential(DwCBN(256, 512, 2), # 下采样/16DwCBN(512, 512, 1), # 5个DwCBN(512, 512, 1),DwCBN(512, 512, 1),DwCBN(512, 512, 1),DwCBN(512, 512, 1),)self.stage5 = torch.nn.Sequential(DwCBN(512, 1024, 2), # 下采样/32DwCBN(1024, 1024, 1))# classifierself.avg_pooling = torch.nn.AdaptiveAvgPool2d((1, 1))self.fc = torch.nn.Linear(1024, class_num, bias=True)# self.classifier = torch.nn.Softmax() # 原始的softmax值# torch.log_softmax 首先计算 softmax 然后再取对数,因此在数值上更加稳定。# 在分类网络在训练过程中,通常使用交叉熵损失函数(Cross-Entropy Loss)。# torch.nn.CrossEntropyLoss 会在内部进行 softmax 操作,因此在网络的最后一层不需要手动加上 softmax 操作。def forward(self, x):scale1 = self.stage1(x) # /2scale2 = self.stage2(scale1)scale3 = self.stage3(scale2)scale4 = self.stage4(scale3)scale5 = self.stage5(scale4) # /32. 7x7x = self.avg_pooling(scale5) # (b,1024,7,7)->(b,1024,1,1)x = torch.flatten(x, 1) # (b,1024,1,1)->(b,1024,)x = self.fc(x) # (b,1024,) -> (b,1000,)return xif __name__ == '__main__':m1 = MobileNetV1(class_num=1000)input_data = torch.randn(64, 3, 224, 224)output = m1.forward(input_data)print(output.shape)
4. 训练验证
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optimfrom mobilenetv1 import MobileNetV1def validate(model, val_loader, criterion, device):model.eval() # Set the model to evaluation modetotal_correct = 0total_samples = 0with torch.no_grad():for val_inputs, val_labels in val_loader:val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)val_outputs = model(val_inputs)_, predicted = torch.max(val_outputs, 1)total_samples += val_labels.size(0)total_correct += (predicted == val_labels).sum().item()accuracy = total_correct / total_samplesmodel.train() # Set the model back to training modereturn accuracyif __name__ == '__main__':# 下载并准备数据集# Define image transformations (adjust as needed)transform = transforms.Compose([transforms.Resize((224, 224)), # Resize images to a consistent sizetransforms.ToTensor(), # converts to PIL Image to a Pytorch Tensor and scales values to the range [0, 1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Adjust normalization values. val = (val - mean) / std.])# Create ImageFolder datasetdata_folder = r"D:\zxq\data\car_or_dog"dataset = torchvision.datasets.ImageFolder(root=data_folder, transform=transform)# Optionally, split the dataset into training and validation sets# Adjust the `split_ratio` as neededsplit_ratio = 0.8train_size = int(split_ratio * len(dataset))val_size = len(dataset) - train_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])# Create DataLoader for training and validationtrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)# 初始化模型、损失函数和优化器net = MobileNetV1(class_num=2)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)# 训练模型device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)net.to(device)for epoch in range(20): # 例如,训练 20 个周期for i, data in enumerate(train_loader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device) # 将数据移动到GPUoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()if i % 100 == 0:print("epoch/step: {}/{}: loss: {}".format(epoch, i, loss.item()))# Validation after each epochval_accuracy = validate(net, val_loader, criterion, device)print("Epoch {} - Validation Accuracy: {:.2%}".format(epoch, val_accuracy))print('Finished Training')
待续。。。
相关文章:

实现pytorch版的mobileNetV1
mobileNet具体细节,在前面已做了分析记录:轻量化网络-MobileNet系列-CSDN博客 这里是根据网络结构,搭建模型,用于图像分类任务。 1. 网络结构和基本组件 2. 搭建组件 (1)普通的卷积组件:CBL …...
vue多tab页面全部关闭后自动退出登录
业务场景:主项目是用vue写的单页面应用,但是有多开页面的需求,现在需要在用户关闭了所有的浏览器标签页面后,自动退出登录。 思路:因为是不同的tab页面,我只能用localStorage来通信,新打开一个…...
记一个集群环境部署不完整导致的BUG
一 背景 产品有三个环境:开发测试环境、验收环境、生产环境。 开发测试环境,保持最新的更新; 验收环境,阶段待发布内容; 生产环境,部署稳定内容。 产品为BS架构,后端采用微服务…...
Go zero copy,复制文件
这里使用零拷贝技术复制文件,从内核态操作源文件和目标文件。避免了在用户态开辟缓冲区,然后从内核态复制文件到用户态的问题。 由内核态完成文件复制操作。 调用的是syscall.Sendfile系统调用函数。 //go:build linuxpackage zero_copyimport ("f…...
http协议九种请求方法介绍及常见状态码
http1.0定义了三种: GET: 向服务器获取资源,比如常见的查询请求POST: 向服务器提交数据而发送的请求Head: 和get类似,返回的响应中没有具体的内容,用于获取报头 http1.1定义了六种 PUT:一般是用于更新请求,…...

详解flink exactly-once和两阶段提交
以下是我们常见的三种 flink 处理语义: 最多一次(At-most-Once):用户的数据只会被处理一次,不管成功还是失败,不会重试也不会重发。 至少一次(At-least-Once):系统会保…...
Qt/QML编程学习之心得:QDbus实现service接口调用(28)
D-Bus协议用于进程间通讯的。 QString value = retrieveValue();QDBusPendingCall pcall = interface->asyncCall(QLatin1String("Process"), value);QDBusPendingCallWatcher *watcher = new QDBusPendingCallWatcher(pcall, this);QObject::connect(watcher, SI…...

前端nginx配置指南
前端项目发布后,有些接口需要在服务器配置反向代理,资源配置gzip压缩,配置跨域允许访问等 配置文件模块概览 配置示例 反向代理 反向代理是Nginx的核心功能之一,是指客户端发送请求到代理服务器,代理服务器再将请求…...

接口测试到底怎么做,5分钟时间看完这篇文章彻底搞清楚
01、通用的项目架构 02、什么是接口 接口:服务端程序对外提供的一种统一的访问方式,通常采用HTTP协议,通过不同的url,不同的请求类型(GET、POST),不同的参数,来执行不同的业务逻辑。…...

显示管理磁盘分区 fdisk
显示管理磁盘分区 fdisk fdisk是用于检查一个磁盘上分区信息最通用的命令。 fdisk可以显示分区信息及一些细节信息,比如文件系统类型等。 设备的名称通常是/dev/sda、/dev/sdb 等。 对于以前的设备有可能还存在设备名为 /dev/hd* (IDE)的设备,这个设…...

Hyperledger Fabric 管理链码 peer lifecycle chaincode 指令使用
链上代码(Chaincode)简称链码,包括系统链码和用户链码。系统链码(System Chaincode)指的是 Fabric Peer 中负责系统配置、查询、背书、验证等平台功能的代码逻辑,运行在 Peer 进程内,将在第 14 …...
L1-011 A-B(Java)
题目 本题要求你计算A−B。不过麻烦的是,A和B都是字符串 —— 即从字符串A中把字符串B所包含的字符全删掉,剩下的字符组成的就是字符串A−B。 输入格式: 输入在2行中先后给出字符串A和B。两字符串的长度都不超过10的四次方,并且…...

系列七、Ribbon
一、Ribbon 1.1、概述 Ribbon是基于Netflix Ribbon实现的一套客户端负载均衡的工具,是Netflix发布的一款开源项目,其主要功能是提供客户端的软件负载均衡算法和服务调用,Ribbon客户端组件提供一系列完善的配置项,例如:…...

山东名岳轩印刷包装携专业包装袋盛装亮相2024济南生物发酵展
山东名岳轩印刷包装有限公司盛装亮相2024第12届国际生物发酵展,3月5-7日山东国际会展中心与您相约! 展位号:1号馆F17 山东名岳轩印刷包装有限公司是一家拥有南北两个生产厂区,设计、制版、印刷,营销策划为一体的专业…...
BGP公认必遵属性——Next-hop(一)
BGP公认必遵属性共有三个,分别是:Next-hop、Origin、As-path,本期介绍Next-hop 点赞关注,持续更新!!! Next-hop 华为BGP路由下一跳特点: 默认情况下传给EBGP邻居的BGP路由的下一跳…...

增强Wi-Fi信号的10种方法,值得去尝试
Wi-Fi信号丢失,无线盲区。在一个对一些人来说,上网和呼吸一样必要的世界里,这些问题中的每一个都令人抓狂。 如果你觉得你的Wi-Fi变得迟钝,有很多工具可以用来测试你的互联网速度。你还可以尝试一些技巧来解决网络问题。然而,如果你能获得良好接收的唯一方法是站在无线路…...
第十五章 ECMAScript6新增的常用语法
文章目录 一、声明关键字二、箭头函数三、解构赋值四、展开运算符五、对字符的补充六、Symbol七、对象的简写语法八、Set和Map九、for-of 一、声明关键字 ES6新增的声明关键字: let,const:声明变量class:声明类import,…...

vulhub中的Apache SSI 远程命令执行漏洞
Apache SSI 远程命令执行漏洞 1.cd到ssi-rce cd /opt/vulhub/httpd/ssi-rce/ 2.执行docker-compose up -d docker-compose up -d 3.查看靶场是否开启成功 dooker ps 拉取成功了 4.访问url 这里已经执行成功了,注意这里需要加入/upload.php 5.写入一句话木马 &…...

MSB20M-ASEMI迷你贴片整流桥MSB20M
编辑:ll MSB20M-ASEMI迷你贴片整流桥MSB20M 型号:MSB20M 品牌:ASEMI 封装:UMSB-4 特性:贴片、整流桥 最大平均正向电流:2A 最大重复峰值反向电压:1000V 恢复时间:࿱…...

工程管理系统功能设计与实践:实现高效、透明的工程管理
在现代化的工程项目管理中,一套功能全面、操作便捷的系统至关重要。本文将介绍一个基于Spring Cloud和Spring Boot技术的Java版工程项目管理系统,结合Vue和ElementUI实现前后端分离。该系统涵盖了项目管理、合同管理、预警管理、竣工管理、质量管理等多个…...

CMake基础:构建流程详解
目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...
蓝桥杯 2024 15届国赛 A组 儿童节快乐
P10576 [蓝桥杯 2024 国 A] 儿童节快乐 题目描述 五彩斑斓的气球在蓝天下悠然飘荡,轻快的音乐在耳边持续回荡,小朋友们手牵着手一同畅快欢笑。在这样一片安乐祥和的氛围下,六一来了。 今天是六一儿童节,小蓝老师为了让大家在节…...

基于Docker Compose部署Java微服务项目
一. 创建根项目 根项目(父项目)主要用于依赖管理 一些需要注意的点: 打包方式需要为 pom<modules>里需要注册子模块不要引入maven的打包插件,否则打包时会出问题 <?xml version"1.0" encoding"UTF-8…...

Docker 本地安装 mysql 数据库
Docker: Accelerated Container Application Development 下载对应操作系统版本的 docker ;并安装。 基础操作不再赘述。 打开 macOS 终端,开始 docker 安装mysql之旅 第一步 docker search mysql 》〉docker search mysql NAME DE…...
C#中的CLR属性、依赖属性与附加属性
CLR属性的主要特征 封装性: 隐藏字段的实现细节 提供对字段的受控访问 访问控制: 可单独设置get/set访问器的可见性 可创建只读或只写属性 计算属性: 可以在getter中执行计算逻辑 不需要直接对应一个字段 验证逻辑: 可以…...

计算机基础知识解析:从应用到架构的全面拆解
目录 前言 1、 计算机的应用领域:无处不在的数字助手 2、 计算机的进化史:从算盘到量子计算 3、计算机的分类:不止 “台式机和笔记本” 4、计算机的组件:硬件与软件的协同 4.1 硬件:五大核心部件 4.2 软件&#…...
在鸿蒙HarmonyOS 5中使用DevEco Studio实现企业微信功能
1. 开发环境准备 安装DevEco Studio 3.1: 从华为开发者官网下载最新版DevEco Studio安装HarmonyOS 5.0 SDK 项目配置: // module.json5 {"module": {"requestPermissions": [{"name": "ohos.permis…...
关于uniapp展示PDF的解决方案
在 UniApp 的 H5 环境中使用 pdf-vue3 组件可以实现完整的 PDF 预览功能。以下是详细实现步骤和注意事项: 一、安装依赖 安装 pdf-vue3 和 PDF.js 核心库: npm install pdf-vue3 pdfjs-dist二、基本使用示例 <template><view class"con…...
Python实现简单音频数据压缩与解压算法
Python实现简单音频数据压缩与解压算法 引言 在音频数据处理中,压缩算法是降低存储成本和传输效率的关键技术。Python作为一门灵活且功能强大的编程语言,提供了丰富的库和工具来实现音频数据的压缩与解压。本文将通过一个简单的音频数据压缩与解压算法…...
DiscuzX3.5发帖json api
参考文章:PHP实现独立Discuz站外发帖(直连操作数据库)_discuz 发帖api-CSDN博客 简单改造了一下,适配我自己的需求 有一个站点存在多个采集站,我想通过主站拿标题,采集站拿内容 使用到的sql如下 CREATE TABLE pre_forum_post_…...