使用pytorch进行迁移学习的两个步骤
1. 步骤及代码
迁移学习一般都会使用两个步骤进行训练:
- 固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛;
- 使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节。
import os
import torch
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms as T
import numpy as np# 设置均值、方差
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]# 还原减均值除以方差之前的数据,用于可视化
def reduction_img_show(tensor, mean, std) -> None:to_img = T.ToPILImage()reduced_img = to_img(tensor * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1))reduced_img.show()def getResNet(*, class_names: str, loadfile: str = None):if loadfile is not None:model = torchvision.models.resnet18()model.load_state_dict(torch.load('resnet18-f37072fd.pth')) # 加载权重else:model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1) # 模型自动下载到C:\Users\GaryLau\.cache\torch\hub\checkpoints# 将所有的参数层冻结,设置模型除最后一层以外都不可以进行训练,使模型只针对最后一层进行微调for param in model.parameters():param.requires_grad = False# 输出全连接层信息print(model.fc)x = model.fc.in_features # 获取全连接层输入维度model.fc = torch.nn.Linear(in_features=x, out_features=len(class_names)) # 创建新的全连接层print(model.fc) # 输出新的全连接层return model# 定义训练函数
def train(model, device, train_loader, criterion, optimizer, epoch):model.train()all_loss = []for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()y_pred = model(data)loss = criterion(y_pred, target)loss.backward()all_loss.append(loss.item())optimizer.step()if batch_idx % 10 == 0:print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),np.mean(all_loss)))def val(model, device, val_loader, criterion):model.eval()test_loss = []correct = []with torch.no_grad():for data, target in val_loader:data, target = data.to(device), target.to(device)y_pred = model(data)test_loss.append(criterion(y_pred, target).item())pred = y_pred.argmax(dim=1, keepdim=True)correct.append(pred.eq(target.view_as(pred)).sum().item()/pred.size(0))print('-->Test: Average loss:{:.4f}, Accuracy:({:.0f}%)\n'.format(np.mean(test_loss), 100 * sum(correct) / len(correct)))# 训练,验证时的预处理
transform = {'train': T.Compose([T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean=mean, std=std)]),'val': T.Compose([T.Resize((224,224)),T.ToTensor(),T.Normalize(mean=mean, std=std)])}# 加载训练、验证数据
dataset_train = ImageFolder(r'./train', transform=transform['train'])
dataset_val = ImageFolder(r'./test', transform=transform['val'])# 类别标签
class_names = dataset_train.classes
print(dataset_train.class_to_idx)
print(dataset_val.class_to_idx)# 显示一张训练、验证图
# reduction_img_show(dataset_train[0][0], mean, std)
# reduction_img_show(dataset_val[0][0], mean, std)# 使用DataLoader遍历数据
dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True, sampler=None, num_workers=0,pin_memory=False, drop_last=False)
dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=False, sampler=None, num_workers=0,pin_memory=False, drop_last=False)# 使用方式一,使用next不断获取一个batch的数据
dataiter_train = iter(dataloader_train)
imgs, labels = next(dataiter_train)
print(imgs.size())
# reduction_img_show(imgs[0], mean, std)
# reduction_img_show(imgs[1], mean, std)
multi_imgs = torchvision.utils.make_grid(imgs, nrow=10) # 拼接一个batch的图像用于展示
# reduction_img_show(multi_imgs, mean, std)# 获取ResNet模型,并加载预训练模型权重,将最后一层(输出层)去掉,换成一个新的全连接层,新全连接层输出的节点数是新数据的类别数
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)# 构建模型
model = getResNet(class_names=class_names, loadfile='resnet18-f37072fd.pth')
model.to(device)# 构建损失函数
criterion = torch.nn.CrossEntropyLoss()
# 指定新加的全连接层为要更新的参数
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001) # 只需要更新最后一层fc的参数if __name__ == '__main__':### 步骤一,微调最后一层first_model = 'resnet18-f37072fd_finetune_fcLayer.pth'for epoch in range(1, 6):train(model, device, dataloader_train, criterion, optimizer, epoch)val(model, device, dataloader_val, criterion)# 仅保存了最后新添加的全连接层的参数#torch.save(model.fc.state_dict(), first_model)torch.save(model.state_dict(), first_model)### 步骤二,小学习率微调所有层second_model = 'resnet18-f37072fd_finetune_allLayer.pth'optimizer2 = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=3, gamma=0.9)# 将所有的参数层设为可训练的for param in model.parameters():param.requires_grad = Trueif os.path.exists(second_model):model.load_state_dict(torch.load(second_model)) # 加载本地模型else:model.load_state_dict(torch.load(first_model)) # 加载步骤一训练得到的本地模型print('Finetune all layers with small learning rate......')for epoch in range(1, 101):train(model, device, dataloader_train, criterion, optimizer2, epoch)if optimizer2.state_dict()['param_groups'][0]['lr'] > 0.00001:exp_lr_scheduler.step()print(f"learning rate: {optimizer2.state_dict()['param_groups'][0]['lr']}")val(model, device, dataloader_val, criterion)# 保存整个模型torch.save(model.state_dict(), second_model)print('Done.')
2. 完整资源
https://download.csdn.net/download/liugan528/89833913
相关文章:
使用pytorch进行迁移学习的两个步骤
1. 步骤及代码 迁移学习一般都会使用两个步骤进行训练: 固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛;使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节。 impor…...
ChatGPT相关参数示例
max_token 用于控制最大输出长度,若ChatGPT的回复大于max_tokens,则对输出结果进行截断。 from openai import OpenAI client OpenAI(base_url"https://api.chatanywhere.tech/v1" ) response client.chat.completions.create(model"…...
OWASP发布大模型安全风险与应对策略(QA测试重点关注)
开放式 Web 应用程序安全项目(OWASP)发布了关于大模型应用的安全风险,这些风险不仅包括传统的沙盒逃逸、代码执行和鉴权不当等安全风险,还涉及提示注入、对话数据泄露和数据投毒等人工智能特有的安全风险。 帮助开发者和测试同学更…...
【HarmonyOS开发笔记 2 】 -- ArkTS语法中的变量与常量
ArkTS是HarmonyOS开发的编程语言 ArkTS语法中的变量 【语法格式】: let 变量名: 类型 值 let:是定义变量的关键字类型: 值数据类型, 常用的数据类型 字符型(string)、数字型(number…...
UI自动化测试示例:python+pytest+selenium+allure
重点应用是封装、参数化: 比如在lib文件夹下,要存储封装好的方法和必要的环境变量(指网址等) 1.cfg.py:封装网址和对应的页面 SMP_ADDRESS http://127.0.0.1:8234SMP_URL_LOGIN f{SMP_ADDRESS}/login.html SMP_URL_DE…...
C/C++ 编程小工具
编写了 tools.h 和 tools.cpp,用于 Debug、性能测试、打印日志。 tools.h #ifndef TOOLS_H #define TOOLS_H#include <time.h> #include <fstream> #include <iostream> #include <random> #include <chrono> #include <vector&…...
第四十二章 使用 WS-ReliableMessaging
文章目录 第四十二章 使用 WS-ReliableMessaging从 Web 客户端发送一系列消息 第四十二章 使用 WS-ReliableMessaging IRIS 支持 WS-ReliableMessaging 规范的部分内容,如简介中所述。此规范提供了一种按顺序可靠地传递一系列消息的机制。本页介绍如何手动使用可靠…...
利士策分享,婚姻为何被称为大事?
利士策分享,婚姻为何被称为大事? 在历史的长河中,婚姻一直被视为人生中的头等大事,这一观念跨越时空,深深植根于各种文化和社会结构中。 古人为何将婚姻称为“大事”,这背后蕴含着丰富的社会、文化和心理寓…...
malloc源码分析之 ----- 你想要啥chunk
文章目录 malloc源码分析之 ----- 你想要啥chunktcachefastbinsmall binunsorted binbin处理top malloc源码分析之 ----- 你想要啥chunk tcache malloc源码,这里以glibc-2.29为例: void * __libc_malloc (size_t bytes) {mstate ar_ptr;void *victim;vo…...
软考系统分析师知识点五:数据通信与计算机网络
前言 今年报考了11月份的软考高级:系统分析师。 考试时间为:11月9日。 倒计时:32天。 目标:优先应试,其次学习,再次实践。 复习计划第一阶段:扫平基础知识点,仅抽取有用信息&am…...
windows客户端SSH连接ubuntu/linux服务器,三种网络连接:局域网,内网穿透(sakuraftp),虚拟局域网(zerotier)
windows客户端SSH连接ubuntu/linux服务器,三种网络连接:局域网,内网穿透(sakuraftp),虚拟局域网(zerotier) 目录 SSH简述、三种网络连接特点SSH简述局域网内连接内网穿透(…...
Python 工具库每日推荐【openpyxl 】
文章目录 引言Python Excel 处理库的重要性今日推荐:openpyxl 工具库主要功能:使用场景:安装与配置快速上手示例代码代码解释实际应用案例案例:自动生成月度销售报告案例分析高级特性条件格式数据验证扩展阅读与资源优缺点分析优点:缺点:总结【 已更新完 TypeScript 设计…...
本地生活服务项目入局方案解析!本地生活服务商系统能实现怎样的作业效果?
当前,各大平台的本地生活服务业务日渐兴盛,提高创业者入局意向的同时,也让本地生活服务项目有哪些等问题也成为了多个创业者社群中的热议对象。而从目前的讨论情况来看,在创业者们所询问的众多本地生活服务项目中,通过…...
ML 系列:【13 】— Logistic 回归(第 2 部分)
文章目录 一、说明二、挤压方法三、Logistic 回归中的损失函数四、后记 一、说明 在这篇文章中,我们将深入研究 squashing 方法,这是有符号距离方法(第 12节)的一种很有前途的替代方案。squashing 方法通过提供增强的对异常值…...
45岁被裁员的程序员,何去何从?
在当今快速变化的技术行业,职业生涯的稳定性受到挑战。在45岁被裁员,对很多程序员来说,可能是一种惊慌失措的体验。然而,这个阶段也可以被视为一个重新审视和调整方向的机会。本文将对可能的出路进行全方位的分析,并提…...
云计算Openstack Neutron
OpenStack Neutron是OpenStack云计算平台中的网络服务组件,它为OpenStack提供了强大的网络连接功能。 一、基本概念 Neutron是一个网络服务项目,旨在为OpenStack提供网络连接。它允许用户创建和管理虚拟网络,包括子网、路由、安全组等&…...
PointNet++网络详解
数据集转换 数据集转换的意义在于将原本的 txt 点云文件转换为更方便运算的npy点云文件,同时,将原本的xyzrgb这 6 个维度转换为xyzrgbc,最后一个c维度代表该点云所属的类别。 for anno_path in anno_paths:print(anno_path)try:elements a…...
Java | Leetcode Java题解之第459题重复的子字符串
题目: 题解: class Solution {public boolean repeatedSubstringPattern(String s) {return kmp(s s, s);}public boolean kmp(String query, String pattern) {int n query.length();int m pattern.length();int[] fail new int[m];Arrays.fill(fa…...
【动态规划-最长公共子序列(LCS)】【hard】力扣1092. 最短公共超序列
给你两个字符串 str1 和 str2,返回同时以 str1 和 str2 作为 子序列 的最短字符串。如果答案不止一个,则可以返回满足条件的 任意一个 答案。 如果从字符串 t 中删除一些字符(也可能不删除),可以得到字符串 s &#x…...
图片编辑为底片,智能工具助力,创作精彩视觉作品
在当今数字化时代,图像编辑已成为表达创意和美化视觉作品的重要手段。借助智能工具,即使是初学者也能轻松驾驭图片编辑。接下为大家展示图片编辑为底片图片的效果。 1.打开“首助编辑高手”,选择这里“图片批量处理”版块页面上 2.导入保存有…...
第五章作业
233817310313 文章目录图1:单位数码管显示7图2:单位数码管轮播0-9图3:6位数码管显示9图1:单位数码管显示7 #include <reg52.h>#define uchar unsigned char #define uint unsigned int// 定义锁存器控制引脚 sbit LE P2^7;…...
OpenClaw社交媒体管理:Gemma-3-12b-it自动回复评论与生成周报
OpenClaw社交媒体管理:Gemma-3-12b-it自动回复评论与生成周报 1. 为什么选择OpenClaw管理社交媒体 去年运营个人技术账号时,我每天要花1小时手动回复评论和整理周报。直到发现OpenClaw这个开源自动化框架,配合Gemma-3-12b-it模型࿰…...
M24LR64E-R双接口NFC标签驱动与嵌入式集成指南
1. 项目概述NFC Tag M24LR6E 是一款面向嵌入式系统的 Arduino 兼容库,专为驱动 Seeed Studio 推出的 Grove - NFC Tag 模块而设计。该模块核心芯片为 STMicroelectronics 的 M24LR64E-R,是一款高度集成的双接口(IC RF)近场通信标…...
解锁Windows效率提升:免费工具Winhance-zh_CN全功能指南
解锁Windows效率提升:免费工具Winhance-zh_CN全功能指南 【免费下载链接】Winhance-zh_CN A Chinese version of Winhance. C# application designed to optimize and customize your Windows experience. 项目地址: https://gitcode.com/gh_mirrors/wi/Winhance-…...
效率倍增:用快马AI生成服务器批量管理工具,告别重复劳动
最近在团队里负责服务器运维工作,经常需要同时管理几十台服务器。每次登录、执行重复命令、检查状态都要耗费大量时间,直到发现了用InsCode(快马)平台快速搭建批量管理工具的方法,效率直接翻倍。今天就把这个自动化管理方案分享给大家。 痛点…...
2026年最好的AI创业机会,就藏在你压根看不上的角落里
还在焦虑AI会替代你?抢你饭碗?你根本不知道,现在有一群人,正在用AI给自己“印钞票”他们不是搞什么ChatGPT插件,也不是训练大模型,他们就盯着那些看着不起眼,甚至你压根看不上的小事。利用这些小…...
基于深度学习的FasterRCNN水下图像复原
项目概述:Waternet_FasterRCNN 本项目旨在结合深度学习技术进行水下图像的还原与分析,综合应用 WaterNet 和 Faster R-CNN 来完成以下功能: 水下图像还原:利用 WaterNet 修复和增强水下图像质量。色板检测与提取:通过 …...
无线网络实战:从零配置AP与SSID,打通设备互联
1. 无线网络基础概念扫盲 刚接触无线网络时,我经常被各种专业术语搞得晕头转向。其实搭建一个简单的办公网络并不复杂,我们先来理清几个关键概念。**AP(接入点)**就像无线网络中的"信号中转站",负责把有线网…...
3大核心功能突破JSON可视化难题:vue-json-pretty革新前端数据展示体验
3大核心功能突破JSON可视化难题:vue-json-pretty革新前端数据展示体验 【免费下载链接】vue-json-pretty A JSON tree view component that is easy to use and also supports data selection. 项目地址: https://gitcode.com/gh_mirrors/vu/vue-json-pretty …...
Zynq Linux FPGA Manager实战:5分钟搞定PL配置(含bit转bin避坑指南)
Zynq Linux FPGA Manager实战:5分钟搞定PL配置(含bit转bin避坑指南) 第一次在Zynq开发板上尝试配置PL逻辑时,我盯着Vivado生成的.bit文件发愁——官方文档里提到的PCAP、ICAP协议像天书一样,而网上各种教程要么步骤不全…...
