学习基于pytorch的VGG图像分类 day3
注:本系列博客在于汇总CSDN的精华帖,类似自用笔记,不做学习交流,方便以后的复习回顾,博文中的引用都注明出处,并点赞收藏原博主.
目录
VGG模型训练
1.导入必要的库
2.主函数部分
2.1使用cpu或gpu
2.2对数据进行预处理
2.3 训练集部分
2.4索引与标签
2.5创建数据加载器
2.6验证集部分
2.7模型的初始化
2.8训练部分
2.9评估验证部分
2.10主函数入口
小结
VGG模型训练
1.导入必要的库
导入所需的库,以及导入自定义的VGG模型模版
import os
import sys
import json import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm from model import vgg # 导入自定义的VGG模型模块
2.主函数部分
def main():
2.1使用cpu或gpu
# 检查是否有可用的CUDA设备,如果有则使用GPU,否则使用CPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("using {} device.".format(device))
2.2对数据进行预处理
# 定义数据预处理操作,包括随机裁剪、随机水平翻转、转为Tensor格式、标准化 data_transform = { "train": transforms.Compose([ transforms.RandomResizedCrop(224), # 随机裁剪为224x224大小 transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.ToTensor(), # 将PIL Image或ndarray转换为torch.FloatTensor,并归一化到[0.0, 1.0] transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化处理,减均值除标准差 ]), "val": transforms.Compose([ transforms.Resize((224, 224)), # 调整图片大小到224x224 transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 标准化处理 }
2.3 训练集部分
读取脚本的路径,构建一个图像数据的跟路径(用条件判断断言路径是否存在,不存在进行报错)加载训练数据集
# 获取当前脚本的绝对路径 current_file_dir = os.path.dirname(os.path.abspath(__file__)) # 构建图像数据的根路径 image_root = os.path.join(current_file_dir, "image_path") train_dir = os.path.join(image_root, "train") # 训练集目录 # 确保训练集目录存在 assert os.path.exists(train_dir), "{} path does not exist.".format(train_dir) # 使用torchvision.datasets的ImageFolder类加载训练数据集,它假设每个子文件夹的名称是其对应的类别 train_dataset = datasets.ImageFolder(root=train_dir, transform=data_transform["train"]) # 定义保存类别标签与索引对应关系的json文件路径 image_path = os.path.join(image_root) # 计算训练集样本数量 train_num = len(train_dataset)
2.4索引与标签
# 获取类别标签与索引的对应关系 flower_list = train_dataset.class_to_idx # 反转字典,将索引映射到类别标签 cla_dict = {val: key for key, val in flower_list.items()} # 将类别索引到标签的映射关系写入json文件 json_str = json.dumps(cla_dict, indent=4) # 使用json库将字典转化为格式化字符串 with open('class_indices.json', 'w') as json_file:
2.5创建数据加载器
定义每个数据加载器使用的工作进程数量,若自身内存不够,可以小一点!!!
# 定义每个数据加载器使用的工作进程数量 batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # 取CPU核心数, batch_size(如果大于1)和8中的最小值 print('Using {} dataloader workers every process'.format(nw)) # 打印每个进程使用的数据加载器工作进程数 # 创建训练集数据加载器 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw)
2.6验证集部分
道理同训练集。
# 验证集目录 val_dir = os.path.join(image_path, "val") # 验证集文件夹路径 # 检查验证集目录是否存在 assert os.path.exists(val_dir), "{} path does not exist.".format(val_dir) # 加载验证数据集 validate_dataset = datasets.ImageFolder(root=val_dir, transform=data_transform["val"]) # 获取验证集样本数量 val_num = len(validate_dataset) # 创建验证集数据加载器 validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, num_workers=nw) # 打印训练集和验证集的样本数量 print("using {} images for training, {} images for validation.".format(train_num, val_num))
2.7模型的初始化
对模型各个参数进行初始化,确定分类个数,训练轮数,最佳准确率,学习率
# 初始化模型 model_name = "vgg16" net = vgg(model_name=model_name, num_classes = 4, init_weights=True) # 创建VGG16模型,类别数为4,并初始化权重 net.to(device) # 将模型转移到指定的设备上(CPU或GPU) # 定义损失函数和优化器 loss_function = nn.CrossEntropyLoss() # 交叉熵损失函数,用于分类问题 optimizer = optim.Adam(net.parameters(), lr=0.0001) # 使用Adam优化器,学习率为0.0001 # 设置训练轮数 epochs = 60# 初始化最佳准确率 best_acc = 0.0 # 设置模型保存路径 save_path = './{}Net.pth'.format(model_name) # 计算训练步骤数 train_steps = len(train_loader)
2.8训练部分
训练及展示进度
# 开始训练循环 for epoch in range(epochs): # 将模型设置为训练模式 net.train() # 初始化运行损失 running_loss = 0.0 # 使用tqdm库创建进度条,用于显示训练进度 train_bar = tqdm(train_loader, file=sys.stdout) # 开始每个epoch的训练步骤循环 for step, data in enumerate(train_bar): # 从数据加载器中获取图像和标签 images, labels = data # 梯度清零 optimizer.zero_grad() # 前向传播,计算输出 outputs = net(images.to(device)) # 计算损失 loss = loss_function(outputs, labels.to(device)) # 反向传播,计算梯度 loss.backward() # 更新模型参数 optimizer.step() # 更新运行损失(这部分代码在原始代码中被省略了,通常需要用于记录或展示) # 将当前损失值添加到运行损失中 running_loss += loss.item() # 打印训练过程中的统计信息 # 格式化字符串,显示当前epoch、总epoch数和当前损失值 train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
2.9评估验证部分
# 验证模型性能 net.eval() # 将模型设置为评估模式 acc = 0.0 # 初始化累积的正确预测数量 with torch.no_grad(): # 不计算梯度,节省计算资源 val_bar = tqdm(validate_loader, file=sys.stdout) # 创建验证集的进度条 for val_data in val_bar: # 遍历验证集数据 val_images, val_labels = val_data # 获取图像和标签 outputs = net(val_images.to(device)) # 前向传播,获取模型输出 predict_y = torch.max(outputs, dim=1)[1] # 获取预测类别 # 计算预测正确的数量,并累加到acc中 acc += torch.eq(predict_y, val_labels.to(device)).sum().item() # 计算验证集的准确率 val_accurate = acc / val_num # 打印当前epoch的训练损失和验证准确率 print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, running_loss / train_steps, val_accurate)) # 如果当前验证准确率高于最佳准确率,则更新最佳准确率并保存模型状态 if val_accurate > best_acc: best_acc = val_accurate torch.save(net.state_dict(), save_path) # 保存模型权重到指定路径 print('Finished Training') # 训练完成,打印提示信息
2.10主函数入口
# 主函数入口
if __name__ == '__main__': main() # 调用main函数,开始训练过程
小结
1.对内存不够的情况要降低batch_size的值,否则模型无法训练
2.在构建路径后,一定要用条件判断断言路径是否存在(这是一个好习惯)
3.在模型训练时最好实时更新数据(可以更加直观的体现)
相关文章:
学习基于pytorch的VGG图像分类 day3
注:本系列博客在于汇总CSDN的精华帖,类似自用笔记,不做学习交流,方便以后的复习回顾,博文中的引用都注明出处,并点赞收藏原博主. 目录 VGG模型训练 1.导入必要的库 2.主函数部分 2.1使用cpu或gpu 2.2对数据…...
Spring Boot统一功能处理之拦截器
本篇主要介绍Spring Boot的统一功能处理中的拦截器。 目录 一、拦截器的基本使用 二、拦截器实操 三、浅尝源码 初始化DispatcherServerlet 处理请求(doDispatch) 四、适配器模式 一、拦截器的基本使用 在一般的学校或者社区门口,通常会安排几个…...
stm32之基本定时器的使用
在上文我们使用到了HAL库的自带的延时函数,HAL_Delay();我们来看一下函数的原型 __weak void HAL_Delay(uint32_t Delay) {uint32_t tickstart HAL_GetTick();uint32_t wait Delay;/* Add a freq to guarantee minimum wait */…...
单片机为什么还在用C语言编程?
单片机产品的成本是非常敏感的。因此对于单片机开发来说,最重要的是在极其有限的ROM和RAM中实现最多产品的功能。或者反过来说,实现相同的产品功能,所需要的ROM和RAM越小越好,在开始前我有一些资料,是我根据网友给的问…...
IO流的基础详解
文件【1】File类: 封装文件/目录的各种信息,对目录/文件进行操作,但是我们不可以获取到文件/目录中的内容。 【2】引入:IO流: I/O : Input/Output的缩写,用于处理设备之间的数据的传输。 【3】…...
实战攻防 | 记一次项目上的任意文件下载
1、开局 开局一个弱口令,正常来讲我们一般是弱口令或者sql,或者未授权 那么这次运气比较好,直接弱口令进去了 直接访问看看有没有功能点,正常做测试我们一定要先找功能点 发现一个文件上传点,不过老规矩,还…...
熔断之神:探寻Hystrix的秘密与实践指南
引言: 在微服务架构中,服务之间的依赖复杂且难以控制,容灾机制成为确保系统稳定性的重要手段。Hystrix作为Netflix开源的断路器实现,提供了一系列强健的容错功能。 Hystrix的核心概念与作用: Hystrix是一个由Netflix开…...
Web功能测试测试点总结!
web测试就是基于BS架构的软件产品的测试,通俗点来说就是web网站的测试。 一 、界面检查 当我们进入一个页面时,首先应该检查title,页面排版(即页面的展示),而不是马上进入字段校验页面面包屑导航是否正确当前位置是否可见 您的位…...
关于vue3的简单学习
Vue 3 简介 Vue 3 是一个流行的开源Java框架,用于构建用户界面和单页面应用。它带来了许多新特性和改进,包括更好的性能、更小的打包大小、更好的Type支持、全新的组合式 API,以及一些新的内置组件。 一. Vue 3 的新特性 Vue 3引入了许多新…...
windows server 2019 -DNS服务器搭建
前面是有关DNS的相关理论知识,懂了的可以直接跳到第五点。 说明一下:作为服务器ip最好固定下来,以DNS服务器为例子,如果客户机的填写DNS信息的之后,服务器的ip如果变动了的话,客户机都得跟着改,…...
使用 XCTest 进行 iOS UI 自动化测试
使用 XCTest 进行 iOS UI 自动化测试是一种有效的方法,可以帮助你验证应用界面的行为和功能。以下是使用 XCTest 进行 iOS UI 自动化测试的基本步骤: 设置项目: 确保你的 Xcode 项目已经包含了 XCTest 测试目标。在测试目标中创建一个新的测试类…...
【Python】FANUC机器人OPC UA通信并记录数据
目录 引言机器人仿真环境准备代码实现1. 导入库2. 设置参数3. 日志配置4. OPC UA通信5. 备份旧CSV文件6. 主函数 总结 引言 OPC UA(Open Platform Communications Unified Architecture)是一种跨平台的、开放的数据交换标准,常用于工业自动化…...
Linux 中断处理
一、基本概念 1、中断及中断上下文 中断是一种由硬件设备产生的信号,不同设备产生的中断通过中断号来区分。CPU在接收到中断信号后,根据中断号执行对应的中断处理程序(Interrupt Service Routine) 内核对异常和中断的处理类似&a…...
人大金昌netcore适配,调用oracle模式下存储过程\包,返回参数游标
using KdbndpConnection conn new KdbndpConnection("Host192.168.133.221;Port54321;Databasedb1;Poolingtrue;User IDsystem;Password123");conn.Open();//存储过程调用也是类似using var cmd conn.CreateCommand();cmd.CommandText "模式.包名称.存储过程…...
pandas常用的一些操作
EXCLE操作 读取Excel data1 pd.read_excel(excle_dir) 读Excel取跳过前几行: data1 pd.read_excel(excle_dir,skiprows1) 获取总行数 data1.shape[0] 获取总列数 data1.shape[1] 指定某列数据类型 data1 pd.read_excel("C:数据导入.xlsx",dtype…...
【鸿蒙开发】系统组件Row
Row组件 Row沿水平方向布局容器 接口: Row(value?:{space?: number | string }) 参数: 参数名 参数类型 必填 参数描述 space string | number 否 横向布局元素间距。 从API version 9开始,space为负数或者justifyContent设置为…...
Hadoop和zookeeper集群相关执行脚本(未完,持续更新中~)
1、Hadoop集群查看状态 搭建Hadoop数据集群时,按以下路径操作即可生成脚本 [test_1analysis01 bin]$ pwd /home/test_1/hadoop/bin [test_01analysis01 bin]$ vim jpsall #!/bin/bash for host in analysis01 analysis02 analysis03 do echo $host s…...
蓝桥杯算法题:栈(Stack)
这道题考的是递推动态规划,可能不是很难,不过这是自己第一次靠自己想出状态转移方程,所以纪念一下: 要做这些题目,首先要把题目中会出现什么状态给找出来,然后想想他们的状态可以通过什么操作转移…...
JavaWeb-监听器
文章目录 1.基本介绍2.ServletContextListener1.基本介绍2.创建maven项目,导入依赖3.代码演示1.实现ServletContextListener接口2.配置web.xml3.结果 3.ServletContextAttributeListener监听器1.基本介绍2.代码实例1.ServletContextAttributeListener.java2.配置web…...
系统架构设计基础知识
一. 系统架构概述系统架构的定义 系统架构(System Architecture)是系统的一种整体的高层次的结构表示,是系统的骨架和根基,支撑和链接各个部分,包括构件、连接件、约束规范以及指导这些内容设计与演化的原理࿰…...
LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器的上位机配置操作说明
LBE-LEX系列工业语音播放器|预警播报器|喇叭蜂鸣器专为工业环境精心打造,完美适配AGV和无人叉车。同时,集成以太网与语音合成技术,为各类高级系统(如MES、调度系统、库位管理、立库等)提供高效便捷的语音交互体验。 L…...
C++_核心编程_多态案例二-制作饮品
#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为:煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例,提供抽象制作饮品基类,提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...
【JVM】- 内存结构
引言 JVM:Java Virtual Machine 定义:Java虚拟机,Java二进制字节码的运行环境好处: 一次编写,到处运行自动内存管理,垃圾回收的功能数组下标越界检查(会抛异常,不会覆盖到其他代码…...
渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止
<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet: https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...
python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)
更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...
Ascend NPU上适配Step-Audio模型
1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统,支持多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤)&#x…...
保姆级教程:在无网络无显卡的Windows电脑的vscode本地部署deepseek
文章目录 1 前言2 部署流程2.1 准备工作2.2 Ollama2.2.1 使用有网络的电脑下载Ollama2.2.2 安装Ollama(有网络的电脑)2.2.3 安装Ollama(无网络的电脑)2.2.4 安装验证2.2.5 修改大模型安装位置2.2.6 下载Deepseek模型 2.3 将deepse…...
无人机侦测与反制技术的进展与应用
国家电网无人机侦测与反制技术的进展与应用 引言 随着无人机(无人驾驶飞行器,UAV)技术的快速发展,其在商业、娱乐和军事领域的广泛应用带来了新的安全挑战。特别是对于关键基础设施如电力系统,无人机的“黑飞”&…...
如何更改默认 Crontab 编辑器 ?
在 Linux 领域中,crontab 是您可能经常遇到的一个术语。这个实用程序在类 unix 操作系统上可用,用于调度在预定义时间和间隔自动执行的任务。这对管理员和高级用户非常有益,允许他们自动执行各种系统任务。 编辑 Crontab 文件通常使用文本编…...
C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)
名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...
