当前位置: 首页 > news >正文

Pytorch网络模型训练

现有网络模型的使用与修改

vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别print(vgg16_true)

以下是vgg16预训练模型的输出 

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

预训练模型的输出从1000类别转为10类别

import torchvision
from torch import nn
# 因为数据集过大,所以注释掉此行代码
# train_data = torchvision.datasets.ImageNet("./data_image_net", split='train', download=True,
#                                            transform=torchvision.transforms.ToTensor())vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别print(vgg16_true)# vgg16_true.add_module("add_linear", nn.Linear(1000, 10))
vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
# 在预训练模型的最后添加了一个新的全连接层,用于将最后的输出转化为10个类别
print(vgg16_true)print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
# 未预训练模型的最后一层的输出特征数更改为了10
print(vgg16_false)

网络模型的保存与读取

加载未预训练的模型

vgg16 = torchvision.models.vgg16(pretrained=False)

方式一

# 保存方式1  保存的模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pyth")#读取方式1
model = torch.load("vgg16_method1.pth")

方式二

# 保存方式2  不再保存模型结构,而是保存模型的参数为字典结构    推荐
torch.save(vgg16.state_dict(), "vgg16_method2.pyth")# 方式2,加载模型
# model = torch.load("vgg16_method2.pth")     #这样输出的是字典类型
# print(model)
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))      # 将其恢复为网络模型
print(vgg16)

完整的模型训练套路

准备数据集

# 准备数据集
train_data = torchvision.datasets.CIFAR10("../data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),download=True)train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为{}".format(train_data_size))    # 50000
print("测试数据集的长度为{}".format(test_data_size))     # 10000# 利用Dataloader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

创建网络模型

# 创建网络模型  神经网络的代码在train_module文件
tudui = Tudui()

train_module文件

# 搭建神经网络
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()# 简化操作,并且按顺序进行操作self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x

构建损失函数

# 损失函数
loss_fn = nn.CrossEntropyLoss()

构建优化器

# 优化器
# 如果学习率过大,模型可能会在最小值附近震荡而无法收敛;如果学习率过小,模型训练可能会过于缓慢
learning_rate = 0.01
# 使用随机梯度下降算法来更新模型的权重
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

设置训练集参数

# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10

添加tensorboard

# 将数据写入 TensorBoard 可视化的日志文件中
writer = SummaryWriter("../logs_train")

训练步骤

# tudui.train()
for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)# 优化器优化模型optimizer.zero_grad()# 将优化器中的梯度缓存(如果有的话)清零loss.backward()# 计算损失函数(loss)相对于模型参数的梯度optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:# .item()是将tensor张量变为正常的数字print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))# loss.item()是当前步骤的损失值writer.add_scalar("train_loss", loss.item(), total_train_step)# 使用add_scalar可以将一个标量添加到之前的所有标量值中,# 这样就可以在TensorBoard中绘制一个标量随时间变化的图表

测试步骤

# 测试步骤开始
# tudui.eval()
total_test_loss = 0
total_accuracy = 0
# 不会对以下的代码进行调优
with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()# argmax(1)是横向看,argmax(0)是纵向看accuracy = (outputs.argmax(1) == targets).sum()# argmax在找到模型预测的最大概率对应的类别# 预测正确的个数total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))
print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))
# 测试集上的总损失
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
total_test_step = total_test_step + 1

利用GPU训练

# 定义训练的设备
# device = torch.device("cpu/cuda")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# GPU训练的关键
tudui = tudui.cuda()
# tudui = tudui.to(device)

相关文章:

Pytorch网络模型训练

现有网络模型的使用与修改 vgg16_false torchvision.models.vgg16(pretrainedFalse) # 加载一个未预训练的模型 vgg16_true torchvision.models.vgg16(pretrainedTrue) # 把数据分为了1000个类别print(vgg16_true) 以下是vgg16预训练模型的输出 VGG((features): S…...

webgoat-Path traversal

Path traversal 路径(目录)遍历是一种漏洞,攻击者能够访问或存储外部的文件和目录 应用程序运行的位置。这可能会导致从其他目录读取文件,如果是文件,则会导致读取文件 上传覆盖关键系统文件。 它是如何工作的&#…...

P8976 「DTOI-4」排列,贪心

题目背景 Update on 2023.2.1:新增一组针对 yuanjiabao 的 Hack 数据,放置于 #21。 Update on 2023.2.2:新增一组针对 CourtesyWei 和 bizhidaojiaosha 的 Hack 数据,放置于 #22。 题目描述 小 L 给你一个偶数 n 和两个整数a,b…...

使用 Python 进行自然语言处理第 5 部分:文本分类

一、说明 关于文本分类,文章已经很多,本文这里有实操代码,明确而清晰地表述这种过程,是实战工程师所可以参照和依赖的案例版本。 本文是 2023 年 1 月的 WomenWhoCode 数据科学跟踪活动提供的会议系列文章中的一篇。 之前的文章在…...

uni-app---- 点击按钮拨打电话功能点击按钮调用高德地图进行导航的功能【安卓app端】

uniapp---- 点击按钮拨打电话功能&&点击按钮调用高德地图进行导航的功能【安卓app端】 先上效果图: 1. 在封装方法的文件夹下新建一个js文件,然后把这些功能进行封装 // 点击按钮拨打电话 export function getActionSheet(phone) {uni.showAct…...

通讯录详解(静态版,动态版,文件版)

💓博客主页:江池俊的博客⏩收录专栏:C语言进阶之路👉专栏推荐:✅C语言初阶之路 ✅数据结构探索✅C语言刷题专栏💻代码仓库:江池俊的代码仓库🎉欢迎大家点赞👍评论&#x…...

在windows中搭建vue开发环境

1.环境搭建 具体环境搭建步骤参考链接 注意该博客中初始化命令: vue init webpack MyPortalProject需改为小写: vue init webpack myportalproject不然会报错 Warning: name can no longer contain capital letters2.创建第一个vueelement ui项目 …...

数字化转型:云表低代码开发助力制造业腾飞

数字化转型已成为制造业不可避免的趋势。为了应对市场快速变化、提高运营效率以及降低成本,制造业企业积极追求更加智能化、敏捷的生产方式。在这个转型过程中,低代码技术作为一种强大的工具,正逐渐崭露头角,有望加速制造业的数字…...

Linux学习之vim跳转到特定行数

参考的博客:《Vim跳到最后一行的方法》 《oeasy教您玩转vim - 14 - # 行头行尾》 《Linux:vim 中跳到首行和最后一行》 想要跳到特定行的话,可以在命令模式和正常模式进行跳转。要是对于vim的四种模式不太熟的话,可以到博客《Linu…...

详解基于Android的Appium+Python自动化脚本编写

📢专注于分享软件测试干货内容,欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正!📢交流讨论:欢迎加入我们一起学习!📢资源分享:耗时200小时精选的「软件测试」资…...

【马蹄集】—— 百度之星 2023

百度之星 2023 目录 BD202301 公园⭐BD202302 蛋糕划分⭐⭐⭐BD202303 第五维度⭐⭐ BD202301 公园⭐ 难度:钻石    时间限制:1秒    占用内存:64M 题目描述 今天是六一节,小度去公园玩,公园一共 N N N 个景点&am…...

大数据毕业设计选题推荐-无线网络大数据平台-Hadoop-Spark-Hive

✨作者主页:IT毕设梦工厂✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…...

【jvm】虚拟机之本地方法接口与本地方法库

目录 一、本地方法1.1 说明1.2 代码示例1.3 为什么要使用native method 二、现状 一、本地方法 1.1 说明 1.一个Native Method就是一个Java调用非Java代码的接口。 2.一个Native Method是这样一个Java方法:该方法的实现由非Java语言实现,比如C。 3.这个…...

HDFS系统操作命令大全

一,前言 HDFS作为分布式存储的文件系统,有其对数据的路径表达方式 HDFS同linux系统一样,均是以/作为根目录的组织形式 linux:/usr/local/hello.txt HDFS:/usr/local/hello.txt 二,如何区分呢? L…...

雷尼绍探头编程 9810

9810 ​ 安全移动 使用参数 参数含义#9移动速度 F#117移动速度 F#148#24X 移动 终点绝对坐标#25Y 移动 终点绝对坐标#26Z 移动 终点绝对坐标#123机床移动到终点的绝对坐标 与 终点的理论值 的 差#5041当前绝对坐标 X 值#5042当前绝对坐标 Y 值#5043当前绝对坐标 Z 值#116刀具…...

el-table 列分页

<template><div><el-table:data"tableData":key"tampTime"style"width: 100%"><el-table-columnprop"name"label"姓名"width"180"></el-table-column><el-table-columnprop&quo…...

APP攻防--ADB基础

进入app包 先使用 adb devices查看链接状态 手机连接成功的 adb shell 获取到手机的一个shell 此时想进入app包时没有权限的&#xff0c;APP包一般在data/data/下。没有执行权限&#xff0c;如图 Permission denied 权限被拒绝 此时需要手机root&#xff0c;root后输入 su …...

【Linux】第十站:git和gdb的基本使用

文章目录 一、git的基本操作1.gitee新建仓库注意事项2.git的安装3.git的克隆4.git的add5.git的commit6.git的push7.git log8.git status9. .gitignore 二、Linux调试器---gdb1.背景2.gdb安装、进入与退出3.list/l4.r/run运行程序5. break/b 打断点6.info/i b 查看断点7.delete/…...

Single Image Haze Removal Using Dark Channel Prior(暗通道先验)

去雾算法都会依赖于很强的先验以及假设&#xff0c;并结合相应的物理模型&#xff0c;完成去雾过程。本文作者何凯明及其团队通过大量的无雾图像和有雾图像&#xff0c;归纳总结出无雾图像在其对应的暗通道图像上具有极低的强度值&#xff08;趋近于0&#xff09;&#xff0c;并…...

力扣382.链表随机节点(java利用数组随机返回节点值)

Problem: 382. 链表随机节点 文章目录 思路解题方法复杂度Code 思路 注意链表与数组的特性&#xff0c;对于随机访问读取的操作利用数组可以较方便实现&#xff0c;所以我们可以将链表中的节点值先存入到数组中最后再取出随机生成节点位置的值。 解题方法 1.生成List集合与Rand…...

国防科技大学计算机基础课程笔记02信息编码

1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制&#xff0c;因此这个了16进制的数据既可以翻译成为这个机器码&#xff0c;也可以翻译成为这个国标码&#xff0c;所以这个时候很容易会出现这个歧义的情况&#xff1b; 因此&#xff0c;我们的这个国…...

label-studio的使用教程(导入本地路径)

文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...

进程地址空间(比特课总结)

一、进程地址空间 1. 环境变量 1 &#xff09;⽤户级环境变量与系统级环境变量 全局属性&#xff1a;环境变量具有全局属性&#xff0c;会被⼦进程继承。例如当bash启动⼦进程时&#xff0c;环 境变量会⾃动传递给⼦进程。 本地变量限制&#xff1a;本地变量只在当前进程(ba…...

ubuntu搭建nfs服务centos挂载访问

在Ubuntu上设置NFS服务器 在Ubuntu上&#xff0c;你可以使用apt包管理器来安装NFS服务器。打开终端并运行&#xff1a; sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享&#xff0c;例如/shared&#xff1a; sudo mkdir /shared sud…...

Spring Boot 实现流式响应(兼容 2.7.x)

在实际开发中&#xff0c;我们可能会遇到一些流式数据处理的场景&#xff0c;比如接收来自上游接口的 Server-Sent Events&#xff08;SSE&#xff09; 或 流式 JSON 内容&#xff0c;并将其原样中转给前端页面或客户端。这种情况下&#xff0c;传统的 RestTemplate 缓存机制会…...

Redis相关知识总结(缓存雪崩,缓存穿透,缓存击穿,Redis实现分布式锁,如何保持数据库和缓存一致)

文章目录 1.什么是Redis&#xff1f;2.为什么要使用redis作为mysql的缓存&#xff1f;3.什么是缓存雪崩、缓存穿透、缓存击穿&#xff1f;3.1缓存雪崩3.1.1 大量缓存同时过期3.1.2 Redis宕机 3.2 缓存击穿3.3 缓存穿透3.4 总结 4. 数据库和缓存如何保持一致性5. Redis实现分布式…...

Nginx server_name 配置说明

Nginx 是一个高性能的反向代理和负载均衡服务器&#xff0c;其核心配置之一是 server 块中的 server_name 指令。server_name 决定了 Nginx 如何根据客户端请求的 Host 头匹配对应的虚拟主机&#xff08;Virtual Host&#xff09;。 1. 简介 Nginx 使用 server_name 指令来确定…...

linux 下常用变更-8

1、删除普通用户 查询用户初始UID和GIDls -l /home/ ###家目录中查看UID cat /etc/group ###此文件查看GID删除用户1.编辑文件 /etc/passwd 找到对应的行&#xff0c;YW343:x:0:0::/home/YW343:/bin/bash 2.将标红的位置修改为用户对应初始UID和GID&#xff1a; YW3…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成

厌倦手动写WordPress文章&#xff1f;AI自动生成&#xff0c;效率提升10倍&#xff01; 支持多语言、自动配图、定时发布&#xff0c;让内容创作更轻松&#xff01; AI内容生成 → 不想每天写文章&#xff1f;AI一键生成高质量内容&#xff01;多语言支持 → 跨境电商必备&am…...

Unit 1 深度强化学习简介

Deep RL Course ——Unit 1 Introduction 从理论和实践层面深入学习深度强化学习。学会使用知名的深度强化学习库&#xff0c;例如 Stable Baselines3、RL Baselines3 Zoo、Sample Factory 和 CleanRL。在独特的环境中训练智能体&#xff0c;比如 SnowballFight、Huggy the Do…...