使用pytorch进行迁移学习的两个步骤
1. 步骤及代码
迁移学习一般都会使用两个步骤进行训练:
- 固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛;
- 使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节。
import os
import torch
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms as T
import numpy as np# 设置均值、方差
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]# 还原减均值除以方差之前的数据,用于可视化
def reduction_img_show(tensor, mean, std) -> None:to_img = T.ToPILImage()reduced_img = to_img(tensor * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1))reduced_img.show()def getResNet(*, class_names: str, loadfile: str = None):if loadfile is not None:model = torchvision.models.resnet18()model.load_state_dict(torch.load('resnet18-f37072fd.pth')) # 加载权重else:model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1) # 模型自动下载到C:\Users\GaryLau\.cache\torch\hub\checkpoints# 将所有的参数层冻结,设置模型除最后一层以外都不可以进行训练,使模型只针对最后一层进行微调for param in model.parameters():param.requires_grad = False# 输出全连接层信息print(model.fc)x = model.fc.in_features # 获取全连接层输入维度model.fc = torch.nn.Linear(in_features=x, out_features=len(class_names)) # 创建新的全连接层print(model.fc) # 输出新的全连接层return model# 定义训练函数
def train(model, device, train_loader, criterion, optimizer, epoch):model.train()all_loss = []for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()y_pred = model(data)loss = criterion(y_pred, target)loss.backward()all_loss.append(loss.item())optimizer.step()if batch_idx % 10 == 0:print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),np.mean(all_loss)))def val(model, device, val_loader, criterion):model.eval()test_loss = []correct = []with torch.no_grad():for data, target in val_loader:data, target = data.to(device), target.to(device)y_pred = model(data)test_loss.append(criterion(y_pred, target).item())pred = y_pred.argmax(dim=1, keepdim=True)correct.append(pred.eq(target.view_as(pred)).sum().item()/pred.size(0))print('-->Test: Average loss:{:.4f}, Accuracy:({:.0f}%)\n'.format(np.mean(test_loss), 100 * sum(correct) / len(correct)))# 训练,验证时的预处理
transform = {'train': T.Compose([T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean=mean, std=std)]),'val': T.Compose([T.Resize((224,224)),T.ToTensor(),T.Normalize(mean=mean, std=std)])}# 加载训练、验证数据
dataset_train = ImageFolder(r'./train', transform=transform['train'])
dataset_val = ImageFolder(r'./test', transform=transform['val'])# 类别标签
class_names = dataset_train.classes
print(dataset_train.class_to_idx)
print(dataset_val.class_to_idx)# 显示一张训练、验证图
# reduction_img_show(dataset_train[0][0], mean, std)
# reduction_img_show(dataset_val[0][0], mean, std)# 使用DataLoader遍历数据
dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True, sampler=None, num_workers=0,pin_memory=False, drop_last=False)
dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=False, sampler=None, num_workers=0,pin_memory=False, drop_last=False)# 使用方式一,使用next不断获取一个batch的数据
dataiter_train = iter(dataloader_train)
imgs, labels = next(dataiter_train)
print(imgs.size())
# reduction_img_show(imgs[0], mean, std)
# reduction_img_show(imgs[1], mean, std)
multi_imgs = torchvision.utils.make_grid(imgs, nrow=10) # 拼接一个batch的图像用于展示
# reduction_img_show(multi_imgs, mean, std)# 获取ResNet模型,并加载预训练模型权重,将最后一层(输出层)去掉,换成一个新的全连接层,新全连接层输出的节点数是新数据的类别数
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)# 构建模型
model = getResNet(class_names=class_names, loadfile='resnet18-f37072fd.pth')
model.to(device)# 构建损失函数
criterion = torch.nn.CrossEntropyLoss()
# 指定新加的全连接层为要更新的参数
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001) # 只需要更新最后一层fc的参数if __name__ == '__main__':### 步骤一,微调最后一层first_model = 'resnet18-f37072fd_finetune_fcLayer.pth'for epoch in range(1, 6):train(model, device, dataloader_train, criterion, optimizer, epoch)val(model, device, dataloader_val, criterion)# 仅保存了最后新添加的全连接层的参数#torch.save(model.fc.state_dict(), first_model)torch.save(model.state_dict(), first_model)### 步骤二,小学习率微调所有层second_model = 'resnet18-f37072fd_finetune_allLayer.pth'optimizer2 = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=3, gamma=0.9)# 将所有的参数层设为可训练的for param in model.parameters():param.requires_grad = Trueif os.path.exists(second_model):model.load_state_dict(torch.load(second_model)) # 加载本地模型else:model.load_state_dict(torch.load(first_model)) # 加载步骤一训练得到的本地模型print('Finetune all layers with small learning rate......')for epoch in range(1, 101):train(model, device, dataloader_train, criterion, optimizer2, epoch)if optimizer2.state_dict()['param_groups'][0]['lr'] > 0.00001:exp_lr_scheduler.step()print(f"learning rate: {optimizer2.state_dict()['param_groups'][0]['lr']}")val(model, device, dataloader_val, criterion)# 保存整个模型torch.save(model.state_dict(), second_model)print('Done.')
2. 完整资源
https://download.csdn.net/download/liugan528/89833913
相关文章:
使用pytorch进行迁移学习的两个步骤
1. 步骤及代码 迁移学习一般都会使用两个步骤进行训练: 固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛;使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节。 impor…...
ChatGPT相关参数示例
max_token 用于控制最大输出长度,若ChatGPT的回复大于max_tokens,则对输出结果进行截断。 from openai import OpenAI client OpenAI(base_url"https://api.chatanywhere.tech/v1" ) response client.chat.completions.create(model"…...

OWASP发布大模型安全风险与应对策略(QA测试重点关注)
开放式 Web 应用程序安全项目(OWASP)发布了关于大模型应用的安全风险,这些风险不仅包括传统的沙盒逃逸、代码执行和鉴权不当等安全风险,还涉及提示注入、对话数据泄露和数据投毒等人工智能特有的安全风险。 帮助开发者和测试同学更…...

【HarmonyOS开发笔记 2 】 -- ArkTS语法中的变量与常量
ArkTS是HarmonyOS开发的编程语言 ArkTS语法中的变量 【语法格式】: let 变量名: 类型 值 let:是定义变量的关键字类型: 值数据类型, 常用的数据类型 字符型(string)、数字型(number…...
UI自动化测试示例:python+pytest+selenium+allure
重点应用是封装、参数化: 比如在lib文件夹下,要存储封装好的方法和必要的环境变量(指网址等) 1.cfg.py:封装网址和对应的页面 SMP_ADDRESS http://127.0.0.1:8234SMP_URL_LOGIN f{SMP_ADDRESS}/login.html SMP_URL_DE…...
C/C++ 编程小工具
编写了 tools.h 和 tools.cpp,用于 Debug、性能测试、打印日志。 tools.h #ifndef TOOLS_H #define TOOLS_H#include <time.h> #include <fstream> #include <iostream> #include <random> #include <chrono> #include <vector&…...
第四十二章 使用 WS-ReliableMessaging
文章目录 第四十二章 使用 WS-ReliableMessaging从 Web 客户端发送一系列消息 第四十二章 使用 WS-ReliableMessaging IRIS 支持 WS-ReliableMessaging 规范的部分内容,如简介中所述。此规范提供了一种按顺序可靠地传递一系列消息的机制。本页介绍如何手动使用可靠…...

利士策分享,婚姻为何被称为大事?
利士策分享,婚姻为何被称为大事? 在历史的长河中,婚姻一直被视为人生中的头等大事,这一观念跨越时空,深深植根于各种文化和社会结构中。 古人为何将婚姻称为“大事”,这背后蕴含着丰富的社会、文化和心理寓…...

malloc源码分析之 ----- 你想要啥chunk
文章目录 malloc源码分析之 ----- 你想要啥chunktcachefastbinsmall binunsorted binbin处理top malloc源码分析之 ----- 你想要啥chunk tcache malloc源码,这里以glibc-2.29为例: void * __libc_malloc (size_t bytes) {mstate ar_ptr;void *victim;vo…...

软考系统分析师知识点五:数据通信与计算机网络
前言 今年报考了11月份的软考高级:系统分析师。 考试时间为:11月9日。 倒计时:32天。 目标:优先应试,其次学习,再次实践。 复习计划第一阶段:扫平基础知识点,仅抽取有用信息&am…...

windows客户端SSH连接ubuntu/linux服务器,三种网络连接:局域网,内网穿透(sakuraftp),虚拟局域网(zerotier)
windows客户端SSH连接ubuntu/linux服务器,三种网络连接:局域网,内网穿透(sakuraftp),虚拟局域网(zerotier) 目录 SSH简述、三种网络连接特点SSH简述局域网内连接内网穿透(…...

Python 工具库每日推荐【openpyxl 】
文章目录 引言Python Excel 处理库的重要性今日推荐:openpyxl 工具库主要功能:使用场景:安装与配置快速上手示例代码代码解释实际应用案例案例:自动生成月度销售报告案例分析高级特性条件格式数据验证扩展阅读与资源优缺点分析优点:缺点:总结【 已更新完 TypeScript 设计…...

本地生活服务项目入局方案解析!本地生活服务商系统能实现怎样的作业效果?
当前,各大平台的本地生活服务业务日渐兴盛,提高创业者入局意向的同时,也让本地生活服务项目有哪些等问题也成为了多个创业者社群中的热议对象。而从目前的讨论情况来看,在创业者们所询问的众多本地生活服务项目中,通过…...

ML 系列:【13 】— Logistic 回归(第 2 部分)
文章目录 一、说明二、挤压方法三、Logistic 回归中的损失函数四、后记 一、说明 在这篇文章中,我们将深入研究 squashing 方法,这是有符号距离方法(第 12节)的一种很有前途的替代方案。squashing 方法通过提供增强的对异常值…...

45岁被裁员的程序员,何去何从?
在当今快速变化的技术行业,职业生涯的稳定性受到挑战。在45岁被裁员,对很多程序员来说,可能是一种惊慌失措的体验。然而,这个阶段也可以被视为一个重新审视和调整方向的机会。本文将对可能的出路进行全方位的分析,并提…...

云计算Openstack Neutron
OpenStack Neutron是OpenStack云计算平台中的网络服务组件,它为OpenStack提供了强大的网络连接功能。 一、基本概念 Neutron是一个网络服务项目,旨在为OpenStack提供网络连接。它允许用户创建和管理虚拟网络,包括子网、路由、安全组等&…...

PointNet++网络详解
数据集转换 数据集转换的意义在于将原本的 txt 点云文件转换为更方便运算的npy点云文件,同时,将原本的xyzrgb这 6 个维度转换为xyzrgbc,最后一个c维度代表该点云所属的类别。 for anno_path in anno_paths:print(anno_path)try:elements a…...

Java | Leetcode Java题解之第459题重复的子字符串
题目: 题解: class Solution {public boolean repeatedSubstringPattern(String s) {return kmp(s s, s);}public boolean kmp(String query, String pattern) {int n query.length();int m pattern.length();int[] fail new int[m];Arrays.fill(fa…...
【动态规划-最长公共子序列(LCS)】【hard】力扣1092. 最短公共超序列
给你两个字符串 str1 和 str2,返回同时以 str1 和 str2 作为 子序列 的最短字符串。如果答案不止一个,则可以返回满足条件的 任意一个 答案。 如果从字符串 t 中删除一些字符(也可能不删除),可以得到字符串 s &#x…...

图片编辑为底片,智能工具助力,创作精彩视觉作品
在当今数字化时代,图像编辑已成为表达创意和美化视觉作品的重要手段。借助智能工具,即使是初学者也能轻松驾驭图片编辑。接下为大家展示图片编辑为底片图片的效果。 1.打开“首助编辑高手”,选择这里“图片批量处理”版块页面上 2.导入保存有…...

龙虎榜——20250610
上证指数放量收阴线,个股多数下跌,盘中受消息影响大幅波动。 深证指数放量收阴线形成顶分型,指数短线有调整的需求,大概需要一两天。 2025年6月10日龙虎榜行业方向分析 1. 金融科技 代表标的:御银股份、雄帝科技 驱动…...

Lombok 的 @Data 注解失效,未生成 getter/setter 方法引发的HTTP 406 错误
HTTP 状态码 406 (Not Acceptable) 和 500 (Internal Server Error) 是两类完全不同的错误,它们的含义、原因和解决方法都有显著区别。以下是详细对比: 1. HTTP 406 (Not Acceptable) 含义: 客户端请求的内容类型与服务器支持的内容类型不匹…...

ESP32读取DHT11温湿度数据
芯片:ESP32 环境:Arduino 一、安装DHT11传感器库 红框的库,别安装错了 二、代码 注意,DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility
Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...
拉力测试cuda pytorch 把 4070显卡拉满
import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试,通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小,增大可提高计算复杂度duration: 测试持续时间(秒&…...

AI病理诊断七剑下天山,医疗未来触手可及
一、病理诊断困局:刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断",医生需通过显微镜观察组织切片,在细胞迷宫中捕捉癌变信号。某省病理质控报告显示,基层医院误诊率达12%-15%,专家会诊…...

若依登录用户名和密码加密
/*** 获取公钥:前端用来密码加密* return*/GetMapping("/getPublicKey")public RSAUtil.RSAKeyPair getPublicKey() {return RSAUtil.rsaKeyPair();}新建RSAUti.Java package com.ruoyi.common.utils;import org.apache.commons.codec.binary.Base64; im…...

PH热榜 | 2025-06-08
1. Thiings 标语:一套超过1900个免费AI生成的3D图标集合 介绍:Thiings是一个不断扩展的免费AI生成3D图标库,目前已有超过1900个图标。你可以按照主题浏览,生成自己的图标,或者下载整个图标集。所有图标都可以在个人或…...

Linux-进程间的通信
1、IPC: Inter Process Communication(进程间通信): 由于每个进程在操作系统中有独立的地址空间,它们不能像线程那样直接访问彼此的内存,所以必须通过某种方式进行通信。 常见的 IPC 方式包括&#…...

Tauri2学习笔记
教程地址:https://www.bilibili.com/video/BV1Ca411N7mF?spm_id_from333.788.player.switch&vd_source707ec8983cc32e6e065d5496a7f79ee6 官方指引:https://tauri.app/zh-cn/start/ 目前Tauri2的教程视频不多,我按照Tauri1的教程来学习&…...