当前位置: 首页 > news >正文

pytorch进阶学习(四):使用不同分类模型进行数据训练(alexnet、resnet、vgg等)

课程资源:5、帮各位写好了十多个分类模型,直接运行即可【小学生都会的Pytorch】_哔哩哔哩_bilibili

 

目录

一、项目介绍

 1. 数据集准备

2. 运行CreateDataset.py

3. 运行TrainModal.py 

4. 如何切换显卡型号

二、代码

1. CreateDataset.py

2.TrainModal.py 

3. 运行结果


一、项目介绍

 1. 数据集准备

数据集在data文件夹下

 

2. 运行CreateDataset.py

运行CreateDataset.py来生成train.txt和test.txt的数据集文件。

 

3. 运行TrainModal.py 

进行模型的训练,从torchvision中的models模块import了alexnet, vgg, resnet的多个网络模型,使用时直接取消注释掉响应的代码即可,比如我现在训练的是vgg11的网络。

    # 不使用预训练参数# model = alexnet(pretrained=False, num_classes=5).to(device) # 29.3%'''        VGG系列    '''model = vgg11(weights=False, num_classes=5).to(device)   #  23.1%# model = vgg13(weights=False, num_classes=5).to(device)   # 30.0%# model = vgg16(weights=False, num_classes=5).to(device)'''        ResNet系列    '''# model = resnet18(weights=False, num_classes=5).to(device)    # 43.6%# model = resnet34(weights=False, num_classes=5).to(device)# model = resnet50(weights= False, num_classes=5).to(device)#model = resnet101(weights=False, num_classes=5).to(device)   #  26.2%# model = resnet152(weights=False, num_classes=5).to(device)

 同时需要注意的是, vgg11中的weights参数设置为false,我们进入到vgg的定义页发现weights即为是否使用预训练参数,设置为FALSE说明我们不使用预训练参数,因为vgg网络的预训练类别数为1000,而我们自己的数据集没有那么多类,因此不使用预训练。

 

把最后一行中产生的pth的文件名称改成对应网络的名称,如model_vgg11.pth。 

    # 保存训练好的模型torch.save(model.state_dict(), "model_vgg11.pth")print("Saved PyTorch Model Success!")

4. 如何切换显卡型号

我在运行过程中遇到了“torch.cuda.OutOfMemoryError”的问题,显卡的显存不够,在服务器中使用查看显卡占用情况命令:

nvidia -smi

可以看到我一直用的是默认显卡0,使用情况已经到了100%,但是显卡1使用了67%,还用显存使用空间,所以使用以下代码来把显卡0换成显卡1.

# 设置显卡型号为1
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

 

二、代码

1. CreateDataset.py

'''
生成训练集和测试集,保存在txt文件中
'''
##相当于模型的输入。后面做数据加载器dataload的时候从里面读他的数据
import os
import random#打乱数据用的# 百分之60用来当训练集
train_ratio = 0.6# 用来当测试集
test_ratio = 1-train_ratiorootdata = r"data"  #数据的根目录train_list, test_list = [],[]#读取里面每一类的类别
data_list = []#生产train.txt和test.txt
class_flag = -1
for a,b,c in os.walk(rootdata):print(a)for i in range(len(c)):data_list.append(os.path.join(a,c[i]))for i in range(0,int(len(c)*train_ratio)):train_data = os.path.join(a, c[i])+'\t'+str(class_flag)+'\n'train_list.append(train_data)for i in range(int(len(c) * train_ratio),len(c)):test_data = os.path.join(a, c[i]) + '\t' + str(class_flag)+'\n'test_list.append(test_data)class_flag += 1print(train_list)
random.shuffle(train_list)#打乱次序
random.shuffle(test_list)with open('train.txt','w',encoding='UTF-8') as f:for train_img in train_list:f.write(str(train_img))with open('test.txt','w',encoding='UTF-8') as f:for test_img in test_list:f.write(test_img)

2.TrainModal.py 

'''加载pytorch自带的模型,从头训练自己的数据
'''
import time
import torch
from torch import nn
from torch.utils.data import DataLoader
from utils import LoadDataimport os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'from torchvision.models import alexnet  #最简单的模型
from torchvision.models import vgg11, vgg13, vgg16, vgg19   # VGG系列
from torchvision.models import resnet18, resnet34,resnet50, resnet101, resnet152    # ResNet系列
from torchvision.models import inception_v3     # Inception 系列# 定义训练函数,需要
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)# 从数据加载器中读取batch(一次读取多少张,即批次数),X(图片数据),y(图片真实标签)。for batch, (X, y) in enumerate(dataloader):# 将数据存到显卡X, y = X.cuda(), y.cuda()# 得到预测的结果predpred = model(X)# 计算预测的误差# print(pred,y)loss = loss_fn(pred, y)# 反向传播,更新模型参数optimizer.zero_grad()loss.backward()optimizer.step()# 每训练10次,输出一次当前信息if batch % 10 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")def test(dataloader, model):size = len(dataloader.dataset)# 将模型转为验证模式model.eval()# 初始化test_loss 和 correct, 用来统计每次的误差test_loss, correct = 0, 0# 测试时模型参数不用更新,所以no_gard()# 非训练, 推理期用到with torch.no_grad():# 加载数据加载器,得到里面的X(图片数据)和y(真实标签)for X, y in dataloader:# 将数据转到GPUX, y = X.cuda(), y.cuda()# 将图片传入到模型当中就,得到预测的值predpred = model(X)# 计算预测值pred和真实值y的差距test_loss += loss_fn(pred, y).item()# 统计预测正确的个数correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= sizecorrect /= sizeprint(f"correct = {correct}, Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")if __name__=='__main__':batch_size = 8# # 给训练集和测试集分别创建一个数据集加载器train_data = LoadData("train.txt", True)valid_data = LoadData("test.txt", False)train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True)test_dataloader = DataLoader(dataset=valid_data, num_workers=4, pin_memory=True, batch_size=batch_size)# 如果显卡可用,则用显卡进行训练device = "cuda" if torch.cuda.is_available() else "cpu"print(f"Using {device} device")'''随着模型的加深,需要训练的模型参数量增加,相同的训练次数下模型训练准确率起来得更慢'''# 不使用预训练参数# model = alexnet(pretrained=False, num_classes=5).to(device) # 29.3%'''        VGG系列    '''model = vgg11(weights=False, num_classes=5).to(device)   #  23.1%# model = vgg13(weights=False, num_classes=5).to(device)   # 30.0%# model = vgg16(weights=False, num_classes=5).to(device)'''        ResNet系列    '''# model = resnet18(weights=False, num_classes=5).to(device)    # 43.6%# model = resnet34(weights=False, num_classes=5).to(device)# model = resnet50(weights= False, num_classes=5).to(device)#model = resnet101(weights=False, num_classes=5).to(device)   #  26.2%# model = resnet152(weights=False, num_classes=5).to(device)print(model)# 定义损失函数,计算相差多少,交叉熵,loss_fn = nn.CrossEntropyLoss()# 定义优化器,用来训练时候优化模型参数,随机梯度下降法optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)  # 初始学习率# 一共训练1次epochs = 1for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")time_start = time.time()train(train_dataloader, model, loss_fn, optimizer)time_end = time.time()print(f"train time: {(time_end-time_start)}")test(test_dataloader, model)print("Done!")# 保存训练好的模型torch.save(model.state_dict(), "model_vgg11.pth")print("Saved PyTorch Model Success!")

3. 运行结果

vgg11的运行结果:,可以看到最后的准确率是24.8%,因为没有用预训练模型,所以准确率很低。

 

相关文章:

pytorch进阶学习(四):使用不同分类模型进行数据训练(alexnet、resnet、vgg等)

课程资源:5、帮各位写好了十多个分类模型,直接运行即可【小学生都会的Pytorch】_哔哩哔哩_bilibili 目录 一、项目介绍 1. 数据集准备 2. 运行CreateDataset.py 3. 运行TrainModal.py 4. 如何切换显卡型号 二、代码 1. CreateDataset.py 2.Train…...

Java面向对象高级【注解和反射】

目录 注解 什么是注解? 自定义注解 元注解 反射 什么是反射 静态语言和动态语言 动态语言 静态语言 对比 Class类 Java内存分析 类加载过程 类加载器 获取运行时类的完整结构 通过Class对象实例化对象 1.调用Class对象的newInstance 2.Constructor…...

Pytorch基础 - 4. torch.expand() 和 torch.repeat()

目录 1. torch.expand(*sizes) 2. torch.repeat(*sizes) 3. 两者内存占用的区别 在PyTorch中有两个函数可以用来扩展某一维度的张量,即 torch.expand() 和 torch.repeat() 1. torch.expand(*sizes) 【含义】将输入张量在大小为1的维度上进行拓展,…...

《LeetCode》——LeetCode刷题日记

本期,将给大家带来的是关于 LeetCode 的关于二叉树的题目讲解。 目录 (一)606. 根据二叉树创建字符串 💥题意分析 💥解题思路 (二)102. 二叉树的层序遍历 💥题意分析 &#…...

mysql数据库审计(1)

1.数据库审计工具介绍及选择 1.1. 数据库审计工具介绍 MySQL 分支的审计功能包含在企业版中,社区版可以使用其他分支提供的工具。目前已知的审计工具,社区版本有 Percona 的 Percona Server Audit Log 、MariaDB 的 MariaDB Audit Plugin 和 McAfee 的…...

Kafka---kafka概述和kafka基础架构

kafka概述和kafka基础架构 文章目录kafka概述和kafka基础架构Kafka定义消息队列传统消息队列应用场景缓存/消峰解耦异步通信消息队列的两种模式点对点模式发布/订阅模式kafka基础架构producerConsumerConsumer Group(CG)BrokerTopicPartitionReplicaLead…...

《JavaEE初阶》多线程基础

《JavaEE初阶》多线程基础 文章目录《JavaEE初阶》多线程基础前言:多线程的概念简单创建线程并运行:简述Thread中run方法与start方法的区别创建线程的几种方法:探讨串行执行与并行执行的执行时间多线程的使用场景:Thread类简单介绍:构造方法:获取线程的常见属性:线程的常用方法…...

技术分享 | OMS 初识

作者:高鹏 DBA,负责项目日常问题排查,广告位长期出租 。 本文来源:原创投稿 *爱可生开源社区出品,原创内容未经授权不得随意使用,转载请联系小编并注明来源。 本文主要贡献者:进行OMS源码分析的…...

【Elastic (ELK) Stack 实战教程】10、ELK 架构升级-引入消息队列 Redis、Kafka

目录 一、ELK 架构面临的问题 1.1 耦合度过高 1.2 性能瓶颈 二、ELK 对接 Redis 实践 2.1 配置 Redis 2.1.1 安装 Redis 2.1.2 配置 Redis 2.1.3 启动 Redis 2.2 配置 Filebeat 2.3 配置 Logstash 2.4 数据消费 2.5 配置 kibana 三、消息队列基本概述 3.1 什么是…...

优先、双端队列-我的基础算法刷题之路(八)

本篇博客旨在整理记录自已对优先队列、双端队列的一些总结,以及刷题的解题思路,同时希望可给小伙伴一些帮助。本人也是算法小白,水平有限,如果文章中有什么错误之处,希望小伙伴们可以在评论区指出来,共勉 &…...

Python3 os.symlink() 方法、Python 质数判断

Python3 os.symlink() 方法 概述 os.symlink() 方法用于创建一个软链接。 语法 symlink()方法语法格式如下: os.symlink(src, dst)参数 src -- 源地址。 dst -- 目标地址。 返回值 该方法没有返回值。 实例 以下实例演示了 symlink() 方法的使用&#xff1…...

P1972 [SDOI2009] HH的项链

[SDOI2009] HH的项链 题目描述 HH 有一串由各种漂亮的贝壳组成的项链。HH 相信不同的贝壳会带来好运,所以每次散步完后,他都会随意取出一段贝壳,思考它们所表达的含义。HH 不断地收集新的贝壳,因此,他的项链变得越来…...

​力扣解法汇总1026. 节点与其祖先之间的最大差值

目录链接: 力扣编程题-解法汇总_分享记录-CSDN博客 GitHub同步刷题项目: https://github.com/September26/java-algorithms 原题链接:力扣 描述: 给定二叉树的根节点 root,找出存在于 不同 节点 A 和 B 之间的最大值…...

010:Mapbox GL移动鼠标mousemove,显示坐标信息

第010个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+mapbox中移动鼠标mousemove,显示坐标信息。 直接复制下面的 vue+mapbox源代码,操作2分钟即可运行实现效果 文章目录 示例效果配置方式示例源代码(共81行)相关API参考:专栏目标示例效果 配置方式 1)查看基础…...

【两阶段鲁棒优化】利用列-约束生成方法求解两阶段鲁棒优化问题(Python代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...

百度暑期实习 C++ 一面

1.数组 链表 数组是一种线性数据结构,其中相同类型的元素连续存储在一段内存中,并且可以通过索引来访问每个元素。数组的优点是随机访问元素非常快速,但缺点是插入或删除元素可能需要移动其他元素。 链表也是一种线性数据结构,但…...

计算机网络第一章(概述)【湖科大教书匠】

1. 各种网络 网络(Network)由若干**结点(Node)和连接这些结点的链路(Link)**组成多个网络还可以通过路由器互连起来,这样就构成了一个覆盖范围更大的网络,即互联网(互连网)。因此,互联网是"网络的网络(Network of Networks)"**因特…...

【JS】vis.js使用之vis-timeline使用攻略,vis-timeline在vue3中实现时间轴、甘特图

vis.js使用之vis-timeline使用攻略,vis-timeline实现时间轴、甘特图1、vis-timeline简介2、安装插件及依赖3、简单示例4、疑难问题集合1. 中文zh-cn本地化2. 关于自定义class样式无法被渲染3. 关于双向数据绑定vis.js是一个基于浏览器的可视化库,它提供了…...

机器学习——数据处理

机器学习简介 机器学习是人工智能的一个实现途径深度学习是机器学习的一个方法发展而来 机器学习:从数据中自动分析获得模型,并利用模型对未知数据进行预测。 数据集的格式: 特征值目标值 比如上图中房子的各种属性是特征值,然…...

多种文字翻译软件-翻译常用软件

整篇文档翻译软件 整篇文档翻译软件是一种实现全文翻译的自动翻译工具,它能够快速、准确地将整篇文档的内容翻译成目标语言。与单词、句子翻译不同,整篇文档翻译软件不仅需要具备准确的语言识别和翻译技术,还需要考虑上下文语境和文档格式等多…...

浅谈 React Hooks

React Hooks 是 React 16.8 引入的一组 API,用于在函数组件中使用 state 和其他 React 特性(例如生命周期方法、context 等)。Hooks 通过简洁的函数接口,解决了状态与 UI 的高度解耦,通过函数式编程范式实现更灵活 Rea…...

[特殊字符] 智能合约中的数据是如何在区块链中保持一致的?

🧠 智能合约中的数据是如何在区块链中保持一致的? 为什么所有区块链节点都能得出相同结果?合约调用这么复杂,状态真能保持一致吗?本篇带你从底层视角理解“状态一致性”的真相。 一、智能合约的数据存储在哪里&#xf…...

【OSG学习笔记】Day 18: 碰撞检测与物理交互

物理引擎(Physics Engine) 物理引擎 是一种通过计算机模拟物理规律(如力学、碰撞、重力、流体动力学等)的软件工具或库。 它的核心目标是在虚拟环境中逼真地模拟物体的运动和交互,广泛应用于 游戏开发、动画制作、虚…...

【服务器压力测试】本地PC电脑作为服务器运行时出现卡顿和资源紧张(Windows/Linux)

要让本地PC电脑作为服务器运行时出现卡顿和资源紧张的情况,可以通过以下几种方式模拟或触发: 1. 增加CPU负载 运行大量计算密集型任务,例如: 使用多线程循环执行复杂计算(如数学运算、加密解密等)。运行图…...

AI编程--插件对比分析:CodeRider、GitHub Copilot及其他

AI编程插件对比分析:CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展,AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者,分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...

根据万维钢·精英日课6的内容,使用AI(2025)可以参考以下方法:

根据万维钢精英日课6的内容,使用AI(2025)可以参考以下方法: 四个洞见 模型已经比人聪明:以ChatGPT o3为代表的AI非常强大,能运用高级理论解释道理、引用最新学术论文,生成对顶尖科学家都有用的…...

React---day11

14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store: 我们在使用异步的时候理应是要使用中间件的,但是configureStore 已经自动集成了 redux-thunk,注意action里面要返回函数 import { configureS…...

Spring AI Chat Memory 实战指南:Local 与 JDBC 存储集成

一个面向 Java 开发者的 Sring-Ai 示例工程项目,该项目是一个 Spring AI 快速入门的样例工程项目,旨在通过一些小的案例展示 Spring AI 框架的核心功能和使用方法。 项目采用模块化设计,每个模块都专注于特定的功能领域,便于学习和…...

保姆级【快数学会Android端“动画“】+ 实现补间动画和逐帧动画!!!

目录 补间动画 1.创建资源文件夹 2.设置文件夹类型 3.创建.xml文件 4.样式设计 5.动画设置 6.动画的实现 内容拓展 7.在原基础上继续添加.xml文件 8.xml代码编写 (1)rotate_anim (2)scale_anim (3)translate_anim 9.MainActivity.java代码汇总 10.效果展示 逐帧…...

相关类相关的可视化图像总结

目录 一、散点图 二、气泡图 三、相关图 四、热力图 五、二维密度图 六、多模态二维密度图 七、雷达图 八、桑基图 九、总结 一、散点图 特点 通过点的位置展示两个连续变量之间的关系,可直观判断线性相关、非线性相关或无相关关系,点的分布密…...