当前位置: 首页 > news >正文

Pytorch网络模型训练

现有网络模型的使用与修改

vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别print(vgg16_true)

以下是vgg16预训练模型的输出 

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

预训练模型的输出从1000类别转为10类别

import torchvision
from torch import nn
# 因为数据集过大,所以注释掉此行代码
# train_data = torchvision.datasets.ImageNet("./data_image_net", split='train', download=True,
#                                            transform=torchvision.transforms.ToTensor())vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别print(vgg16_true)# vgg16_true.add_module("add_linear", nn.Linear(1000, 10))
vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
# 在预训练模型的最后添加了一个新的全连接层,用于将最后的输出转化为10个类别
print(vgg16_true)print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
# 未预训练模型的最后一层的输出特征数更改为了10
print(vgg16_false)

网络模型的保存与读取

加载未预训练的模型

vgg16 = torchvision.models.vgg16(pretrained=False)

方式一

# 保存方式1  保存的模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pyth")#读取方式1
model = torch.load("vgg16_method1.pth")

方式二

# 保存方式2  不再保存模型结构,而是保存模型的参数为字典结构    推荐
torch.save(vgg16.state_dict(), "vgg16_method2.pyth")# 方式2,加载模型
# model = torch.load("vgg16_method2.pth")     #这样输出的是字典类型
# print(model)
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))      # 将其恢复为网络模型
print(vgg16)

完整的模型训练套路

准备数据集

# 准备数据集
train_data = torchvision.datasets.CIFAR10("../data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),download=True)train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为{}".format(train_data_size))    # 50000
print("测试数据集的长度为{}".format(test_data_size))     # 10000# 利用Dataloader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

创建网络模型

# 创建网络模型  神经网络的代码在train_module文件
tudui = Tudui()

train_module文件

# 搭建神经网络
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()# 简化操作,并且按顺序进行操作self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x

构建损失函数

# 损失函数
loss_fn = nn.CrossEntropyLoss()

构建优化器

# 优化器
# 如果学习率过大,模型可能会在最小值附近震荡而无法收敛;如果学习率过小,模型训练可能会过于缓慢
learning_rate = 0.01
# 使用随机梯度下降算法来更新模型的权重
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

设置训练集参数

# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10

添加tensorboard

# 将数据写入 TensorBoard 可视化的日志文件中
writer = SummaryWriter("../logs_train")

训练步骤

# tudui.train()
for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)# 优化器优化模型optimizer.zero_grad()# 将优化器中的梯度缓存(如果有的话)清零loss.backward()# 计算损失函数(loss)相对于模型参数的梯度optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:# .item()是将tensor张量变为正常的数字print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))# loss.item()是当前步骤的损失值writer.add_scalar("train_loss", loss.item(), total_train_step)# 使用add_scalar可以将一个标量添加到之前的所有标量值中,# 这样就可以在TensorBoard中绘制一个标量随时间变化的图表

测试步骤

# 测试步骤开始
# tudui.eval()
total_test_loss = 0
total_accuracy = 0
# 不会对以下的代码进行调优
with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()# argmax(1)是横向看,argmax(0)是纵向看accuracy = (outputs.argmax(1) == targets).sum()# argmax在找到模型预测的最大概率对应的类别# 预测正确的个数total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))
print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))
# 测试集上的总损失
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
total_test_step = total_test_step + 1

利用GPU训练

# 定义训练的设备
# device = torch.device("cpu/cuda")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# GPU训练的关键
tudui = tudui.cuda()
# tudui = tudui.to(device)

相关文章:

Pytorch网络模型训练

现有网络模型的使用与修改 vgg16_false torchvision.models.vgg16(pretrainedFalse) # 加载一个未预训练的模型 vgg16_true torchvision.models.vgg16(pretrainedTrue) # 把数据分为了1000个类别print(vgg16_true) 以下是vgg16预训练模型的输出 VGG((features): S…...

webgoat-Path traversal

Path traversal 路径(目录)遍历是一种漏洞,攻击者能够访问或存储外部的文件和目录 应用程序运行的位置。这可能会导致从其他目录读取文件,如果是文件,则会导致读取文件 上传覆盖关键系统文件。 它是如何工作的&#…...

P8976 「DTOI-4」排列,贪心

题目背景 Update on 2023.2.1:新增一组针对 yuanjiabao 的 Hack 数据,放置于 #21。 Update on 2023.2.2:新增一组针对 CourtesyWei 和 bizhidaojiaosha 的 Hack 数据,放置于 #22。 题目描述 小 L 给你一个偶数 n 和两个整数a,b…...

使用 Python 进行自然语言处理第 5 部分:文本分类

一、说明 关于文本分类,文章已经很多,本文这里有实操代码,明确而清晰地表述这种过程,是实战工程师所可以参照和依赖的案例版本。 本文是 2023 年 1 月的 WomenWhoCode 数据科学跟踪活动提供的会议系列文章中的一篇。 之前的文章在…...

uni-app---- 点击按钮拨打电话功能点击按钮调用高德地图进行导航的功能【安卓app端】

uniapp---- 点击按钮拨打电话功能&&点击按钮调用高德地图进行导航的功能【安卓app端】 先上效果图: 1. 在封装方法的文件夹下新建一个js文件,然后把这些功能进行封装 // 点击按钮拨打电话 export function getActionSheet(phone) {uni.showAct…...

通讯录详解(静态版,动态版,文件版)

💓博客主页:江池俊的博客⏩收录专栏:C语言进阶之路👉专栏推荐:✅C语言初阶之路 ✅数据结构探索✅C语言刷题专栏💻代码仓库:江池俊的代码仓库🎉欢迎大家点赞👍评论&#x…...

在windows中搭建vue开发环境

1.环境搭建 具体环境搭建步骤参考链接 注意该博客中初始化命令: vue init webpack MyPortalProject需改为小写: vue init webpack myportalproject不然会报错 Warning: name can no longer contain capital letters2.创建第一个vueelement ui项目 …...

数字化转型:云表低代码开发助力制造业腾飞

数字化转型已成为制造业不可避免的趋势。为了应对市场快速变化、提高运营效率以及降低成本,制造业企业积极追求更加智能化、敏捷的生产方式。在这个转型过程中,低代码技术作为一种强大的工具,正逐渐崭露头角,有望加速制造业的数字…...

Linux学习之vim跳转到特定行数

参考的博客:《Vim跳到最后一行的方法》 《oeasy教您玩转vim - 14 - # 行头行尾》 《Linux:vim 中跳到首行和最后一行》 想要跳到特定行的话,可以在命令模式和正常模式进行跳转。要是对于vim的四种模式不太熟的话,可以到博客《Linu…...

详解基于Android的Appium+Python自动化脚本编写

📢专注于分享软件测试干货内容,欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正!📢交流讨论:欢迎加入我们一起学习!📢资源分享:耗时200小时精选的「软件测试」资…...

【马蹄集】—— 百度之星 2023

百度之星 2023 目录 BD202301 公园⭐BD202302 蛋糕划分⭐⭐⭐BD202303 第五维度⭐⭐ BD202301 公园⭐ 难度:钻石    时间限制:1秒    占用内存:64M 题目描述 今天是六一节,小度去公园玩,公园一共 N N N 个景点&am…...

大数据毕业设计选题推荐-无线网络大数据平台-Hadoop-Spark-Hive

✨作者主页:IT毕设梦工厂✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…...

【jvm】虚拟机之本地方法接口与本地方法库

目录 一、本地方法1.1 说明1.2 代码示例1.3 为什么要使用native method 二、现状 一、本地方法 1.1 说明 1.一个Native Method就是一个Java调用非Java代码的接口。 2.一个Native Method是这样一个Java方法:该方法的实现由非Java语言实现,比如C。 3.这个…...

HDFS系统操作命令大全

一,前言 HDFS作为分布式存储的文件系统,有其对数据的路径表达方式 HDFS同linux系统一样,均是以/作为根目录的组织形式 linux:/usr/local/hello.txt HDFS:/usr/local/hello.txt 二,如何区分呢? L…...

雷尼绍探头编程 9810

9810 ​ 安全移动 使用参数 参数含义#9移动速度 F#117移动速度 F#148#24X 移动 终点绝对坐标#25Y 移动 终点绝对坐标#26Z 移动 终点绝对坐标#123机床移动到终点的绝对坐标 与 终点的理论值 的 差#5041当前绝对坐标 X 值#5042当前绝对坐标 Y 值#5043当前绝对坐标 Z 值#116刀具…...

el-table 列分页

<template><div><el-table:data"tableData":key"tampTime"style"width: 100%"><el-table-columnprop"name"label"姓名"width"180"></el-table-column><el-table-columnprop&quo…...

APP攻防--ADB基础

进入app包 先使用 adb devices查看链接状态 手机连接成功的 adb shell 获取到手机的一个shell 此时想进入app包时没有权限的&#xff0c;APP包一般在data/data/下。没有执行权限&#xff0c;如图 Permission denied 权限被拒绝 此时需要手机root&#xff0c;root后输入 su …...

【Linux】第十站:git和gdb的基本使用

文章目录 一、git的基本操作1.gitee新建仓库注意事项2.git的安装3.git的克隆4.git的add5.git的commit6.git的push7.git log8.git status9. .gitignore 二、Linux调试器---gdb1.背景2.gdb安装、进入与退出3.list/l4.r/run运行程序5. break/b 打断点6.info/i b 查看断点7.delete/…...

Single Image Haze Removal Using Dark Channel Prior(暗通道先验)

去雾算法都会依赖于很强的先验以及假设&#xff0c;并结合相应的物理模型&#xff0c;完成去雾过程。本文作者何凯明及其团队通过大量的无雾图像和有雾图像&#xff0c;归纳总结出无雾图像在其对应的暗通道图像上具有极低的强度值&#xff08;趋近于0&#xff09;&#xff0c;并…...

力扣382.链表随机节点(java利用数组随机返回节点值)

Problem: 382. 链表随机节点 文章目录 思路解题方法复杂度Code 思路 注意链表与数组的特性&#xff0c;对于随机访问读取的操作利用数组可以较方便实现&#xff0c;所以我们可以将链表中的节点值先存入到数组中最后再取出随机生成节点位置的值。 解题方法 1.生成List集合与Rand…...

在jupyter中使用R

如果想在Jupyter Notebook中使用R语言&#xff0c;以下几个步骤操作可行&#xff1a; 1、启动Anaconda Prompt 2、进入R的安装位置&#xff0c;切换到R的安装位置&#xff1a;D:\Program Files\R\R-3.4.3\bin&#xff0c;启动R&#xff0c;具体代码操作步骤如下&#xff0c;在…...

2023(第四届)江西开放数据创新应用大赛等你来挑战!

邀请函 这是一个友好的邀请。无论你是数据领域的专家、学生还是爱好者&#xff0c;我们都欢迎你加入这个平台。这不仅仅是一场比赛&#xff0c;更是一个交流、学习和展示自己的机会。 丰厚奖金&#xff1a;我们为参赛者准备了总计15W的奖金池&#xff0c;期待你的才华在这里得…...

2023-mac rz sz 安装

之前安装过一次&#xff0c;没问题&#xff0c;这次按照之前教程装了就不管上传下载都会卡住&#xff1b; step1: brew install lrzsz step2&#xff1a;在/usr/local/bin 路径下配置两个sh,之前从网上找到的直接用都不对&#xff0c;下面这个是调试过的正式可用的 iterm2…...

使用Matplotlib绘画3D图时运行不出结果,也不报错,图片是空白 !!

1.问题&#xff1a; 我使用如下代码运用matplotlib中的Axes3D绘画3D图&#xff0c;但是运行出来的结果是空白。 import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D #导入3D包 fig plt.figure() #窗口 #ax Axes3D(fig) # X, Y …...

Matlab函数——find

介绍 当你需要返回某个数组中符合指定条件的所有元素的索引时&#xff0c;可以使用 MATLAB 中的 find 函数。 find 函数语法&#xff1a; indices find(X) indices find(X, k) indices find(X, k, first) indices find(X, k, last) 其中&#xff0c;X 是一个数组&#xf…...

mac安装python3

文章目录 1. 安装1.1 brew安装&#xff08;失败&#xff09;2. 下载安装包 2. 查看版本3. 配置 1. 安装 1.1 brew安装&#xff08;失败&#xff09; brew install python3下载完成后报错&#xff1a;Error: python3.10: unknown or unsupported macOS version: :dunno 解决&a…...

【星海出品】VUE(一)

Windows安装nvm控制器 Windows里找都PowerShell。右击点击管理员运行。 1.安装choco Set-ExecutionPolicy Bypass -Scope Process -Force; iex ((New-Object System.Net.WebClient).DownloadString(https://chocolatey.org/install.ps1))2.安装NVM choco install nvm 3.查看可…...

Stable Diffusion 的提示词使用技巧

推荐Stable Diffusion自动纹理工具&#xff1a; DreamTexture.js自动纹理化开发包 什么是提示语&#xff1f; 提示语是人工智能中的一个重要组成部分&#xff0c;尤其是自然语言处理 &#xff08;NLP&#xff09;。在AI自人工智能中&#xff0c;想要获得好的效果&#xff0c;简…...

Hook函数

在嵌入式系统中&#xff0c;hook函数&#xff08;也被称为钩子函数&#xff09;是一种特殊类型的函数&#xff0c;它会在特定的事件发生时被操作系统内部调用。例如&#xff0c;在实时操作系统&#xff08;RTOS&#xff09;中&#xff0c;如果删除了一个任务&#xff0c;就会调…...

USB简介系列-01

文章目录 USB简介一、电气USB简介 通用串行总线(USB)是由Compaq,Intel,Microsoft和NEC开发的规范,后来惠普,朗讯和飞利浦加入。这些公司成立了 USB Implementers Forum, Inc 作为一家非营利性公司,以发布规范并组织 USB 的进一步开发。 USB-IF的目的是为当时使用的PC…...