当前位置: 首页 > news >正文

完整模型的训练套路

从心所欲

不逾矩

天大地大

皆可去

一、官方模型的初使用

使用VGG16模型

 VGG模型使用代码示例:

import torchvision.models
from torch import nndataset = torchvision.datasets.CIFAR10('/cifar10', False, transform=torchvision.transforms.ToTensor())vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16_false = torchvision.models.vgg16(pretrained=False)
print(vgg16_false)# 改造VGG,增加一层
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)# 改造vgg,修改一层
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

说明:

  1. pretrained=True:当设置为True时,模型将加载在大规模图像数据集(如ImageNet)上预训练的权重。这些预训练的权重经过了在大量图像上的训练,可以捕捉到通用的图像特征。通过加载预训练权重,可以将VGG模型初始化为在ImageNet上训练得到的状态,并且这些权重可以作为初始参数用于特定任务的微调或迁移学习。

  2. pretrained=False:当设置为False时,模型将使用随机初始化的权重。这意味着模型的权重没有经过预训练,需要从头开始进行训练。在这种情况下,模型将不会具备捕捉通用图像特征的能力,而是需要根据特定任务的数据进行训练。

pretrained=Truepretrained=False区别在于是否加载预训练的权重。如果你想要在特定任务上使用VGG模型,并且你的任务与图像分类或特征提取相关,那么通常建议使用pretrained=True,以便利用预训练权重的优势。如果你的任务与图像分类或特征提取无关,或者你希望从头开始训练模型以适应特定数据集,那么可以选择pretrained=False

二、模型的保存与加载

模型的保存:

两种保存模式,官方推荐第二种,只保存参数,不保存模型

import torch
import torchvision.modelsvgg16 = torchvision.models.vgg16(pretrained=False)# 保存方式1: 既保存模型结构,也保存参数
torch.save(vgg16, 'vgg16_model1.pth')# 保存方式2:把参数保存成字典,不保存结构(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_model2.pth')print("end")

模型的加载:
 

import torch
import torchvision.models# 加载方式1 - 保存方式1
model = torch.load('vgg16_model1.pth')# 加载方式2 - 保存方式2
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('vgg16_model2.pth'))

三、完整的模型训练套路

以CIFAR10数据集来一个完整的模型训练。

代码示例:

model.py

from torch import nn# 搭建神经网络
class Lh(nn.Module):def __init__(self):super(Lh, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x

train.py

import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom model import Lh# 准备数据集
train_data = torchvision.datasets.CIFAR10('./cifar10', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.CIFAR10('./cifar10', train=False, transform=torchvision.transforms.ToTensor(), download=True)
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用DataLoader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 搭建神经网络 - 10分类
lh = Lh()# 损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(lh.parameters(), lr=learning_rate)# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练轮数
epoch = 10# 添加tensorboard
writer = SummaryWriter("train_logs")for i in range(epoch):print("-----第{}轮训练开始了-----".format(i + 1))# 训练步骤开始for data in train_dataloader:imgs, tragets = dataoutput = lh(imgs)loss = loss_fn(output, tragets)optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试步骤开始total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, tragets = dataoutput = lh(imgs)loss = loss_fn(output, tragets)total_test_loss += lossaccuracy = (output.argmax(1) - - tragets).sum()total_accuracy += accuracyprint("整体测试机上误差:{}".format(total_test_loss))print("整体测试机上的正确率:{}".format(total_accuracy / test_data_size))writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy / total_test_step)total_test_step += 1# torch.save(lh, "lhy_{}.pth".format(i))# print("模型已保存")writer.close()

输出结果:

 在tensorboard打开

相关文章:

完整模型的训练套路

从心所欲 不逾矩 天大地大 皆可去 一、官方模型的初使用 使用VGG16模型 VGG模型使用代码示例: import torchvision.models from torch import nndataset torchvision.datasets.CIFAR10(/cifar10, False, transformtorchvision.transforms.ToTensor())vgg16_true …...

PtahDAO:全球首个DAO治理资产信托计划的金融平台

金融科技是当今世界最具创新力和影响力的领域之一,区块链技术作为金融科技的核心驱动力,正在颠覆传统的金融模式,为全球用户提供更加普惠、便捷、安全的金融服务。在这个变革的浪潮中,PtahDAO(普塔道)作为全…...

从零搭建一个react + electron项目

最近打算搭建一个react electron的项目,发现并不是那么傻瓜式 于是记录一下自己的实践步骤 通过create-react-app 创建react项目 npx create-react-app my-app 安装electron依赖 npm i electron -D暴露react项目的配置文件(这一步看自己需求&#xff0c…...

MATLAB /Simulink 快速开发STM32(使用st官方工具 STM32-MAT/TARGET),以及开发过程

配置好环境以后就是开发: stm32cube配置芯片,打开matlab添加ioc文件,写处理逻辑,生成代码,下载到板子中去。 配置需要注意事项: STM32CUBEMAX6.5.0 MABLAB2022BkeilV5.2 Matlab生成的代码CTRLB 其中关键的…...

LeetCode 热题 100 JavaScript--102. 二叉树的层序遍历

给你二叉树的根节点 root ,返回其节点值的 层序遍历 。 (即逐层地,从左到右访问所有节点)。 输入:root [3,9,20,null,null,15,7] 输出:[[3],[9,20],[15,7]] 示例 2: 输入:root [1…...

常见Git命令

Git常见命令 1. 添加单个文件 git add a.txt2. 添加多个文件 git add a.txt b.txt c.txt3. 添加(commit)修改,此时修改还未push到服务器上 git commit -m "修改了a.txt内容"4. 提交(push)修改,此时修改会同步到服务器上 git push5. 查看当…...

在C语言中调用汇编语言的函数

在C语言中调用汇编文件中的函数,要做的主要工作有两个: 一是在C语言中声明函数原型,并加extern关键字; 二是在汇编中用EXPORT导出函数名,并用该函数名作为汇编代码段的标识,最后用mov pc, lr返回。然后&a…...

Delphi Professional Crack,IDE插件开发和扩展IDE

Delphi Professional Crack,IDE插件开发和扩展IDE 构建具有强大视觉设计功能的单源多平台本机应用程序。 Delphi帮助您使用Object Pascal为Windows、Mac、Mobile、IoT和Linux构建和更新数据丰富、超连接、可视化的应用程序。Delphi Professional适合个人开发人员和小型团队构建…...

程序框架-公共MONO模块

作用:让没有继承MONO的类可以开启协程,可以update更新,可以统一管理update MonoController脚本继承MonoBehaviour使得脚本过场不移除,并通过UnityAction可以添加多个函数(多播委托),实现Update…...

采用鲁棒随机森林实现的流异常检测:Python应用的一种新型机器学习方法

在数字化和互联网化日益普遍的现代社会,处理海量的网络流量数据是网络安全分析中不可或缺的一部分。流异常检测是一种重要的技术,用于发现可能的安全威胁,例如:网络攻击、恶意行为和系统故障等。随机森林是一种普遍用于解决这类问题的机器学习算法。在本文中,我们将介绍一…...

缓存友好在实际编程中的重要性

引入 当CPU执行程序时,需要频繁地访问主存储器(RAM)中的数据和指令。然而,主存储器的访问速度相对较慢,与CPU的运算速度相比存在显著差异,每次都从主存中读取数据都会导致相对较长的等待时间,从…...

uni-ajax网络请求库使用

uni-ajax网络请求库使用 uni-ajax是什么 uni-ajax是基于 Promise 的轻量级 uni-app 网络请求库,具有开箱即用、轻量高效、灵活开发 特点。 下面是安装和使用教程 安装该请求库到项目中 npm install uni-ajax编辑工具类request.js // ajax.js// 引入 uni-ajax 模块 import ajax…...

MYSQL进阶-事务

1.什么是数据库事务? 事务是一个不可分割的数据库操作序列,也是数据库并发控制的基本单位,其执 行的结果必须使数据库从一种一致性状态变到另一种一致性状态。事务是逻辑上 的一组操作,要么都执行,要么都不执行。 事务…...

python 常见数据类型和方法

不可变数据类型 不支持直接增删改 只能查 str 字符串 int 整型 bool 布尔值 None None型特殊常量 tuple 元组(,,,)回到顶部 可变数据类型,支持增删改查 list 列表[,,,] dic 字典{"":"","": ,} set 集合("",""…...

a-date-picker报错TypeError: date4.locale is not a function

问题描述 使用日期选择器,数据从后端获得,再赋值给a-date-picker做数据回显,遇到这个报错,排错后定位到a-date-picker组件本身接收数据的问题。 如果使用了dayjs或moment库来处理时间字符串,并且使用.format对时间数据…...

LNMP安装

目录 1、LNMP简述: 1.1、概述 1.2、LNMP是一个缩写词,及每个字母的含义 1.3、编译安装与yum安装差异 1.4、编译安装的优点 2、通过LNMP创建论坛 2.1、 安装nginx服务 2.1.1、关闭防火墙 2.1.2、创建运行用户 2.1.3、 编译安装 2.1.4、 优化路…...

matplotlib绘图风格

文章目录 绘图风格测试代码默认和mpl风格复制风格seaborn风格 绘图风格 matplotlib功能强大,可以定制各种绘图要素,以满足个性化的绘图需求,而更换绘图风格也十分便捷,一个matplotlib.style.use函数轻松搞定,而可用的…...

【初级教程】Appium 启动应用 log 日志分析

刚开始学习 appium 时,老师给我布置了 appium 启动应用 log 分析的作业。由于工作比较忙,再者自己想先动手用 appium 写个公司的 app 的 UI 测试(目前简单的框架基本完成,在不断完善用例管理中)。写这篇文章是为了完成…...

FANUC机器人SRVO-300机械手断裂故障报警原因分析及处理办法

FANUC机器人SRVO-300机械手断裂故障报警原因分析及处理办法 首先,我们查看报警说明书上的介绍: 总结:即在机械手断裂设置为无效时,机器人检测出了机械手断裂信号(不该有的信号,现在检测到了,所以报警) 使机械手断裂设定为无效/有效的具体方法:  按下示教器的MENU菜单…...

MobPush iOS SDK iOS实时活动

开发工具:Xcode 功能需要: SwiftUI实现UI页面,iOS16.1以上系统使用 功能使用: 需应用为启动状态 功能说明 iOS16.1 系统支持实时活动功能,可以在锁定屏幕上实时获知各种事情的进展,MobPushSDK iOS 4.0.3版本已完成适配&#xf…...

KubeSphere 容器平台高可用:环境搭建与可视化操作指南

Linux_k8s篇 欢迎来到Linux的世界,看笔记好好学多敲多打,每个人都是大神! 题目:KubeSphere 容器平台高可用:环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

多云管理“拦路虎”:深入解析网络互联、身份同步与成本可视化的技术复杂度​

一、引言:多云环境的技术复杂性本质​​ 企业采用多云策略已从技术选型升维至生存刚需。当业务系统分散部署在多个云平台时,​​基础设施的技术债呈现指数级积累​​。网络连接、身份认证、成本管理这三大核心挑战相互嵌套:跨云网络构建数据…...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以? 在 Golang 的面试中,map 类型的使用是一个常见的考点,其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

微信小程序云开发平台MySQL的连接方式

注:微信小程序云开发平台指的是腾讯云开发 先给结论:微信小程序云开发平台的MySQL,无法通过获取数据库连接信息的方式进行连接,连接只能通过云开发的SDK连接,具体要参考官方文档: 为什么? 因为…...

关键领域软件测试的突围之路:如何破解安全与效率的平衡难题

在数字化浪潮席卷全球的今天,软件系统已成为国家关键领域的核心战斗力。不同于普通商业软件,这些承载着国家安全使命的软件系统面临着前所未有的质量挑战——如何在确保绝对安全的前提下,实现高效测试与快速迭代?这一命题正考验着…...

Kafka入门-生产者

生产者 生产者发送流程: 延迟时间为0ms时,也就意味着每当有数据就会直接发送 异步发送API 异步发送和同步发送的不同在于:异步发送不需要等待结果,同步发送必须等待结果才能进行下一步发送。 普通异步发送 首先导入所需的k…...

Linux 中如何提取压缩文件 ?

Linux 是一种流行的开源操作系统,它提供了许多工具来管理、压缩和解压缩文件。压缩文件有助于节省存储空间,使数据传输更快。本指南将向您展示如何在 Linux 中提取不同类型的压缩文件。 1. Unpacking ZIP Files ZIP 文件是非常常见的,要在 …...

【C++】纯虚函数类外可以写实现吗?

1. 答案 先说答案&#xff0c;可以。 2.代码测试 .h头文件 #include <iostream> #include <string>// 抽象基类 class AbstractBase { public:AbstractBase() default;virtual ~AbstractBase() default; // 默认析构函数public:virtual int PureVirtualFunct…...

Linux部署私有文件管理系统MinIO

最近需要用到一个文件管理服务&#xff0c;但是又不想花钱&#xff0c;所以就想着自己搭建一个&#xff0c;刚好我们用的一个开源框架已经集成了MinIO&#xff0c;所以就选了这个 我这边对文件服务性能要求不是太高&#xff0c;单机版就可以 安装非常简单&#xff0c;几个命令就…...

面试高频问题

文章目录 &#x1f680; 消息队列核心技术揭秘&#xff1a;从入门到秒杀面试官1️⃣ Kafka为何能"吞云吐雾"&#xff1f;性能背后的秘密1.1 顺序写入与零拷贝&#xff1a;性能的双引擎1.2 分区并行&#xff1a;数据的"八车道高速公路"1.3 页缓存与批量处理…...