当前位置: 首页 > news >正文

使用pytorch进行迁移学习的两个步骤

1. 步骤及代码

迁移学习一般都会使用两个步骤进行训练:

  1. 固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛;
  2. 使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节。
import os
import torch
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms as T
import numpy as np# 设置均值、方差
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]# 还原减均值除以方差之前的数据,用于可视化
def reduction_img_show(tensor, mean, std) -> None:to_img = T.ToPILImage()reduced_img = to_img(tensor * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1))reduced_img.show()def getResNet(*, class_names: str, loadfile: str = None):if loadfile is not None:model = torchvision.models.resnet18()model.load_state_dict(torch.load('resnet18-f37072fd.pth'))  # 加载权重else:model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)  # 模型自动下载到C:\Users\GaryLau\.cache\torch\hub\checkpoints# 将所有的参数层冻结,设置模型除最后一层以外都不可以进行训练,使模型只针对最后一层进行微调for param in model.parameters():param.requires_grad = False# 输出全连接层信息print(model.fc)x = model.fc.in_features  # 获取全连接层输入维度model.fc = torch.nn.Linear(in_features=x, out_features=len(class_names))  # 创建新的全连接层print(model.fc)  # 输出新的全连接层return model# 定义训练函数
def train(model, device, train_loader, criterion, optimizer, epoch):model.train()all_loss = []for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()y_pred = model(data)loss = criterion(y_pred, target)loss.backward()all_loss.append(loss.item())optimizer.step()if batch_idx % 10 == 0:print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),np.mean(all_loss)))def val(model, device, val_loader, criterion):model.eval()test_loss = []correct = []with torch.no_grad():for data, target in val_loader:data, target = data.to(device), target.to(device)y_pred = model(data)test_loss.append(criterion(y_pred, target).item())pred = y_pred.argmax(dim=1, keepdim=True)correct.append(pred.eq(target.view_as(pred)).sum().item()/pred.size(0))print('-->Test: Average loss:{:.4f}, Accuracy:({:.0f}%)\n'.format(np.mean(test_loss), 100 * sum(correct) / len(correct)))# 训练,验证时的预处理
transform = {'train': T.Compose([T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),T.Normalize(mean=mean, std=std)]),'val': T.Compose([T.Resize((224,224)),T.ToTensor(),T.Normalize(mean=mean, std=std)])}# 加载训练、验证数据
dataset_train = ImageFolder(r'./train', transform=transform['train'])
dataset_val = ImageFolder(r'./test', transform=transform['val'])# 类别标签
class_names = dataset_train.classes
print(dataset_train.class_to_idx)
print(dataset_val.class_to_idx)# 显示一张训练、验证图
# reduction_img_show(dataset_train[0][0], mean, std)
# reduction_img_show(dataset_val[0][0], mean, std)# 使用DataLoader遍历数据
dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True, sampler=None, num_workers=0,pin_memory=False, drop_last=False)
dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=False, sampler=None, num_workers=0,pin_memory=False, drop_last=False)# 使用方式一,使用next不断获取一个batch的数据
dataiter_train = iter(dataloader_train)
imgs, labels = next(dataiter_train)
print(imgs.size())
# reduction_img_show(imgs[0], mean, std)
# reduction_img_show(imgs[1], mean, std)
multi_imgs = torchvision.utils.make_grid(imgs, nrow=10)  # 拼接一个batch的图像用于展示
# reduction_img_show(multi_imgs, mean, std)# 获取ResNet模型,并加载预训练模型权重,将最后一层(输出层)去掉,换成一个新的全连接层,新全连接层输出的节点数是新数据的类别数
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)# 构建模型
model = getResNet(class_names=class_names, loadfile='resnet18-f37072fd.pth')
model.to(device)# 构建损失函数
criterion = torch.nn.CrossEntropyLoss()
# 指定新加的全连接层为要更新的参数
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)  # 只需要更新最后一层fc的参数if __name__ == '__main__':### 步骤一,微调最后一层first_model = 'resnet18-f37072fd_finetune_fcLayer.pth'for epoch in range(1, 6):train(model, device, dataloader_train, criterion, optimizer, epoch)val(model, device, dataloader_val, criterion)# 仅保存了最后新添加的全连接层的参数#torch.save(model.fc.state_dict(), first_model)torch.save(model.state_dict(), first_model)### 步骤二,小学习率微调所有层second_model = 'resnet18-f37072fd_finetune_allLayer.pth'optimizer2 = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=3, gamma=0.9)# 将所有的参数层设为可训练的for param in model.parameters():param.requires_grad = Trueif os.path.exists(second_model):model.load_state_dict(torch.load(second_model))   # 加载本地模型else:model.load_state_dict(torch.load(first_model))    # 加载步骤一训练得到的本地模型print('Finetune all layers with small learning rate......')for epoch in range(1, 101):train(model, device, dataloader_train, criterion, optimizer2, epoch)if optimizer2.state_dict()['param_groups'][0]['lr'] > 0.00001:exp_lr_scheduler.step()print(f"learning rate: {optimizer2.state_dict()['param_groups'][0]['lr']}")val(model, device, dataloader_val, criterion)# 保存整个模型torch.save(model.state_dict(), second_model)print('Done.')

2. 完整资源

https://download.csdn.net/download/liugan528/89833913

相关文章:

使用pytorch进行迁移学习的两个步骤

1. 步骤及代码 迁移学习一般都会使用两个步骤进行训练: 固定预训练模型的特征提取部分,只对最后一层进行训练,使其快速收敛;使用较小的学习率,对全部模型进行训练,并对每层的权重进行细微的调节。 impor…...

ChatGPT相关参数示例

max_token 用于控制最大输出长度,若ChatGPT的回复大于max_tokens,则对输出结果进行截断。 from openai import OpenAI client OpenAI(base_url"https://api.chatanywhere.tech/v1" ) response client.chat.completions.create(model"…...

OWASP发布大模型安全风险与应对策略(QA测试重点关注)

开放式 Web 应用程序安全项目(OWASP)发布了关于大模型应用的安全风险,这些风险不仅包括传统的沙盒逃逸、代码执行和鉴权不当等安全风险,还涉及提示注入、对话数据泄露和数据投毒等人工智能特有的安全风险。 帮助开发者和测试同学更…...

【HarmonyOS开发笔记 2 】 -- ArkTS语法中的变量与常量

ArkTS是HarmonyOS开发的编程语言 ArkTS语法中的变量 【语法格式】: let 变量名: 类型 值 let:是定义变量的关键字类型: 值数据类型, 常用的数据类型 字符型(string)、数字型(number&#xf…...

UI自动化测试示例:python+pytest+selenium+allure

重点应用是封装、参数化: 比如在lib文件夹下,要存储封装好的方法和必要的环境变量(指网址等) 1.cfg.py:封装网址和对应的页面 SMP_ADDRESS http://127.0.0.1:8234SMP_URL_LOGIN f{SMP_ADDRESS}/login.html SMP_URL_DE…...

C/C++ 编程小工具

编写了 tools.h 和 tools.cpp&#xff0c;用于 Debug、性能测试、打印日志。 tools.h #ifndef TOOLS_H #define TOOLS_H#include <time.h> #include <fstream> #include <iostream> #include <random> #include <chrono> #include <vector&…...

第四十二章 使用 WS-ReliableMessaging

文章目录 第四十二章 使用 WS-ReliableMessaging从 Web 客户端发送一系列消息 第四十二章 使用 WS-ReliableMessaging IRIS 支持 WS-ReliableMessaging 规范的部分内容&#xff0c;如简介中所述。此规范提供了一种按顺序可靠地传递一系列消息的机制。本页介绍如何手动使用可靠…...

利士策分享,婚姻为何被称为大事?

利士策分享&#xff0c;婚姻为何被称为大事&#xff1f; 在历史的长河中&#xff0c;婚姻一直被视为人生中的头等大事&#xff0c;这一观念跨越时空&#xff0c;深深植根于各种文化和社会结构中。 古人为何将婚姻称为“大事”&#xff0c;这背后蕴含着丰富的社会、文化和心理寓…...

malloc源码分析之 ----- 你想要啥chunk

文章目录 malloc源码分析之 ----- 你想要啥chunktcachefastbinsmall binunsorted binbin处理top malloc源码分析之 ----- 你想要啥chunk tcache malloc源码&#xff0c;这里以glibc-2.29为例&#xff1a; void * __libc_malloc (size_t bytes) {mstate ar_ptr;void *victim;vo…...

软考系统分析师知识点五:数据通信与计算机网络

前言 今年报考了11月份的软考高级&#xff1a;系统分析师。 考试时间为&#xff1a;11月9日。 倒计时&#xff1a;32天。 目标&#xff1a;优先应试&#xff0c;其次学习&#xff0c;再次实践。 复习计划第一阶段&#xff1a;扫平基础知识点&#xff0c;仅抽取有用信息&am…...

windows客户端SSH连接ubuntu/linux服务器,三种网络连接:局域网,内网穿透(sakuraftp),虚拟局域网(zerotier)

windows客户端SSH连接ubuntu/linux服务器&#xff0c;三种网络连接&#xff1a;局域网&#xff0c;内网穿透&#xff08;sakuraftp&#xff09;&#xff0c;虚拟局域网&#xff08;zerotier&#xff09; 目录 SSH简述、三种网络连接特点SSH简述局域网内连接内网穿透&#xff08…...

Python 工具库每日推荐【openpyxl 】

文章目录 引言Python Excel 处理库的重要性今日推荐:openpyxl 工具库主要功能:使用场景:安装与配置快速上手示例代码代码解释实际应用案例案例:自动生成月度销售报告案例分析高级特性条件格式数据验证扩展阅读与资源优缺点分析优点:缺点:总结【 已更新完 TypeScript 设计…...

本地生活服务项目入局方案解析!本地生活服务商系统能实现怎样的作业效果?

当前&#xff0c;各大平台的本地生活服务业务日渐兴盛&#xff0c;提高创业者入局意向的同时&#xff0c;也让本地生活服务项目有哪些等问题也成为了多个创业者社群中的热议对象。而从目前的讨论情况来看&#xff0c;在创业者们所询问的众多本地生活服务项目中&#xff0c;通过…...

ML 系列:【13 】— Logistic 回归(第 2 部分)

文章目录 一、说明二、挤压方法三、Logistic 回归中的损失函数四、后记 一、说明 ​ 在这篇文章中&#xff0c;我们将深入研究 squashing 方法&#xff0c;这是有符号距离方法&#xff08;第 12节&#xff09;的一种很有前途的替代方案。squashing 方法通过提供增强的对异常值…...

45岁被裁员的程序员,何去何从?

在当今快速变化的技术行业&#xff0c;职业生涯的稳定性受到挑战。在45岁被裁员&#xff0c;对很多程序员来说&#xff0c;可能是一种惊慌失措的体验。然而&#xff0c;这个阶段也可以被视为一个重新审视和调整方向的机会。本文将对可能的出路进行全方位的分析&#xff0c;并提…...

云计算Openstack Neutron

OpenStack Neutron是OpenStack云计算平台中的网络服务组件&#xff0c;它为OpenStack提供了强大的网络连接功能。 一、基本概念 Neutron是一个网络服务项目&#xff0c;旨在为OpenStack提供网络连接。它允许用户创建和管理虚拟网络&#xff0c;包括子网、路由、安全组等&…...

PointNet++网络详解

数据集转换 数据集转换的意义在于将原本的 txt 点云文件转换为更方便运算的npy点云文件&#xff0c;同时&#xff0c;将原本的xyzrgb这 6 个维度转换为xyzrgbc&#xff0c;最后一个c维度代表该点云所属的类别。 for anno_path in anno_paths:print(anno_path)try:elements a…...

Java | Leetcode Java题解之第459题重复的子字符串

题目&#xff1a; 题解&#xff1a; class Solution {public boolean repeatedSubstringPattern(String s) {return kmp(s s, s);}public boolean kmp(String query, String pattern) {int n query.length();int m pattern.length();int[] fail new int[m];Arrays.fill(fa…...

【动态规划-最长公共子序列(LCS)】【hard】力扣1092. 最短公共超序列

给你两个字符串 str1 和 str2&#xff0c;返回同时以 str1 和 str2 作为 子序列 的最短字符串。如果答案不止一个&#xff0c;则可以返回满足条件的 任意一个 答案。 如果从字符串 t 中删除一些字符&#xff08;也可能不删除&#xff09;&#xff0c;可以得到字符串 s &#x…...

‌图片编辑为底片,智能工具助力,创作精彩视觉作品

在当今数字化时代&#xff0c;图像编辑已成为表达创意和美化视觉作品的重要手段。借助智能工具&#xff0c;即使是初学者也能轻松驾驭图片编辑。接下为大家展示图片编辑为底片图片的效果。 1.打开“首助编辑高手”&#xff0c;选择这里“图片批量处理”版块页面上 2.导入保存有…...

机器学习/数据分析--用通俗语言讲解时间序列自回归(AR)模型,并用其预测天气,拟合度98%+

时间序列在回归预测的领域的重要性&#xff0c;不言而喻&#xff0c;在数学建模中使用及其频繁&#xff0c;但是你真的了解ARIMA、AR、MA么&#xff1f;ACF图你会看么&#xff1f;&#xff1f; 时间序列数据如何构造&#xff1f;&#xff1f;&#xff1f;&#xff0c;我打过不少…...

回溯算法之值子集和问题详细解读(附带Java代码解读)

子集和问题&#xff08;Subset Sum Problem&#xff09; 是一个经典的组合优化问题。问题可以这样描述&#xff1a; 给定一个整数集合和一个目标整数 target&#xff0c;我们需要从集合中选出若干个整数&#xff0c;使它们的和等于 target。如果这样的子集存在&#xff0c;返回…...

mysql游标的使用

说明&#xff1a; 虽然我们也可以通过筛选条件 WHERE 和 HAVING&#xff0c;或者是限定返回记录的关键字 LIMIT 返回一条记录&#xff0c;但是&#xff0c;却无法在结果集中像指针一样&#xff0c;向前定位一条记录、向后定位一条记录&#xff0c;或者是 随意定位到某一条记录 …...

linux udev详解

1.概念介绍 1.1sysfs文件系统 Linux 2.6以后的内核引入了sysfs文件系统&#xff0c;sysfs被看成是与proc、devfs和devpty同类别的文件系统&#xff0c;该文件系统是一个虚拟的文件系统&#xff0c;它可以产生一个包括所有系统硬件的层级视图&#xff0c;与提供进程和状态信息…...

EventSource和websocket该用哪种技术

EventSource&#xff08;也称为Server-Sent Events, SSE&#xff09;和WebSocket都是实现实时通信的技术&#xff0c;但是它们的设计目的和使用场景有所不同。在选择使用哪种技术时&#xff0c;需要根据具体的应用需求来决定。下面是一些关键点&#xff0c;可以帮助你做出选择&…...

通信工程学习:什么是三网融合

三网融合 三网融合&#xff0c;又称“三网合一”&#xff0c;是指电信网、广播电视网、互联网在高层业务应用上的深度融合。这一概念在近年来随着信息技术的快速发展而逐渐受到重视&#xff0c;并成为推动信息化社会建设的重要力量。以下是对三网融合的详细解释&#xff1a; 一…...

自定义类型结构体(上)

目录 结构体类型的声明结构体的概念结构体的声明特殊的声明结构的自引用 结构体变量的创建和初始化结构成员访问操作符 结构体类型的声明 结构体的概念 结构体是一些值的集合&#xff0c;这些值称为成员变量。结构的每个成员可以是不同类型的变量 举个例子:杰克的英语只考了6…...

b站-湖科大教书匠】4 网络层 - 计算机网络微课堂

【b站-湖科大教书匠】4 网络层 - 计算机网络微课堂_湖科大的计算机网络网课-CSDN博客...

国际 Android WPS Office v18.13 解锁版

WPS Office 移动版&#xff0c;设计不断优化&#xff0c;性能再次提升&#xff01;融入Google Android最新设计标准&#xff0c;Material Design设计风格&#xff0c;完美支持沉浸式&#xff01;简化文档操作&#xff0c;全新移动办公力作。全新界面更清晰舒适&#xff0c;功能…...

【中间件学习】Git的命令和企业级开发

一、Git命令 1.1 创建Git本地仓库 仓库是进行版本控制的一个文件目录。我们要想对文件进行版本控制&#xff0c;就必须创建出一个仓库出来。创建一个Git本地仓库对应的命令是 git init &#xff0c;注意命令要在文件目录下执行。 hrxlavm-1lzqn7w2w6:~/gitcode$ pwd /home/hr…...