使用pytorch进行迁移学习的两个步骤
1. 步骤及代码
迁移学习一般都会使用两个步骤进行训练:
- 固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛;
- 使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节。
import os
import torch
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms as T
import numpy as np# 设置均值、方差
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]# 还原减均值除以方差之前的数据,用于可视化
def reduction_img_show(tensor, mean, std) -> None:to_img = T.ToPILImage()reduced_img = to_img(tensor * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1))reduced_img.show()def getResNet(*, class_names: str, loadfile: str = None):if loadfile is not None:model = torchvision.models.resnet18()model.load_state_dict(torch.load('resnet18-f37072fd.pth')) # 加载权重else:model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1) # 模型自动下载到C:\Users\GaryLau\.cache\torch\hub\checkpoints# 将所有的参数层冻结,设置模型除最后一层以外都不可以进行训练,使模型只针对最后一层进行微调for param in model.parameters():param.requires_grad = False# 输出全连接层信息print(model.fc)x = model.fc.in_features # 获取全连接层输入维度model.fc = torch.nn.Linear(in_features=x, out_features=len(class_names)) # 创建新的全连接层print(model.fc) # 输出新的全连接层return model# 定义训练函数
def train(model, device, train_loader, criterion, optimizer, epoch):model.train()all_loss = []for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()y_pred = model(data)loss = criterion(y_pred, target)loss.backward()all_loss.append(loss.item())optimizer.step()if batch_idx % 10 == 0:print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),np.mean(all_loss)))def val(model, device, val_loader, criterion):model.eval()test_loss = []correct = []with torch.no_grad():for data, target in val_loader:data, target = data.to(device), target.to(device)y_pred = model(data)test_loss.append(criterion(y_pred, target).item())pred = y_pred.argmax(dim=1, keepdim=True)correct.append(pred.eq(target.view_as(pred)).sum().item()/pred.size(0))print('-->Test: Average loss:{:.4f}, Accuracy:({:.0f}%)\n'.format(np.mean(test_loss), 100 * sum(correct) / len(correct)))# 训练,验证时的预处理
transform = {'train': T.Compose([T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean=mean, std=std)]),'val': T.Compose([T.Resize((224,224)),T.ToTensor(),T.Normalize(mean=mean, std=std)])}# 加载训练、验证数据
dataset_train = ImageFolder(r'./train', transform=transform['train'])
dataset_val = ImageFolder(r'./test', transform=transform['val'])# 类别标签
class_names = dataset_train.classes
print(dataset_train.class_to_idx)
print(dataset_val.class_to_idx)# 显示一张训练、验证图
# reduction_img_show(dataset_train[0][0], mean, std)
# reduction_img_show(dataset_val[0][0], mean, std)# 使用DataLoader遍历数据
dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True, sampler=None, num_workers=0,pin_memory=False, drop_last=False)
dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=False, sampler=None, num_workers=0,pin_memory=False, drop_last=False)# 使用方式一,使用next不断获取一个batch的数据
dataiter_train = iter(dataloader_train)
imgs, labels = next(dataiter_train)
print(imgs.size())
# reduction_img_show(imgs[0], mean, std)
# reduction_img_show(imgs[1], mean, std)
multi_imgs = torchvision.utils.make_grid(imgs, nrow=10) # 拼接一个batch的图像用于展示
# reduction_img_show(multi_imgs, mean, std)# 获取ResNet模型,并加载预训练模型权重,将最后一层(输出层)去掉,换成一个新的全连接层,新全连接层输出的节点数是新数据的类别数
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)# 构建模型
model = getResNet(class_names=class_names, loadfile='resnet18-f37072fd.pth')
model.to(device)# 构建损失函数
criterion = torch.nn.CrossEntropyLoss()
# 指定新加的全连接层为要更新的参数
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001) # 只需要更新最后一层fc的参数if __name__ == '__main__':### 步骤一,微调最后一层first_model = 'resnet18-f37072fd_finetune_fcLayer.pth'for epoch in range(1, 6):train(model, device, dataloader_train, criterion, optimizer, epoch)val(model, device, dataloader_val, criterion)# 仅保存了最后新添加的全连接层的参数#torch.save(model.fc.state_dict(), first_model)torch.save(model.state_dict(), first_model)### 步骤二,小学习率微调所有层second_model = 'resnet18-f37072fd_finetune_allLayer.pth'optimizer2 = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=3, gamma=0.9)# 将所有的参数层设为可训练的for param in model.parameters():param.requires_grad = Trueif os.path.exists(second_model):model.load_state_dict(torch.load(second_model)) # 加载本地模型else:model.load_state_dict(torch.load(first_model)) # 加载步骤一训练得到的本地模型print('Finetune all layers with small learning rate......')for epoch in range(1, 101):train(model, device, dataloader_train, criterion, optimizer2, epoch)if optimizer2.state_dict()['param_groups'][0]['lr'] > 0.00001:exp_lr_scheduler.step()print(f"learning rate: {optimizer2.state_dict()['param_groups'][0]['lr']}")val(model, device, dataloader_val, criterion)# 保存整个模型torch.save(model.state_dict(), second_model)print('Done.')
2. 完整资源
https://download.csdn.net/download/liugan528/89833913
相关文章:
使用pytorch进行迁移学习的两个步骤
1. 步骤及代码 迁移学习一般都会使用两个步骤进行训练: 固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛;使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节。 impor…...
ChatGPT相关参数示例
max_token 用于控制最大输出长度,若ChatGPT的回复大于max_tokens,则对输出结果进行截断。 from openai import OpenAI client OpenAI(base_url"https://api.chatanywhere.tech/v1" ) response client.chat.completions.create(model"…...

OWASP发布大模型安全风险与应对策略(QA测试重点关注)
开放式 Web 应用程序安全项目(OWASP)发布了关于大模型应用的安全风险,这些风险不仅包括传统的沙盒逃逸、代码执行和鉴权不当等安全风险,还涉及提示注入、对话数据泄露和数据投毒等人工智能特有的安全风险。 帮助开发者和测试同学更…...

【HarmonyOS开发笔记 2 】 -- ArkTS语法中的变量与常量
ArkTS是HarmonyOS开发的编程语言 ArkTS语法中的变量 【语法格式】: let 变量名: 类型 值 let:是定义变量的关键字类型: 值数据类型, 常用的数据类型 字符型(string)、数字型(number…...
UI自动化测试示例:python+pytest+selenium+allure
重点应用是封装、参数化: 比如在lib文件夹下,要存储封装好的方法和必要的环境变量(指网址等) 1.cfg.py:封装网址和对应的页面 SMP_ADDRESS http://127.0.0.1:8234SMP_URL_LOGIN f{SMP_ADDRESS}/login.html SMP_URL_DE…...
C/C++ 编程小工具
编写了 tools.h 和 tools.cpp,用于 Debug、性能测试、打印日志。 tools.h #ifndef TOOLS_H #define TOOLS_H#include <time.h> #include <fstream> #include <iostream> #include <random> #include <chrono> #include <vector&…...
第四十二章 使用 WS-ReliableMessaging
文章目录 第四十二章 使用 WS-ReliableMessaging从 Web 客户端发送一系列消息 第四十二章 使用 WS-ReliableMessaging IRIS 支持 WS-ReliableMessaging 规范的部分内容,如简介中所述。此规范提供了一种按顺序可靠地传递一系列消息的机制。本页介绍如何手动使用可靠…...

利士策分享,婚姻为何被称为大事?
利士策分享,婚姻为何被称为大事? 在历史的长河中,婚姻一直被视为人生中的头等大事,这一观念跨越时空,深深植根于各种文化和社会结构中。 古人为何将婚姻称为“大事”,这背后蕴含着丰富的社会、文化和心理寓…...

malloc源码分析之 ----- 你想要啥chunk
文章目录 malloc源码分析之 ----- 你想要啥chunktcachefastbinsmall binunsorted binbin处理top malloc源码分析之 ----- 你想要啥chunk tcache malloc源码,这里以glibc-2.29为例: void * __libc_malloc (size_t bytes) {mstate ar_ptr;void *victim;vo…...

软考系统分析师知识点五:数据通信与计算机网络
前言 今年报考了11月份的软考高级:系统分析师。 考试时间为:11月9日。 倒计时:32天。 目标:优先应试,其次学习,再次实践。 复习计划第一阶段:扫平基础知识点,仅抽取有用信息&am…...

windows客户端SSH连接ubuntu/linux服务器,三种网络连接:局域网,内网穿透(sakuraftp),虚拟局域网(zerotier)
windows客户端SSH连接ubuntu/linux服务器,三种网络连接:局域网,内网穿透(sakuraftp),虚拟局域网(zerotier) 目录 SSH简述、三种网络连接特点SSH简述局域网内连接内网穿透(…...

Python 工具库每日推荐【openpyxl 】
文章目录 引言Python Excel 处理库的重要性今日推荐:openpyxl 工具库主要功能:使用场景:安装与配置快速上手示例代码代码解释实际应用案例案例:自动生成月度销售报告案例分析高级特性条件格式数据验证扩展阅读与资源优缺点分析优点:缺点:总结【 已更新完 TypeScript 设计…...

本地生活服务项目入局方案解析!本地生活服务商系统能实现怎样的作业效果?
当前,各大平台的本地生活服务业务日渐兴盛,提高创业者入局意向的同时,也让本地生活服务项目有哪些等问题也成为了多个创业者社群中的热议对象。而从目前的讨论情况来看,在创业者们所询问的众多本地生活服务项目中,通过…...

ML 系列:【13 】— Logistic 回归(第 2 部分)
文章目录 一、说明二、挤压方法三、Logistic 回归中的损失函数四、后记 一、说明 在这篇文章中,我们将深入研究 squashing 方法,这是有符号距离方法(第 12节)的一种很有前途的替代方案。squashing 方法通过提供增强的对异常值…...

45岁被裁员的程序员,何去何从?
在当今快速变化的技术行业,职业生涯的稳定性受到挑战。在45岁被裁员,对很多程序员来说,可能是一种惊慌失措的体验。然而,这个阶段也可以被视为一个重新审视和调整方向的机会。本文将对可能的出路进行全方位的分析,并提…...

云计算Openstack Neutron
OpenStack Neutron是OpenStack云计算平台中的网络服务组件,它为OpenStack提供了强大的网络连接功能。 一、基本概念 Neutron是一个网络服务项目,旨在为OpenStack提供网络连接。它允许用户创建和管理虚拟网络,包括子网、路由、安全组等&…...

PointNet++网络详解
数据集转换 数据集转换的意义在于将原本的 txt 点云文件转换为更方便运算的npy点云文件,同时,将原本的xyzrgb这 6 个维度转换为xyzrgbc,最后一个c维度代表该点云所属的类别。 for anno_path in anno_paths:print(anno_path)try:elements a…...

Java | Leetcode Java题解之第459题重复的子字符串
题目: 题解: class Solution {public boolean repeatedSubstringPattern(String s) {return kmp(s s, s);}public boolean kmp(String query, String pattern) {int n query.length();int m pattern.length();int[] fail new int[m];Arrays.fill(fa…...
【动态规划-最长公共子序列(LCS)】【hard】力扣1092. 最短公共超序列
给你两个字符串 str1 和 str2,返回同时以 str1 和 str2 作为 子序列 的最短字符串。如果答案不止一个,则可以返回满足条件的 任意一个 答案。 如果从字符串 t 中删除一些字符(也可能不删除),可以得到字符串 s &#x…...

图片编辑为底片,智能工具助力,创作精彩视觉作品
在当今数字化时代,图像编辑已成为表达创意和美化视觉作品的重要手段。借助智能工具,即使是初学者也能轻松驾驭图片编辑。接下为大家展示图片编辑为底片图片的效果。 1.打开“首助编辑高手”,选择这里“图片批量处理”版块页面上 2.导入保存有…...

AI-调查研究-01-正念冥想有用吗?对健康的影响及科学指南
点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...
React Native 开发环境搭建(全平台详解)
React Native 开发环境搭建(全平台详解) 在开始使用 React Native 开发移动应用之前,正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南,涵盖 macOS 和 Windows 平台的配置步骤,如何在 Android 和 iOS…...

循环冗余码校验CRC码 算法步骤+详细实例计算
通信过程:(白话解释) 我们将原始待发送的消息称为 M M M,依据发送接收消息双方约定的生成多项式 G ( x ) G(x) G(x)(意思就是 G ( x ) G(x) G(x) 是已知的)࿰…...

为什么需要建设工程项目管理?工程项目管理有哪些亮点功能?
在建筑行业,项目管理的重要性不言而喻。随着工程规模的扩大、技术复杂度的提升,传统的管理模式已经难以满足现代工程的需求。过去,许多企业依赖手工记录、口头沟通和分散的信息管理,导致效率低下、成本失控、风险频发。例如&#…...
【Go】3、Go语言进阶与依赖管理
前言 本系列文章参考自稀土掘金上的 【字节内部课】公开课,做自我学习总结整理。 Go语言并发编程 Go语言原生支持并发编程,它的核心机制是 Goroutine 协程、Channel 通道,并基于CSP(Communicating Sequential Processes࿰…...
【android bluetooth 框架分析 04】【bt-framework 层详解 1】【BluetoothProperties介绍】
1. BluetoothProperties介绍 libsysprop/srcs/android/sysprop/BluetoothProperties.sysprop BluetoothProperties.sysprop 是 Android AOSP 中的一种 系统属性定义文件(System Property Definition File),用于声明和管理 Bluetooth 模块相…...

Module Federation 和 Native Federation 的比较
前言 Module Federation 是 Webpack 5 引入的微前端架构方案,允许不同独立构建的应用在运行时动态共享模块。 Native Federation 是 Angular 官方基于 Module Federation 理念实现的专为 Angular 优化的微前端方案。 概念解析 Module Federation (模块联邦) Modul…...

Ascend NPU上适配Step-Audio模型
1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统,支持多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤)&#x…...
Caliper 配置文件解析:config.yaml
Caliper 是一个区块链性能基准测试工具,用于评估不同区块链平台的性能。下面我将详细解释你提供的 fisco-bcos.json 文件结构,并说明它与 config.yaml 文件的关系。 fisco-bcos.json 文件解析 这个文件是针对 FISCO-BCOS 区块链网络的 Caliper 配置文件,主要包含以下几个部…...
Java线上CPU飙高问题排查全指南
一、引言 在Java应用的线上运行环境中,CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时,通常会导致应用响应缓慢,甚至服务不可用,严重影响用户体验和业务运行。因此,掌握一套科学有效的CPU飙高问题排查方法&…...