当前位置: 首页 > news >正文

如何使用.pth训练模型

一.使用.pth训练模型的步骤如下:

1.导入必要的库和模型

import torch
import torchvision.models as models# 加载预训练模型
model = models.resnet50(pretrained=True)

2.定义数据集和数据加载器

# 定义数据集和数据加载器
dataset = MyDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

3.定义损失函数和优化器

# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

4.训练模型

# 训练模型
for epoch in range(10):running_loss = 0.0for i, data in enumerate(dataloader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 2000 == 1999:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0

5.保存模型

# 保存模型
torch.save(model.state_dict(), 'model.pth')

二,使用自己训练的.pth模型进行训练的步骤如下:

1.导入必要的库和模型

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from my_dataset import MyDataset # 自定义数据集
from my_model import MyModel # 自定义模型

2.设置超参数和路径

batch_size = 32 # 批大小
num_epochs = 10 # 训练轮数
learning_rate = 0.001 # 学习率
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设置设备
train_data_path = 'train_data/' # 训练数据集路径
test_data_path = 'test_data/' # 测试数据集路径
model_path = 'my_model.pth' # 模型保存路径

3.加载数据集

train_transforms = transforms.Compose([transforms.Resize((224, 224)), # 调整图像大小transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(), # 转换为张量transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 标准化
])test_transforms = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])train_dataset = MyDataset(train_data_path, train_transforms) # 自定义数据集
test_dataset = MyDataset(test_data_path, test_transforms)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # 训练集加载器
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 测试集加载器

4.加载模型

model = MyModel() # 自定义模型
model.load_state_dict(torch.load(model_path)) # 加载.pth模型
model.to(device) # 将模型移动到设备上

5.定义损失函数和优化器

criterion = torch.nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Adam优化器

6.训练模型

total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))torch.save(model.state_dict(), 'fine_tuned_model.pth') # 保存.pth模型

7.测试模型

model.eval() # 切换到评估模式
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the test images: {} %'.format(100 * correct / total))

相关文章:

如何使用.pth训练模型

一.使用.pth训练模型的步骤如下: 1.导入必要的库和模型 import torch import torchvision.models as models# 加载预训练模型 model models.resnet50(pretrainedTrue) 2.定义数据集和数据加载器 # 定义数据集和数据加载器 dataset MyDataset() dataloader to…...

C++11线程以及线程同步

C11中提供的线程类std::thread,基于此类创建一个新的线程相对简单,只需要提供线程函数和线程对象即可 一.命名空间 this_thread C11 添加一个关于线程的命名空间std::this_pthread ,此命名空间中提供四个公共的成员函数; 1.1 get_id() 调用命名空间s…...

深度学习之基于YoloV3杂草识别系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 深度学习在图像识别领域已经取得了显著的成果,其中基于YOLO(You Only Look Once&#xff09…...

Linux 命令vim(编辑器)

(一)vim编辑器的介绍 vim是文件编辑器,是vi的升级版本,兼容vi的所有指令,同时做了优化和延伸。vim有多种模式,其中常用的模式有命令模式、插入模式、末行模式:。 (二)vim编辑器基本操作 1 进入vim编辑文件 1 vim …...

轻松配置PPPoE连接:路由器设置和步骤详解

在家庭网络环境中,我们经常使用PPPoE(点对点协议过夜)连接来接入宽带互联网。然而,对于一些没有网络专业知识的人来说,配置PPPoE连接可能会有些困难。在本文中,我将详细介绍如何轻松配置PPPoE连接&#xff…...

电源控制系统架构(PCSA)之系统分区电源域

目录 4.2 电源域 4.2.1 电源模式 4.2.2 电源域的选择 4.2.3 系统逻辑 4.2.4 Always-On域 4.2.5 处理器Clusters 4.2.6 CoreSight逻辑 4.2.7 图像处理器 4.2.8 显示处理器 4.2.9 其他功能 4.2.10 电源域层次结构要求 4.2.11 SOC域示例 4.2 电源域 电源域在这里被定…...

Linux:docker基础操作(3)

docker的介绍 Linux:Docker的介绍(1)-CSDN博客https://blog.csdn.net/w14768855/article/details/134146721?spm1001.2014.3001.5502 通过yum安装docker Linux:Docker-yum安装(2)-CSDN博客https://blog.…...

【Axure教程】用中继器制作卡片多条件搜索效果

卡片设计通过提供清晰的信息结构、可视化吸引力、易扩展性和强大的交互性,为用户界面设计带来了许多优势,使得用户能够更轻松地浏览、理解和互动。 那今天就教大家如何用中继器制作卡片的模板,以及完成多条件搜索的效果,我们会以…...

Linux中vi常用命令-批量替换

在日常服务器日志查看中常用到的命令有grep、tail等,有时想查看详细日志,用到vi命令,记录下来,方便查看。 操作文件:test.properites 一、查看与编辑 查看命令:vi 文件名 编辑命令:按键 i&…...

logback-spring.xml的内容格式

目录 一、logback-spring.xml 二、Logback 中的三种日志文件类型 一、logback-spring.xml <?xml version"1.0" encoding"UTF-8"?> <configuration scan"true" scanPeriod"10 seconds" ><!-- <statusListener…...

nodejs+vue+elementui+express青少年编程课程在线考试系统

针对传统线下考试存在的老师阅卷工作量较大&#xff0c;统计成绩数据时间长等问题&#xff0c;实现一套高效、灵活、功能强大的管理系统是非常必要的。该系统可以迅速完成随机组卷&#xff0c;及时阅卷、统计考试成绩排名的效果。该考试系统要求&#xff1a;该系统将采用B/S结构…...

Navicat 技术指引 | GaussDB 数据查看器

Navicat Premium&#xff08;16.2.8 Windows版或以上&#xff09; 已支持对GaussDB 主备版的管理和开发功能。它不仅具备轻松、便捷的可视化数据查看和编辑功能&#xff0c;还提供强大的高阶功能&#xff08;如模型、结构同步、协同合作、数据迁移等&#xff09;&#xff0c;这…...

Docker的registry

简介 地址&#xff1a;https://hub.docker.com/_/registry Dcoker registry是存储Dcoker image的仓库&#xff0c;运行push&#xff0c;pull&#xff0c;search时&#xff0c;是通过Dcoker daemon与docker registry通信。有时候会用Dcoker Hub这样的公共仓库可能不方便&#x…...

【vue_3】关于超链接的问题

1、需求2、修改前的代码3、修改之后&#xff08;1&#xff09;第一次&#xff08;2&#xff09;第二次&#xff08;3&#xff09;第三次&#xff08;4&#xff09;第四次&#xff08;5&#xff09;第五次 1、需求 需求&#xff1a;要给没有超链接的列表添加软超链接 2、修改前…...

redis优化秒杀和消息队列

redis优化秒杀 1. 异步秒杀思路1.1 在redis存入库存和订单信息1.2 具体流程图 2. 实现2.1 总结 3. Redis的消息队列3.1 基于list实现消息队列3.2 基于PubSub实现消息队列3.3 基于stream实现消息队列3.3.1 stream的单消费模式3.3.2 stream的消费者组模式 3.4 基于stream消息队列…...

arm-eabi-gcc 和 arm-none-eabi-gcc 都是基于 GCC 的交叉编译器

arm-eabi-gcc 和 arm-none-eabi-gcc 都是基于 GCC 的交叉编译器&#xff0c;用于编译 ARM 架构的嵌入式系统。它们的命名规则如下&#xff1a; arm 表示目标架构是 ARM。eabi 表示嵌入式应用程序二进制接口&#xff08;Embedded Application Binary Interface&#xff09;&…...

《大话设计模式》(持续更新中)

《大话设计模式》 序 为什么要学设计模式第0章 面向对象基础什么是对象&#xff1f;什么是类&#xff1f;什么是构造方法&#xff1f;什么是重载&#xff1f;属性与字段有什么区别&#xff1f;什么是封装&#xff1f;什么是继承&#xff1f;什么是多态&#xff1f;抽象类的目的…...

人工智能原理复习--绪论

文章目录 人工智能原理概述图灵测试人工智能的研究方法符号主义连接主义行为主义总结 人工智能原理概述 人工智能是计算机科学基础理论研究的重要组成部分 现代人工智能一般认为起源于美国1956你那夏季的达特茅斯会议&#xff0c;在这次会议上&#xff0c;John McCarthy第一次…...

[网络] 字节一面~ 2. HTTP 2 与 HTTP 1.x 有什么区别

头部压缩 在 HTTP2 当中&#xff0c;如果你发出了多个请求&#xff0c;并且它们的头部(header)是相同的&#xff0c;那么 HTTP2 协议会帮你消除同样的部分。(其实就是在客户端和服务端维护一张索引表来实现)二进制格式 HTTP1.1 采用明文的形式 HTTP/2 全⾯采⽤了⼆进制格式&…...

自己动手实现一个深度学习算法——八、深度学习

深度学习是加深了层的深度神经网络。 1.加深网络 1&#xff09;向更深的网络出发 创建一个如下图所示的网络结构的CNN 这个网络的层比之前实现的网络都更深。这里使用的卷积层全都是33 的小型滤波器&#xff0c;特点是随着层的加深&#xff0c;通道数变大&#xff08;卷积…...

IDEA运行Tomcat出现乱码问题解决汇总

最近正值期末周&#xff0c;有很多同学在写期末Java web作业时&#xff0c;运行tomcat出现乱码问题&#xff0c;经过多次解决与研究&#xff0c;我做了如下整理&#xff1a; 原因&#xff1a; IDEA本身编码与tomcat的编码与Windows编码不同导致&#xff0c;Windows 系统控制台…...

【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型

摘要 拍照搜题系统采用“三层管道&#xff08;多模态 OCR → 语义检索 → 答案渲染&#xff09;、两级检索&#xff08;倒排 BM25 向量 HNSW&#xff09;并以大语言模型兜底”的整体框架&#xff1a; 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后&#xff0c;分别用…...

Cursor实现用excel数据填充word模版的方法

cursor主页&#xff1a;https://www.cursor.com/ 任务目标&#xff1a;把excel格式的数据里的单元格&#xff0c;按照某一个固定模版填充到word中 文章目录 注意事项逐步生成程序1. 确定格式2. 调试程序 注意事项 直接给一个excel文件和最终呈现的word文件的示例&#xff0c;…...

基于服务器使用 apt 安装、配置 Nginx

&#x1f9fe; 一、查看可安装的 Nginx 版本 首先&#xff0c;你可以运行以下命令查看可用版本&#xff1a; apt-cache madison nginx-core输出示例&#xff1a; nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...

高频面试之3Zookeeper

高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个&#xff1f;3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制&#xff08;过半机制&#xff0…...

ESP32读取DHT11温湿度数据

芯片&#xff1a;ESP32 环境&#xff1a;Arduino 一、安装DHT11传感器库 红框的库&#xff0c;别安装错了 二、代码 注意&#xff0c;DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...

STM32标准库-DMA直接存储器存取

文章目录 一、DMA1.1简介1.2存储器映像1.3DMA框图1.4DMA基本结构1.5DMA请求1.6数据宽度与对齐1.7数据转运DMA1.8ADC扫描模式DMA 二、数据转运DMA2.1接线图2.2代码2.3相关API 一、DMA 1.1简介 DMA&#xff08;Direct Memory Access&#xff09;直接存储器存取 DMA可以提供外设…...

sipsak:SIP瑞士军刀!全参数详细教程!Kali Linux教程!

简介 sipsak 是一个面向会话初始协议 (SIP) 应用程序开发人员和管理员的小型命令行工具。它可以用于对 SIP 应用程序和设备进行一些简单的测试。 sipsak 是一款 SIP 压力和诊断实用程序。它通过 sip-uri 向服务器发送 SIP 请求&#xff0c;并检查收到的响应。它以以下模式之一…...

AGain DB和倍数增益的关系

我在设置一款索尼CMOS芯片时&#xff0c;Again增益0db变化为6DB&#xff0c;画面的变化只有2倍DN的增益&#xff0c;比如10变为20。 这与dB和线性增益的关系以及传感器处理流程有关。以下是具体原因分析&#xff1a; 1. dB与线性增益的换算关系 6dB对应的理论线性增益应为&…...

PAN/FPN

import torch import torch.nn as nn import torch.nn.functional as F import mathclass LowResQueryHighResKVAttention(nn.Module):"""方案 1: 低分辨率特征 (Query) 查询高分辨率特征 (Key, Value).输出分辨率与低分辨率输入相同。"""def __…...