使用pytorch进行迁移学习的两个步骤
1. 步骤及代码
迁移学习一般都会使用两个步骤进行训练:
- 固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛;
- 使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节。
import os
import torch
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms as T
import numpy as np# 设置均值、方差
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]# 还原减均值除以方差之前的数据,用于可视化
def reduction_img_show(tensor, mean, std) -> None:to_img = T.ToPILImage()reduced_img = to_img(tensor * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1))reduced_img.show()def getResNet(*, class_names: str, loadfile: str = None):if loadfile is not None:model = torchvision.models.resnet18()model.load_state_dict(torch.load('resnet18-f37072fd.pth')) # 加载权重else:model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1) # 模型自动下载到C:\Users\GaryLau\.cache\torch\hub\checkpoints# 将所有的参数层冻结,设置模型除最后一层以外都不可以进行训练,使模型只针对最后一层进行微调for param in model.parameters():param.requires_grad = False# 输出全连接层信息print(model.fc)x = model.fc.in_features # 获取全连接层输入维度model.fc = torch.nn.Linear(in_features=x, out_features=len(class_names)) # 创建新的全连接层print(model.fc) # 输出新的全连接层return model# 定义训练函数
def train(model, device, train_loader, criterion, optimizer, epoch):model.train()all_loss = []for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()y_pred = model(data)loss = criterion(y_pred, target)loss.backward()all_loss.append(loss.item())optimizer.step()if batch_idx % 10 == 0:print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),np.mean(all_loss)))def val(model, device, val_loader, criterion):model.eval()test_loss = []correct = []with torch.no_grad():for data, target in val_loader:data, target = data.to(device), target.to(device)y_pred = model(data)test_loss.append(criterion(y_pred, target).item())pred = y_pred.argmax(dim=1, keepdim=True)correct.append(pred.eq(target.view_as(pred)).sum().item()/pred.size(0))print('-->Test: Average loss:{:.4f}, Accuracy:({:.0f}%)\n'.format(np.mean(test_loss), 100 * sum(correct) / len(correct)))# 训练,验证时的预处理
transform = {'train': T.Compose([T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean=mean, std=std)]),'val': T.Compose([T.Resize((224,224)),T.ToTensor(),T.Normalize(mean=mean, std=std)])}# 加载训练、验证数据
dataset_train = ImageFolder(r'./train', transform=transform['train'])
dataset_val = ImageFolder(r'./test', transform=transform['val'])# 类别标签
class_names = dataset_train.classes
print(dataset_train.class_to_idx)
print(dataset_val.class_to_idx)# 显示一张训练、验证图
# reduction_img_show(dataset_train[0][0], mean, std)
# reduction_img_show(dataset_val[0][0], mean, std)# 使用DataLoader遍历数据
dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True, sampler=None, num_workers=0,pin_memory=False, drop_last=False)
dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=False, sampler=None, num_workers=0,pin_memory=False, drop_last=False)# 使用方式一,使用next不断获取一个batch的数据
dataiter_train = iter(dataloader_train)
imgs, labels = next(dataiter_train)
print(imgs.size())
# reduction_img_show(imgs[0], mean, std)
# reduction_img_show(imgs[1], mean, std)
multi_imgs = torchvision.utils.make_grid(imgs, nrow=10) # 拼接一个batch的图像用于展示
# reduction_img_show(multi_imgs, mean, std)# 获取ResNet模型,并加载预训练模型权重,将最后一层(输出层)去掉,换成一个新的全连接层,新全连接层输出的节点数是新数据的类别数
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)# 构建模型
model = getResNet(class_names=class_names, loadfile='resnet18-f37072fd.pth')
model.to(device)# 构建损失函数
criterion = torch.nn.CrossEntropyLoss()
# 指定新加的全连接层为要更新的参数
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001) # 只需要更新最后一层fc的参数if __name__ == '__main__':### 步骤一,微调最后一层first_model = 'resnet18-f37072fd_finetune_fcLayer.pth'for epoch in range(1, 6):train(model, device, dataloader_train, criterion, optimizer, epoch)val(model, device, dataloader_val, criterion)# 仅保存了最后新添加的全连接层的参数#torch.save(model.fc.state_dict(), first_model)torch.save(model.state_dict(), first_model)### 步骤二,小学习率微调所有层second_model = 'resnet18-f37072fd_finetune_allLayer.pth'optimizer2 = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=3, gamma=0.9)# 将所有的参数层设为可训练的for param in model.parameters():param.requires_grad = Trueif os.path.exists(second_model):model.load_state_dict(torch.load(second_model)) # 加载本地模型else:model.load_state_dict(torch.load(first_model)) # 加载步骤一训练得到的本地模型print('Finetune all layers with small learning rate......')for epoch in range(1, 101):train(model, device, dataloader_train, criterion, optimizer2, epoch)if optimizer2.state_dict()['param_groups'][0]['lr'] > 0.00001:exp_lr_scheduler.step()print(f"learning rate: {optimizer2.state_dict()['param_groups'][0]['lr']}")val(model, device, dataloader_val, criterion)# 保存整个模型torch.save(model.state_dict(), second_model)print('Done.')
2. 完整资源
https://download.csdn.net/download/liugan528/89833913
相关文章:
使用pytorch进行迁移学习的两个步骤
1. 步骤及代码 迁移学习一般都会使用两个步骤进行训练: 固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛;使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节。 impor…...
ChatGPT相关参数示例
max_token 用于控制最大输出长度,若ChatGPT的回复大于max_tokens,则对输出结果进行截断。 from openai import OpenAI client OpenAI(base_url"https://api.chatanywhere.tech/v1" ) response client.chat.completions.create(model"…...
OWASP发布大模型安全风险与应对策略(QA测试重点关注)
开放式 Web 应用程序安全项目(OWASP)发布了关于大模型应用的安全风险,这些风险不仅包括传统的沙盒逃逸、代码执行和鉴权不当等安全风险,还涉及提示注入、对话数据泄露和数据投毒等人工智能特有的安全风险。 帮助开发者和测试同学更…...
【HarmonyOS开发笔记 2 】 -- ArkTS语法中的变量与常量
ArkTS是HarmonyOS开发的编程语言 ArkTS语法中的变量 【语法格式】: let 变量名: 类型 值 let:是定义变量的关键字类型: 值数据类型, 常用的数据类型 字符型(string)、数字型(number…...
UI自动化测试示例:python+pytest+selenium+allure
重点应用是封装、参数化: 比如在lib文件夹下,要存储封装好的方法和必要的环境变量(指网址等) 1.cfg.py:封装网址和对应的页面 SMP_ADDRESS http://127.0.0.1:8234SMP_URL_LOGIN f{SMP_ADDRESS}/login.html SMP_URL_DE…...
C/C++ 编程小工具
编写了 tools.h 和 tools.cpp,用于 Debug、性能测试、打印日志。 tools.h #ifndef TOOLS_H #define TOOLS_H#include <time.h> #include <fstream> #include <iostream> #include <random> #include <chrono> #include <vector&…...
第四十二章 使用 WS-ReliableMessaging
文章目录 第四十二章 使用 WS-ReliableMessaging从 Web 客户端发送一系列消息 第四十二章 使用 WS-ReliableMessaging IRIS 支持 WS-ReliableMessaging 规范的部分内容,如简介中所述。此规范提供了一种按顺序可靠地传递一系列消息的机制。本页介绍如何手动使用可靠…...
利士策分享,婚姻为何被称为大事?
利士策分享,婚姻为何被称为大事? 在历史的长河中,婚姻一直被视为人生中的头等大事,这一观念跨越时空,深深植根于各种文化和社会结构中。 古人为何将婚姻称为“大事”,这背后蕴含着丰富的社会、文化和心理寓…...
malloc源码分析之 ----- 你想要啥chunk
文章目录 malloc源码分析之 ----- 你想要啥chunktcachefastbinsmall binunsorted binbin处理top malloc源码分析之 ----- 你想要啥chunk tcache malloc源码,这里以glibc-2.29为例: void * __libc_malloc (size_t bytes) {mstate ar_ptr;void *victim;vo…...
软考系统分析师知识点五:数据通信与计算机网络
前言 今年报考了11月份的软考高级:系统分析师。 考试时间为:11月9日。 倒计时:32天。 目标:优先应试,其次学习,再次实践。 复习计划第一阶段:扫平基础知识点,仅抽取有用信息&am…...
windows客户端SSH连接ubuntu/linux服务器,三种网络连接:局域网,内网穿透(sakuraftp),虚拟局域网(zerotier)
windows客户端SSH连接ubuntu/linux服务器,三种网络连接:局域网,内网穿透(sakuraftp),虚拟局域网(zerotier) 目录 SSH简述、三种网络连接特点SSH简述局域网内连接内网穿透(…...
Python 工具库每日推荐【openpyxl 】
文章目录 引言Python Excel 处理库的重要性今日推荐:openpyxl 工具库主要功能:使用场景:安装与配置快速上手示例代码代码解释实际应用案例案例:自动生成月度销售报告案例分析高级特性条件格式数据验证扩展阅读与资源优缺点分析优点:缺点:总结【 已更新完 TypeScript 设计…...
本地生活服务项目入局方案解析!本地生活服务商系统能实现怎样的作业效果?
当前,各大平台的本地生活服务业务日渐兴盛,提高创业者入局意向的同时,也让本地生活服务项目有哪些等问题也成为了多个创业者社群中的热议对象。而从目前的讨论情况来看,在创业者们所询问的众多本地生活服务项目中,通过…...
ML 系列:【13 】— Logistic 回归(第 2 部分)
文章目录 一、说明二、挤压方法三、Logistic 回归中的损失函数四、后记 一、说明 在这篇文章中,我们将深入研究 squashing 方法,这是有符号距离方法(第 12节)的一种很有前途的替代方案。squashing 方法通过提供增强的对异常值…...
45岁被裁员的程序员,何去何从?
在当今快速变化的技术行业,职业生涯的稳定性受到挑战。在45岁被裁员,对很多程序员来说,可能是一种惊慌失措的体验。然而,这个阶段也可以被视为一个重新审视和调整方向的机会。本文将对可能的出路进行全方位的分析,并提…...
云计算Openstack Neutron
OpenStack Neutron是OpenStack云计算平台中的网络服务组件,它为OpenStack提供了强大的网络连接功能。 一、基本概念 Neutron是一个网络服务项目,旨在为OpenStack提供网络连接。它允许用户创建和管理虚拟网络,包括子网、路由、安全组等&…...
PointNet++网络详解
数据集转换 数据集转换的意义在于将原本的 txt 点云文件转换为更方便运算的npy点云文件,同时,将原本的xyzrgb这 6 个维度转换为xyzrgbc,最后一个c维度代表该点云所属的类别。 for anno_path in anno_paths:print(anno_path)try:elements a…...
Java | Leetcode Java题解之第459题重复的子字符串
题目: 题解: class Solution {public boolean repeatedSubstringPattern(String s) {return kmp(s s, s);}public boolean kmp(String query, String pattern) {int n query.length();int m pattern.length();int[] fail new int[m];Arrays.fill(fa…...
【动态规划-最长公共子序列(LCS)】【hard】力扣1092. 最短公共超序列
给你两个字符串 str1 和 str2,返回同时以 str1 和 str2 作为 子序列 的最短字符串。如果答案不止一个,则可以返回满足条件的 任意一个 答案。 如果从字符串 t 中删除一些字符(也可能不删除),可以得到字符串 s &#x…...
图片编辑为底片,智能工具助力,创作精彩视觉作品
在当今数字化时代,图像编辑已成为表达创意和美化视觉作品的重要手段。借助智能工具,即使是初学者也能轻松驾驭图片编辑。接下为大家展示图片编辑为底片图片的效果。 1.打开“首助编辑高手”,选择这里“图片批量处理”版块页面上 2.导入保存有…...
【Axure高保真原型】引导弹窗
今天和大家中分享引导弹窗的原型模板,载入页面后,会显示引导弹窗,适用于引导用户使用页面,点击完成后,会显示下一个引导弹窗,直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…...
系统设计 --- MongoDB亿级数据查询优化策略
系统设计 --- MongoDB亿级数据查询分表策略 背景Solution --- 分表 背景 使用audit log实现Audi Trail功能 Audit Trail范围: 六个月数据量: 每秒5-7条audi log,共计7千万 – 1亿条数据需要实现全文检索按照时间倒序因为license问题,不能使用ELK只能使用…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一)
宇树机器人多姿态起立控制强化学习框架论文解析 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一) 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化…...
【HTML-16】深入理解HTML中的块元素与行内元素
HTML元素根据其显示特性可以分为两大类:块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...
CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云
目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...
深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南
🚀 C extern 关键字深度解析:跨文件编程的终极指南 📅 更新时间:2025年6月5日 🏷️ 标签:C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言🔥一、extern 是什么?&…...
Mac下Android Studio扫描根目录卡死问题记录
环境信息 操作系统: macOS 15.5 (Apple M2芯片)Android Studio版本: Meerkat Feature Drop | 2024.3.2 Patch 1 (Build #AI-243.26053.27.2432.13536105, 2025年5月22日构建) 问题现象 在项目开发过程中,提示一个依赖外部头文件的cpp源文件需要同步,点…...
解决:Android studio 编译后报错\app\src\main\cpp\CMakeLists.txt‘ to exist
现象: android studio报错: [CXX1409] D:\GitLab\xxxxx\app.cxx\Debug\3f3w4y1i\arm64-v8a\android_gradle_build.json : expected buildFiles file ‘D:\GitLab\xxxxx\app\src\main\cpp\CMakeLists.txt’ to exist 解决: 不要动CMakeLists.…...
解析奥地利 XARION激光超声检测系统:无膜光学麦克风 + 无耦合剂的技术协同优势及多元应用
在工业制造领域,无损检测(NDT)的精度与效率直接影响产品质量与生产安全。奥地利 XARION开发的激光超声精密检测系统,以非接触式光学麦克风技术为核心,打破传统检测瓶颈,为半导体、航空航天、汽车制造等行业提供了高灵敏…...
二维FDTD算法仿真
二维FDTD算法仿真,并带完全匹配层,输入波形为高斯波、平面波 FDTD_二维/FDTD.zip , 6075 FDTD_二维/FDTD_31.m , 1029 FDTD_二维/FDTD_32.m , 2806 FDTD_二维/FDTD_33.m , 3782 FDTD_二维/FDTD_34.m , 4182 FDTD_二维/FDTD_35.m , 4793...
