实现pytorch版的mobileNetV1
mobileNet具体细节,在前面已做了分析记录:轻量化网络-MobileNet系列-CSDN博客
这里是根据网络结构,搭建模型,用于图像分类任务。
1. 网络结构和基本组件
2. 搭建组件
(1)普通的卷积组件:CBL = Conv2d + BN + ReLU6;
(2)深度可分离卷积:DwCBL = Conv dw+ Conv dp;
Conv dw+ Conv dp = {Conv2d(3x3) + BN + ReLU6 } + {Conv2d(1x1) + BN + ReLU6};
Conv dw是3x3的深度卷积,通过步长控制是否进行下采样;
Conv dp是1x1的逐点卷积,通过控制输出通道数,控制通道维度的变化;
# 普通卷积
class CBN(nn.Module):def __init__(self, in_c, out_c, stride=1):super(CBN, self).__init__()self.conv = nn.Conv2d(in_c, out_c, 3, stride, padding=1, bias=False)self.bn = nn.BatchNorm2d(out_c)self.relu = nn.ReLU6(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return x
# 深度可分离卷积: 深度卷积(3x3x1) + 逐点卷积(1x1xc卷积)
class DwCBN(nn.Module):def __init__(self, in_c, out_c, stride=1):super(DwCBN, self).__init__()# conv3x3x1, 深度卷积,通过步长,只控制是否缩小特征hwself.conv3x3 = nn.Conv2d(in_c, in_c, 3, stride, padding=1, groups=in_c, bias=False)self.bn1 = nn.BatchNorm2d(in_c)self.relu1 = nn.ReLU6(inplace=True)# conv1x1xc, 逐点卷积,通过控制输出通道数,控制通道维度的变化self.conv1x1 = nn.Conv2d(in_c, out_c, 1, stride=1, padding=0, bias=False)self.bn2 = nn.BatchNorm2d(out_c)self.relu2 = nn.ReLU6(inplace=True)def forward(self, x):x = self.conv3x3(x)x = self.bn1(x)x = self.relu1(x)x = self.conv1x1(x)x = self.bn2(x)x = self.relu2(x)return x
3. 搭建网络
class MobileNetV1(nn.Module):def __init__(self, class_num=1000):super(MobileNetV1, self).__init__()self.stage1 = torch.nn.Sequential(CBN(3, 32, 2), # 下采样/2DwCBN(32, 64, 1))self.stage2 = torch.nn.Sequential(DwCBN(64, 128, 2), # 下采样/4DwCBN(128, 128, 1))self.stage3 = torch.nn.Sequential(DwCBN(128, 256, 2), # 下采样/8DwCBN(256, 256, 1))self.stage4 = torch.nn.Sequential(DwCBN(256, 512, 2), # 下采样/16DwCBN(512, 512, 1), # 5个DwCBN(512, 512, 1),DwCBN(512, 512, 1),DwCBN(512, 512, 1),DwCBN(512, 512, 1),)self.stage5 = torch.nn.Sequential(DwCBN(512, 1024, 2), # 下采样/32DwCBN(1024, 1024, 1))# classifierself.avg_pooling = torch.nn.AdaptiveAvgPool2d((1, 1))self.fc = torch.nn.Linear(1024, class_num, bias=True)# self.classifier = torch.nn.Softmax() # 原始的softmax值# torch.log_softmax 首先计算 softmax 然后再取对数,因此在数值上更加稳定。# 在分类网络在训练过程中,通常使用交叉熵损失函数(Cross-Entropy Loss)。# torch.nn.CrossEntropyLoss 会在内部进行 softmax 操作,因此在网络的最后一层不需要手动加上 softmax 操作。def forward(self, x):scale1 = self.stage1(x) # /2scale2 = self.stage2(scale1)scale3 = self.stage3(scale2)scale4 = self.stage4(scale3)scale5 = self.stage5(scale4) # /32. 7x7x = self.avg_pooling(scale5) # (b,1024,7,7)->(b,1024,1,1)x = torch.flatten(x, 1) # (b,1024,1,1)->(b,1024,)x = self.fc(x) # (b,1024,) -> (b,1000,)return xif __name__ == '__main__':m1 = MobileNetV1(class_num=1000)input_data = torch.randn(64, 3, 224, 224)output = m1.forward(input_data)print(output.shape)
4. 训练验证
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optimfrom mobilenetv1 import MobileNetV1def validate(model, val_loader, criterion, device):model.eval() # Set the model to evaluation modetotal_correct = 0total_samples = 0with torch.no_grad():for val_inputs, val_labels in val_loader:val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)val_outputs = model(val_inputs)_, predicted = torch.max(val_outputs, 1)total_samples += val_labels.size(0)total_correct += (predicted == val_labels).sum().item()accuracy = total_correct / total_samplesmodel.train() # Set the model back to training modereturn accuracyif __name__ == '__main__':# 下载并准备数据集# Define image transformations (adjust as needed)transform = transforms.Compose([transforms.Resize((224, 224)), # Resize images to a consistent sizetransforms.ToTensor(), # converts to PIL Image to a Pytorch Tensor and scales values to the range [0, 1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Adjust normalization values. val = (val - mean) / std.])# Create ImageFolder datasetdata_folder = r"D:\zxq\data\car_or_dog"dataset = torchvision.datasets.ImageFolder(root=data_folder, transform=transform)# Optionally, split the dataset into training and validation sets# Adjust the `split_ratio` as neededsplit_ratio = 0.8train_size = int(split_ratio * len(dataset))val_size = len(dataset) - train_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])# Create DataLoader for training and validationtrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)# 初始化模型、损失函数和优化器net = MobileNetV1(class_num=2)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)# 训练模型device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)net.to(device)for epoch in range(20): # 例如,训练 20 个周期for i, data in enumerate(train_loader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device) # 将数据移动到GPUoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()if i % 100 == 0:print("epoch/step: {}/{}: loss: {}".format(epoch, i, loss.item()))# Validation after each epochval_accuracy = validate(net, val_loader, criterion, device)print("Epoch {} - Validation Accuracy: {:.2%}".format(epoch, val_accuracy))print('Finished Training')
待续。。。
相关文章:

实现pytorch版的mobileNetV1
mobileNet具体细节,在前面已做了分析记录:轻量化网络-MobileNet系列-CSDN博客 这里是根据网络结构,搭建模型,用于图像分类任务。 1. 网络结构和基本组件 2. 搭建组件 (1)普通的卷积组件:CBL …...
vue多tab页面全部关闭后自动退出登录
业务场景:主项目是用vue写的单页面应用,但是有多开页面的需求,现在需要在用户关闭了所有的浏览器标签页面后,自动退出登录。 思路:因为是不同的tab页面,我只能用localStorage来通信,新打开一个…...
记一个集群环境部署不完整导致的BUG
一 背景 产品有三个环境:开发测试环境、验收环境、生产环境。 开发测试环境,保持最新的更新; 验收环境,阶段待发布内容; 生产环境,部署稳定内容。 产品为BS架构,后端采用微服务…...
Go zero copy,复制文件
这里使用零拷贝技术复制文件,从内核态操作源文件和目标文件。避免了在用户态开辟缓冲区,然后从内核态复制文件到用户态的问题。 由内核态完成文件复制操作。 调用的是syscall.Sendfile系统调用函数。 //go:build linuxpackage zero_copyimport ("f…...
http协议九种请求方法介绍及常见状态码
http1.0定义了三种: GET: 向服务器获取资源,比如常见的查询请求POST: 向服务器提交数据而发送的请求Head: 和get类似,返回的响应中没有具体的内容,用于获取报头 http1.1定义了六种 PUT:一般是用于更新请求,…...

详解flink exactly-once和两阶段提交
以下是我们常见的三种 flink 处理语义: 最多一次(At-most-Once):用户的数据只会被处理一次,不管成功还是失败,不会重试也不会重发。 至少一次(At-least-Once):系统会保…...
Qt/QML编程学习之心得:QDbus实现service接口调用(28)
D-Bus协议用于进程间通讯的。 QString value = retrieveValue();QDBusPendingCall pcall = interface->asyncCall(QLatin1String("Process"), value);QDBusPendingCallWatcher *watcher = new QDBusPendingCallWatcher(pcall, this);QObject::connect(watcher, SI…...

前端nginx配置指南
前端项目发布后,有些接口需要在服务器配置反向代理,资源配置gzip压缩,配置跨域允许访问等 配置文件模块概览 配置示例 反向代理 反向代理是Nginx的核心功能之一,是指客户端发送请求到代理服务器,代理服务器再将请求…...

接口测试到底怎么做,5分钟时间看完这篇文章彻底搞清楚
01、通用的项目架构 02、什么是接口 接口:服务端程序对外提供的一种统一的访问方式,通常采用HTTP协议,通过不同的url,不同的请求类型(GET、POST),不同的参数,来执行不同的业务逻辑。…...

显示管理磁盘分区 fdisk
显示管理磁盘分区 fdisk fdisk是用于检查一个磁盘上分区信息最通用的命令。 fdisk可以显示分区信息及一些细节信息,比如文件系统类型等。 设备的名称通常是/dev/sda、/dev/sdb 等。 对于以前的设备有可能还存在设备名为 /dev/hd* (IDE)的设备,这个设…...

Hyperledger Fabric 管理链码 peer lifecycle chaincode 指令使用
链上代码(Chaincode)简称链码,包括系统链码和用户链码。系统链码(System Chaincode)指的是 Fabric Peer 中负责系统配置、查询、背书、验证等平台功能的代码逻辑,运行在 Peer 进程内,将在第 14 …...
L1-011 A-B(Java)
题目 本题要求你计算A−B。不过麻烦的是,A和B都是字符串 —— 即从字符串A中把字符串B所包含的字符全删掉,剩下的字符组成的就是字符串A−B。 输入格式: 输入在2行中先后给出字符串A和B。两字符串的长度都不超过10的四次方,并且…...

系列七、Ribbon
一、Ribbon 1.1、概述 Ribbon是基于Netflix Ribbon实现的一套客户端负载均衡的工具,是Netflix发布的一款开源项目,其主要功能是提供客户端的软件负载均衡算法和服务调用,Ribbon客户端组件提供一系列完善的配置项,例如:…...

山东名岳轩印刷包装携专业包装袋盛装亮相2024济南生物发酵展
山东名岳轩印刷包装有限公司盛装亮相2024第12届国际生物发酵展,3月5-7日山东国际会展中心与您相约! 展位号:1号馆F17 山东名岳轩印刷包装有限公司是一家拥有南北两个生产厂区,设计、制版、印刷,营销策划为一体的专业…...
BGP公认必遵属性——Next-hop(一)
BGP公认必遵属性共有三个,分别是:Next-hop、Origin、As-path,本期介绍Next-hop 点赞关注,持续更新!!! Next-hop 华为BGP路由下一跳特点: 默认情况下传给EBGP邻居的BGP路由的下一跳…...

增强Wi-Fi信号的10种方法,值得去尝试
Wi-Fi信号丢失,无线盲区。在一个对一些人来说,上网和呼吸一样必要的世界里,这些问题中的每一个都令人抓狂。 如果你觉得你的Wi-Fi变得迟钝,有很多工具可以用来测试你的互联网速度。你还可以尝试一些技巧来解决网络问题。然而,如果你能获得良好接收的唯一方法是站在无线路…...
第十五章 ECMAScript6新增的常用语法
文章目录 一、声明关键字二、箭头函数三、解构赋值四、展开运算符五、对字符的补充六、Symbol七、对象的简写语法八、Set和Map九、for-of 一、声明关键字 ES6新增的声明关键字: let,const:声明变量class:声明类import,…...

vulhub中的Apache SSI 远程命令执行漏洞
Apache SSI 远程命令执行漏洞 1.cd到ssi-rce cd /opt/vulhub/httpd/ssi-rce/ 2.执行docker-compose up -d docker-compose up -d 3.查看靶场是否开启成功 dooker ps 拉取成功了 4.访问url 这里已经执行成功了,注意这里需要加入/upload.php 5.写入一句话木马 &…...

MSB20M-ASEMI迷你贴片整流桥MSB20M
编辑:ll MSB20M-ASEMI迷你贴片整流桥MSB20M 型号:MSB20M 品牌:ASEMI 封装:UMSB-4 特性:贴片、整流桥 最大平均正向电流:2A 最大重复峰值反向电压:1000V 恢复时间:࿱…...

工程管理系统功能设计与实践:实现高效、透明的工程管理
在现代化的工程项目管理中,一套功能全面、操作便捷的系统至关重要。本文将介绍一个基于Spring Cloud和Spring Boot技术的Java版工程项目管理系统,结合Vue和ElementUI实现前后端分离。该系统涵盖了项目管理、合同管理、预警管理、竣工管理、质量管理等多个…...
《Playwright:微软的自动化测试工具详解》
Playwright 简介:声明内容来自网络,将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具,支持 Chrome、Firefox、Safari 等主流浏览器,提供多语言 API(Python、JavaScript、Java、.NET)。它的特点包括&a…...
渲染学进阶内容——模型
最近在写模组的时候发现渲染器里面离不开模型的定义,在渲染的第二篇文章中简单的讲解了一下关于模型部分的内容,其实不管是方块还是方块实体,都离不开模型的内容 🧱 一、CubeListBuilder 功能解析 CubeListBuilder 是 Minecraft Java 版模型系统的核心构建器,用于动态创…...

cf2117E
原题链接:https://codeforces.com/contest/2117/problem/E 题目背景: 给定两个数组a,b,可以执行多次以下操作:选择 i (1 < i < n - 1),并设置 或,也可以在执行上述操作前执行一次删除任意 和 。求…...
CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云
目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...
【HTTP三个基础问题】
面试官您好!HTTP是超文本传输协议,是互联网上客户端和服务器之间传输超文本数据(比如文字、图片、音频、视频等)的核心协议,当前互联网应用最广泛的版本是HTTP1.1,它基于经典的C/S模型,也就是客…...

【开发技术】.Net使用FFmpeg视频特定帧上绘制内容
目录 一、目的 二、解决方案 2.1 什么是FFmpeg 2.2 FFmpeg主要功能 2.3 使用Xabe.FFmpeg调用FFmpeg功能 2.4 使用 FFmpeg 的 drawbox 滤镜来绘制 ROI 三、总结 一、目的 当前市场上有很多目标检测智能识别的相关算法,当前调用一个医疗行业的AI识别算法后返回…...
力扣-35.搜索插入位置
题目描述 给定一个排序数组和一个目标值,在数组中找到目标值,并返回其索引。如果目标值不存在于数组中,返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 class Solution {public int searchInsert(int[] nums, …...
在Ubuntu24上采用Wine打开SourceInsight
1. 安装wine sudo apt install wine 2. 安装32位库支持,SourceInsight是32位程序 sudo dpkg --add-architecture i386 sudo apt update sudo apt install wine32:i386 3. 验证安装 wine --version 4. 安装必要的字体和库(解决显示问题) sudo apt install fonts-wqy…...

【Linux】Linux 系统默认的目录及作用说明
博主介绍:✌全网粉丝23W,CSDN博客专家、Java领域优质创作者,掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围:SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物…...
第7篇:中间件全链路监控与 SQL 性能分析实践
7.1 章节导读 在构建数据库中间件的过程中,可观测性 和 性能分析 是保障系统稳定性与可维护性的核心能力。 特别是在复杂分布式场景中,必须做到: 🔍 追踪每一条 SQL 的生命周期(从入口到数据库执行)&#…...