【学习记录】pytorch载入模型的部分参数
需要从PointNet网络框架中提取encoder部分的参数,然后赋予自己的模型。因此,需要从一个已有的.pth文件读取部分参数,加载到自定义模型上面。做了一些尝试,记录如下。
关于模型保存与载入
torch.save(): 使用Python的pickle实用程序将对象进行序列化,然后将序列化的对象保存到disk,可以保存各种对象,包括模型、张量和字典等。
torch.load(): 使用pickle unpickle工具将pickle的对象文件反序列化为内存。
可以看出,pth文件本质上是一个序列化的dict。
我们在save时,代码如下:
state = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),
}
然后以下代码load进来:
checkpoint = torch.load(args.model_file, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
查看checkpoint,可以看到包含的就是自己保存时的3个dict,分别是epoch,model_state_dict,和optimizer信息。

这里我们重点关注 model_state_dict,数据类型是一个 OrderedDict,有序字典。展开如下:

可以看到里面包含了自己定义的encoder,bn1-3,mlp 1-4层,以及每个层对应的参数(权重、bias,对于bn层还有mean, var等)。
这个Dict的顺序就是在Model中我们定义的顺序,这个和模型是一致的。
因此,如果载入时的模型和保存模型完全一致,直接用load_state_dict()就可以按顺序把数据载入进来。但如,如果定义不同怎么办?这就需要手动载入。
方法1:手动载入指定层的参数
从debug的断点可以看到,每个参数就是存在dict中的一个tensor。因此,我们只要读取对应的dict即可。
例如,encoder的conv1的权重,就是 checkpoint['model_state_dict']['encoder.conv1.weight'],那么我们在自己的模型对应的位置读取这个dict即可。
具体载入方式如下:
# 定义模型
model = MyPointNetSegmentation(channel=3, get_feature=True, batch_size=1)
model.to('cpu')# 载入其他模型的参数
checkpoint = torch.load(model_file, map_location='cpu')
model_dict = checkpoint['model_state_dict']# 将其他模型的参数,赋值给自己模型对应参数
model.encoder.conv1.weight.data.copy_(model_dict['encoder.conv1.weight'])
model.encoder.conv1.bias.data.copy_(model_dict['encoder.conv1.bias'])
把所有有用的参数都赋值过来就好,但要注意参数对应的tensor维度是一样的。

方法2:一次性载入key值相同的参数
如果说两个model的某些key值相同,可以用python的字典推导方式,将名称相关的参数提取出来。例如:
def load_dict_from_pointnet(model : Point2VoxelNet, checkpoint):my_model_dict = model.state_dict()pretrained_dict = checkpoint['model_state_dict']# 只将pretraind_dict中那些在model_dict中的参数,提取出来state_dict = {k:v for k,v in pretrained_dict.items() if k in my_model_dict .keys()}my_model_dict.update(state_dict) # 注意要更新state的变量,如果直接赋值,会出现某些key没有定义,导致运行失败model.load_state_dict(my_model_dict)# 对比参数是否一致print(f"{checkpoint['model_state_dict']['feat.stn.conv1.weight'][1]}")print(f"{model.feat.stn.conv1.weight[1]}")return model
看到这里,可以知道如果自己的模型改了名称,例如.pth的参数是:feat.stn.conv1,我这边叫做了 encoder.stn.conv1,那么是无法直接赋值的。可以用方法1,一个个载入,但是太慢了。另一种方式,是做一个键值映射,如果读到的是 feat.xxx,则赋予自定义模型中的 encoder.xxx ,简单处理即可。
注意事项
- conv层需要载入的参数有:weight 和 bias
- BN层涉及的参数有:
- weight,bias
- running_mean,running_var:这两个参数用于归一化的均值和方差, 因此也需要载入
- num_batches_tracked:在训练时需要载入,在test时不需要载入
- 载入参数后,如果用于测试,需要调用
eval()。注意不能在载入参数前调用 eval。eval 会将 bn 层的training参数设置为 false ,这样在测试时 batch_size 时如果是 1 也能够正常运行。
测试
用默认方式载入参数,以及手动方式载入后的两个模型,预测结果一致。
相关文章:
【学习记录】pytorch载入模型的部分参数
需要从PointNet网络框架中提取encoder部分的参数,然后赋予自己的模型。因此,需要从一个已有的.pth文件读取部分参数,加载到自定义模型上面。做了一些尝试,记录如下。 关于模型保存与载入 torch.save(): 使用Python的pickle实用程…...
Ubuntu Wayland启动腾讯会议并实现原生屏幕共享
Intro 众所周知,长期以来,由于腾讯会议项目组的尸位素餐、极度不作为,在Wayland成为Ubuntu 24.04 LTS的默认窗口环境下,仍然选择摆烂,甚至还“贴心”地在启动脚本下增加检测Wayland退出的代码;并且即使使用…...
写Prompt的技巧和基本原则
一.基本原则 1.一定要描述清晰你需要大模型做的事情,不要模棱两可 2.告诉大模型需要它做什么,不需要做什么 改写前: 请帮我推荐一些电影 改写后: 请帮我推荐2025年新出的10部评分比较高的喜剧电影,不要问我个人喜好等其他问题ÿ…...
前端Material-UI面试题及参考答案
目录 Material-UI 的设计理念与 Material Design 规范的关系是什么? 如何通过 npm/yarn/pnpm 安装 Material-UI 的核心依赖? Material-UI 的默认主题系统如何实现全局样式管理? 如何在项目中配置自定义字体和颜色方案? 什么是 emotion 和 styled-components,它们在 Ma…...
29、web前端开发之CSS3(六)
13. 多列布局(Multi-column Layout) 多列布局(Multi-column Layout)是一种通过CSS实现的布局方式,允许将内容组织成多列,类似于报纸或杂志的排版方式。这种布局方法能够有效地利用页面空间,提升…...
Go 语言语法精讲:从 Java 开发者的视角全面掌握
《Go 语言语法精讲:从 Java 开发者的视角全面掌握》 一、引言1.1 为什么选择 Go?1.2 适合 Java 开发者的原因1.3 本文目标 二、Go 语言环境搭建2.1 安装 Go2.2 推荐 IDE2.3 第一个 Go 程序 三、Go 语言基础语法3.1 变量与常量3.1.1 声明变量3.1.2 常量定…...
MySQL 复制与主从架构(Master-Slave)
MySQL 复制与主从架构(Master-Slave) MySQL 复制与主从架构是数据库高可用和负载均衡的重要手段。通过复制数据到多个从服务器,既可以实现数据冗余备份,又能分担查询压力,提升系统整体性能与容错能力。本文将详细介绍…...
水下成像机理分析
一般情况下, 水下环境泛指浸入到人工水体 (如水库、人工湖等)或自然水体(如海洋、河流、湖 泊、含水层等)中的区域。在水下环境中所拍摄 的图像由于普遍受到光照、波长、水中悬浮颗粒物 等因素的影响,导致生成的水下图像出现模糊、退 化、偏色等现象,图像…...
腾讯云智测试开发面经
1、投递时间线 2.20投递简历,3.11第一轮面试,3.30第二轮面试,4.4第三轮面试,4.10第四轮面试,4.11offer意向书 2、第一轮面试 第一轮面试技术面,面试官是导师,面试时长40多分钟 1)自我介绍 2)数组和列表的区别 3)了解哪些数据库 4)进程和线程的区别 5)了解哪…...
JVM类加载器详解
文章目录 1.类与类加载器2.类加载器加载规则3.JVM 中内置的三个重要类加载器为什么 获取到 ClassLoader 为null就是 BootstrapClassLoader 加载的呢? 4.自定义类加载器什么时候需要自定义类加载器代码示例 5.双亲委派模式类与类加载器双亲委派模型双亲委派模型的执行…...
@ComponentScan注解详解:Spring组件扫描的核心机制
ComponentScan注解详解:Spring组件扫描的核心机制 一、ComponentScan注解概述 ComponentScan是Spring框架中的一个核心注解,用于自动扫描和注册指定包及其子包下的Spring组件。它是Spring实现依赖注入和自动装配的基础机制之一。 Retention(Retention…...
rust Send Sync 以及对象安全和对象不安全
开头:菜鸟小明的疑惑 小明: “李哥,我最近学 Rust,感觉它超级严谨,啥 Send、Sync、对象安全、静态分发、动态分发的,我都搞晕了!为啥 Rust 要设计得这么复杂啊?” 小李࿰…...
从一到无穷大 #44:AWS Glue: Data integration + Catalog
本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 本作品 (李兆龙 博文, 由 李兆龙 创作),由 李兆龙 确认,转载请注明版权。 文章目录 引言Glue的历史,设计原则与挑战Serverless ETL 功能设计Glue StudioGlue …...
【Redis】如何处理缓存穿透、击穿、雪崩
Redis 缓存穿透、击穿和雪崩是高并发场景下的典型问题,以下是详细解决方案和最佳实践: 一、缓存穿透(Cache Penetration) 问题:恶意请求不存在的数据(如不存在的ID),绕过缓存直接访…...
区块链技术如何重塑金融衍生品市场?
区块链技术如何重塑金融衍生品市场? 金融衍生品市场一直是全球金融体系的重要组成部分,其复杂性和风险性让许多投资者望而却步。然而,随着区块链技术的兴起,这一领域正在经历一场深刻的变革。区块链以其去中心化、透明和不可篡改…...
实战打靶集锦-35-GitRoot
文章目录 1. 主机发现2. 端口扫描3. 服务枚举4. 服务探查5. 系统提权6. 写在最后 靶机地址:https://download.vulnhub.com/gitroot/GitRoot.ova 1. 主机发现 目前只知道目标靶机在192.168.56.xx网段,通过如下的命令,看看这个网段上在线的主机…...
Vue3 + Element Plus + AntV X6 实现拖拽树组件
Vue3 Element Plus AntV X6 实现拖拽树组件 介绍 在本篇文章中,我们将介绍如何使用 Vue 3 和 Element Plus 结合 antv/x6 实现树形结构的拖拽功能。用户可以将树节点拖拽到图形区域,自动创建相应的节点。我们将会通过简单的示例来一步步讲解实现过程…...
从零开始跑通3DGS教程:介绍
写在前面 本文内容 本文所属《从零开始跑通3DGS教程》系列文章,将实现从原始图像(有序、无序)数据开始,经过处理(视频抽帧成有序),SFM,3DGS训练、编辑、渲染等步骤,完整地呈现从原始图像到新视角合成的全部流程&#x…...
聊聊Spring AI的Chat Model
序 本文主要研究一下Spring AI的Chat Model Model spring-ai-core/src/main/java/org/springframework/ai/model/Model.java public interface Model<TReq extends ModelRequest<?>, TRes extends ModelResponse<?>> {/*** Executes a method call to …...
将mysql配置成服务的方法
第一步:配置环境变量 1)新建MYSQL_HOME变量,并配置:C:\Program Files\MySQL\MySQL Server 5.6 MYSQL_HOME:C:\Program Files\MySQL\MySQL Server 5.6 2)编辑path系统变量,将%MYSQL_HOME%\bin添加到path变量后。配置path环境变量…...
GaussDB(for PostgreSQL) 存储引擎:ASTORE 与 USTORE 详细对比
GaussDB(for PostgreSQL) 存储引擎:ASTORE 与 USTORE 详细对比 1. 背景说明 GaussDB(for PostgreSQL) 是华为基于 PostgreSQL 开发的企业级分布式数据库,其存储引擎分为 ASTORE 和 USTORE 两种类型,分别针对不同场景优化。 2. 核心对比 (1)…...
英语口语 -- 常用 1368 词汇
英语口语 -- 常用 1368 词汇 介绍常用单词List1 (96 个)时间类气候类自然类植物类动物类昆虫类其他生物地点类 List2 (95 个)机构类声音类食品类餐饮类蔬菜类水果类食材类饮料类营养类疾病类房屋类家具类服装类首饰类化妆品类 Lis…...
SpringBoot+Vue 中 WebSocket 的使用
WebSocket 是一种在单个 TCP 连接上进行全双工通信的协议,它使得客户端和服务器之间可以进行实时数据传输,打破了传统 HTTP 协议请求 - 响应模式的限制。 下面我会展示在 SpringBoot Vue 中,使用WebSocket进行前后端通信。 后端 1、引入 j…...
关于依赖注入框架VContainer DIIOC 的学习记录
文章目录 前言一、VContainer核心概念1.DI(Dependency Injection(依赖注入))2.scope(域,作用域) 二、练习例子1.Hello,World!步骤一,编写一个底类。HelloWorldService步骤二,编写使用低类的类。GamePresenter步骤三&am…...
LRU缓存是什么
LRU缓存是什么 LRU(Least Recently Used)即最近最少使用,是一种缓存淘汰策略。在缓存空间有限的情况下,当新的数据需要存入缓存,而缓存已满时,LRU 策略会优先淘汰最近最少使用的数据,以此保证缓存中存储的是最近最常使用的数据。 LRU缓存的工作原理 LRU 缓存的核心思…...
Qt常用控件第一部分
1.控件概述 Widget 是 Qt 中的核⼼概念. 英⽂原义是 "⼩部件", 我们此处也把它翻译为 "控件" . 控件是构成⼀个图形化界⾯的基本要素. 像上述⽰例中的, 按钮, 列表视图, 树形视图, 单⾏输⼊框, 多⾏输⼊框, 滚动条, 下拉框等, 都可以称为 "控件"…...
docker存储卷及dockers容器源码部署httpd
1. COW机制 Docker镜像由多个只读层叠加而成,启动容器时,Docker会加载只读镜像层并在镜像栈顶部添加一个读写层。 如果运行中的容器修改了现有的一个已经存在的文件,那么该文件将会从读写层下面的只读层复制到读写层,该文件的只读版本依然存在,只是已经被读写层中该文件…...
JMeter接口自动化发包与示例
前言 JMeter接口自动化发包与示例 近期需要完成对于接口的测试,于是了解并简单做了个测试示例,看了看这款江湖上声名远播的强大的软件-Jmeter靠不靠谱。 官网:Apache JMeter - Apache JMeter™ 1简介 Apache-Jmeter是一个使用java语言编写且开源&…...
INFINI Console 极限控制台密码忘记了,如何重置?
在使用 INFINI Console(极限控制台)时,可能会遇到忘记密码的情况,这对于管理员来说是一个常见但棘手的问题。 本文将详细介绍如何处理 INFINI Console 密码忘记的情况,并提供两种可能的解决方案,帮助您快速…...
Python运算符的理解及简单运用
免责声明 如有异议请在评论区友好交流,或者私信 内容纯属个人见解,仅供学习参考 如若从事非法行业请勿食用 如有雷同纯属巧合 版权问题请直接联系本人进行删改 前言 提示:这里可以添加本文要记录的大概内容: 提示:以…...
