如何使用.pth训练模型
一.使用.pth训练模型的步骤如下:
1.导入必要的库和模型
import torch
import torchvision.models as models# 加载预训练模型
model = models.resnet50(pretrained=True)
2.定义数据集和数据加载器
# 定义数据集和数据加载器
dataset = MyDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
3.定义损失函数和优化器
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
4.训练模型
# 训练模型
for epoch in range(10):running_loss = 0.0for i, data in enumerate(dataloader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 2000 == 1999:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0
5.保存模型
# 保存模型
torch.save(model.state_dict(), 'model.pth')
二,使用自己训练的.pth模型进行训练的步骤如下:
1.导入必要的库和模型
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from my_dataset import MyDataset # 自定义数据集
from my_model import MyModel # 自定义模型
2.设置超参数和路径
batch_size = 32 # 批大小
num_epochs = 10 # 训练轮数
learning_rate = 0.001 # 学习率
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设置设备
train_data_path = 'train_data/' # 训练数据集路径
test_data_path = 'test_data/' # 测试数据集路径
model_path = 'my_model.pth' # 模型保存路径
3.加载数据集
train_transforms = transforms.Compose([transforms.Resize((224, 224)), # 调整图像大小transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(), # 转换为张量transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 标准化
])test_transforms = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])train_dataset = MyDataset(train_data_path, train_transforms) # 自定义数据集
test_dataset = MyDataset(test_data_path, test_transforms)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # 训练集加载器
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 测试集加载器
4.加载模型
model = MyModel() # 自定义模型
model.load_state_dict(torch.load(model_path)) # 加载.pth模型
model.to(device) # 将模型移动到设备上
5.定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Adam优化器
6.训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))torch.save(model.state_dict(), 'fine_tuned_model.pth') # 保存.pth模型
7.测试模型
model.eval() # 切换到评估模式
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the test images: {} %'.format(100 * correct / total))
相关文章:
如何使用.pth训练模型
一.使用.pth训练模型的步骤如下: 1.导入必要的库和模型 import torch import torchvision.models as models# 加载预训练模型 model models.resnet50(pretrainedTrue) 2.定义数据集和数据加载器 # 定义数据集和数据加载器 dataset MyDataset() dataloader to…...
C++11线程以及线程同步
C11中提供的线程类std::thread,基于此类创建一个新的线程相对简单,只需要提供线程函数和线程对象即可 一.命名空间 this_thread C11 添加一个关于线程的命名空间std::this_pthread ,此命名空间中提供四个公共的成员函数; 1.1 get_id() 调用命名空间s…...
深度学习之基于YoloV3杂草识别系统
欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 深度学习在图像识别领域已经取得了显著的成果,其中基于YOLO(You Only Look Once)…...
Linux 命令vim(编辑器)
(一)vim编辑器的介绍 vim是文件编辑器,是vi的升级版本,兼容vi的所有指令,同时做了优化和延伸。vim有多种模式,其中常用的模式有命令模式、插入模式、末行模式:。 (二)vim编辑器基本操作 1 进入vim编辑文件 1 vim …...
轻松配置PPPoE连接:路由器设置和步骤详解
在家庭网络环境中,我们经常使用PPPoE(点对点协议过夜)连接来接入宽带互联网。然而,对于一些没有网络专业知识的人来说,配置PPPoE连接可能会有些困难。在本文中,我将详细介绍如何轻松配置PPPoE连接ÿ…...
电源控制系统架构(PCSA)之系统分区电源域
目录 4.2 电源域 4.2.1 电源模式 4.2.2 电源域的选择 4.2.3 系统逻辑 4.2.4 Always-On域 4.2.5 处理器Clusters 4.2.6 CoreSight逻辑 4.2.7 图像处理器 4.2.8 显示处理器 4.2.9 其他功能 4.2.10 电源域层次结构要求 4.2.11 SOC域示例 4.2 电源域 电源域在这里被定…...
Linux:docker基础操作(3)
docker的介绍 Linux:Docker的介绍(1)-CSDN博客https://blog.csdn.net/w14768855/article/details/134146721?spm1001.2014.3001.5502 通过yum安装docker Linux:Docker-yum安装(2)-CSDN博客https://blog.…...
【Axure教程】用中继器制作卡片多条件搜索效果
卡片设计通过提供清晰的信息结构、可视化吸引力、易扩展性和强大的交互性,为用户界面设计带来了许多优势,使得用户能够更轻松地浏览、理解和互动。 那今天就教大家如何用中继器制作卡片的模板,以及完成多条件搜索的效果,我们会以…...
Linux中vi常用命令-批量替换
在日常服务器日志查看中常用到的命令有grep、tail等,有时想查看详细日志,用到vi命令,记录下来,方便查看。 操作文件:test.properites 一、查看与编辑 查看命令:vi 文件名 编辑命令:按键 i&…...
logback-spring.xml的内容格式
目录 一、logback-spring.xml 二、Logback 中的三种日志文件类型 一、logback-spring.xml <?xml version"1.0" encoding"UTF-8"?> <configuration scan"true" scanPeriod"10 seconds" ><!-- <statusListener…...
nodejs+vue+elementui+express青少年编程课程在线考试系统
针对传统线下考试存在的老师阅卷工作量较大,统计成绩数据时间长等问题,实现一套高效、灵活、功能强大的管理系统是非常必要的。该系统可以迅速完成随机组卷,及时阅卷、统计考试成绩排名的效果。该考试系统要求:该系统将采用B/S结构…...
Navicat 技术指引 | GaussDB 数据查看器
Navicat Premium(16.2.8 Windows版或以上) 已支持对GaussDB 主备版的管理和开发功能。它不仅具备轻松、便捷的可视化数据查看和编辑功能,还提供强大的高阶功能(如模型、结构同步、协同合作、数据迁移等),这…...
Docker的registry
简介 地址:https://hub.docker.com/_/registry Dcoker registry是存储Dcoker image的仓库,运行push,pull,search时,是通过Dcoker daemon与docker registry通信。有时候会用Dcoker Hub这样的公共仓库可能不方便&#x…...
【vue_3】关于超链接的问题
1、需求2、修改前的代码3、修改之后(1)第一次(2)第二次(3)第三次(4)第四次(5)第五次 1、需求 需求:要给没有超链接的列表添加软超链接 2、修改前…...
redis优化秒杀和消息队列
redis优化秒杀 1. 异步秒杀思路1.1 在redis存入库存和订单信息1.2 具体流程图 2. 实现2.1 总结 3. Redis的消息队列3.1 基于list实现消息队列3.2 基于PubSub实现消息队列3.3 基于stream实现消息队列3.3.1 stream的单消费模式3.3.2 stream的消费者组模式 3.4 基于stream消息队列…...
arm-eabi-gcc 和 arm-none-eabi-gcc 都是基于 GCC 的交叉编译器
arm-eabi-gcc 和 arm-none-eabi-gcc 都是基于 GCC 的交叉编译器,用于编译 ARM 架构的嵌入式系统。它们的命名规则如下: arm 表示目标架构是 ARM。eabi 表示嵌入式应用程序二进制接口(Embedded Application Binary Interface)&…...
《大话设计模式》(持续更新中)
《大话设计模式》 序 为什么要学设计模式第0章 面向对象基础什么是对象?什么是类?什么是构造方法?什么是重载?属性与字段有什么区别?什么是封装?什么是继承?什么是多态?抽象类的目的…...
人工智能原理复习--绪论
文章目录 人工智能原理概述图灵测试人工智能的研究方法符号主义连接主义行为主义总结 人工智能原理概述 人工智能是计算机科学基础理论研究的重要组成部分 现代人工智能一般认为起源于美国1956你那夏季的达特茅斯会议,在这次会议上,John McCarthy第一次…...
[网络] 字节一面~ 2. HTTP 2 与 HTTP 1.x 有什么区别
头部压缩 在 HTTP2 当中,如果你发出了多个请求,并且它们的头部(header)是相同的,那么 HTTP2 协议会帮你消除同样的部分。(其实就是在客户端和服务端维护一张索引表来实现)二进制格式 HTTP1.1 采用明文的形式 HTTP/2 全⾯采⽤了⼆进制格式&…...
自己动手实现一个深度学习算法——八、深度学习
深度学习是加深了层的深度神经网络。 1.加深网络 1)向更深的网络出发 创建一个如下图所示的网络结构的CNN 这个网络的层比之前实现的网络都更深。这里使用的卷积层全都是33 的小型滤波器,特点是随着层的加深,通道数变大(卷积…...
C++初阶-list的底层
目录 1.std::list实现的所有代码 2.list的简单介绍 2.1实现list的类 2.2_list_iterator的实现 2.2.1_list_iterator实现的原因和好处 2.2.2_list_iterator实现 2.3_list_node的实现 2.3.1. 避免递归的模板依赖 2.3.2. 内存布局一致性 2.3.3. 类型安全的替代方案 2.3.…...
vscode(仍待补充)
写于2025 6.9 主包将加入vscode这个更权威的圈子 vscode的基本使用 侧边栏 vscode还能连接ssh? debug时使用的launch文件 1.task.json {"tasks": [{"type": "cppbuild","label": "C/C: gcc.exe 生成活动文件"…...
蓝牙 BLE 扫描面试题大全(2):进阶面试题与实战演练
前文覆盖了 BLE 扫描的基础概念与经典问题蓝牙 BLE 扫描面试题大全(1):从基础到实战的深度解析-CSDN博客,但实际面试中,企业更关注候选人对复杂场景的应对能力(如多设备并发扫描、低功耗与高发现率的平衡)和前沿技术的…...
在四层代理中还原真实客户端ngx_stream_realip_module
一、模块原理与价值 PROXY Protocol 回溯 第三方负载均衡(如 HAProxy、AWS NLB、阿里 SLB)发起上游连接时,将真实客户端 IP/Port 写入 PROXY Protocol v1/v2 头。Stream 层接收到头部后,ngx_stream_realip_module 从中提取原始信息…...
Keil 中设置 STM32 Flash 和 RAM 地址详解
文章目录 Keil 中设置 STM32 Flash 和 RAM 地址详解一、Flash 和 RAM 配置界面(Target 选项卡)1. IROM1(用于配置 Flash)2. IRAM1(用于配置 RAM)二、链接器设置界面(Linker 选项卡)1. 勾选“Use Memory Layout from Target Dialog”2. 查看链接器参数(如果没有勾选上面…...
汇编常见指令
汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX(不访问内存)XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...
论文笔记——相干体技术在裂缝预测中的应用研究
目录 相关地震知识补充地震数据的认识地震几何属性 相干体算法定义基本原理第一代相干体技术:基于互相关的相干体技术(Correlation)第二代相干体技术:基于相似的相干体技术(Semblance)基于多道相似的相干体…...
iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈
在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...
【C++特殊工具与技术】优化内存分配(一):C++中的内存分配
目录 一、C 内存的基本概念 1.1 内存的物理与逻辑结构 1.2 C 程序的内存区域划分 二、栈内存分配 2.1 栈内存的特点 2.2 栈内存分配示例 三、堆内存分配 3.1 new和delete操作符 4.2 内存泄漏与悬空指针问题 4.3 new和delete的重载 四、智能指针…...
Golang——7、包与接口详解
包与接口详解 1、Golang包详解1.1、Golang中包的定义和介绍1.2、Golang包管理工具go mod1.3、Golang中自定义包1.4、Golang中使用第三包1.5、init函数 2、接口详解2.1、接口的定义2.2、空接口2.3、类型断言2.4、结构体值接收者和指针接收者实现接口的区别2.5、一个结构体实现多…...
