当前位置: 首页 > news >正文

14、保存与加载PyTorch训练的模型和超参数

文章目录

  • 1. state_dict
  • 2. 模型保存
  • 3. check_point
  • 4. 详细保存
  • 5. Docker
  • 6. 机器学习常用库

1. state_dict

nn.Module 类是所有神经网络构建的基类,即自己构建一个深度神经网络也是需要继承自nn.Module类才行,并且nn.Module中的state_dict包含神经网络中的权重weight ,偏置bias,过程量buffer,举例说明:

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :NN_Embedding.py
# @Time      :2024/11/26 22:50
# @Author    :Jason Zhang
import torch
from torch import nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear1 = nn.Linear(3, 4)self.relu = nn.ReLU()self.linear2 = nn.Linear(4, 5)self.batch_norm = nn.BatchNorm2d(4)def forward(self, x):x = self.linear1(x)x = self.relu(x)y = self.linear2(x)return yif __name__ == "__main__":my_test = MyModel()my_keys = my_test.state_dict().keys()print(f"my_keys={my_keys}")
  • 结果:
    从结果中看出,跟说明的一样,不仅存的是weight,bias ,还有buffer
y_keys=odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias', 'batch_norm.weight', 'batch_norm.bias', 'batch_norm.running_mean', 'batch_norm.running_var', 'batch_norm.num_batches_tracked'])

2. 模型保存

保存和加载

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :torch_save.py
# @Time      :2024/11/27 21:33
# @Author    :Jason Zhang
import torch
import torchvision.models as modelsif __name__ == "__main__":run_code = 0model = models.vgg16(weights='IMAGENET1K_V1')torch.save(model.state_dict(), 'model_weights.pth')model.load_state_dict(torch.load('model_weights.pth', weights_only=True))model.eval()torch.save(model, 'model.pth')

3. check_point

# Define model
import torch
from torch import nn
from torch import optim
import torch.nn.functional as Fclass TheModelClass(nn.Module):def __init__(self):super(TheModelClass, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# Initialize model
model = TheModelClass()# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():print(var_name, "\t", optimizer.state_dict()[var_name])
Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

4. 详细保存

在训练过程中,我们希望详细保存,以至于我们可以在中断训练中恢复训练。
保存模型

5. Docker

关于Docker方式搭建深度神经网络环境和配置
在这里插入图片描述

6. 机器学习常用库

在这里插入图片描述

相关文章:

14、保存与加载PyTorch训练的模型和超参数

文章目录 1. state_dict2. 模型保存3. check_point4. 详细保存5. Docker6. 机器学习常用库 1. state_dict nn.Module 类是所有神经网络构建的基类,即自己构建一个深度神经网络也是需要继承自nn.Module类才行,并且nn.Module中的state_dict包含神经网络中…...

【前端开发】JS+Vuew3请求列表数据并分页

应用技术:原生JavaScript Vue3 $(function () {ini(); });function ini() {const { createApp, ref, onMounted } Vue;createApp({setup() {const data ref({studentList: [],page: 1,pageSize: 10,});const getStudentList async (page, key) > {window.ons…...

Trimble X12助力电力管廊数据采集,为机器人巡视系统提供精准导航支持

地下电缆是一个城市重要的基础设施,它不仅具有规模大、范围广、空间分布复杂等特点,更重要的是它还承担着信息传输、能源输送等与人们生活息息相关的重要功能,也是一个城市赖以生存和发展的物质基础。 01、项目概述 本次项目是对某区域2公里左…...

Docker 清理镜像策略详解

文章目录 前言一、删除 Docker 镜像1. 查看当前镜像2. 删除单个镜像3. 删除多个镜像4. 删除所有未使用的镜像5. 删除悬空的 Docker 镜像6. 根据模式删除镜像7. 删除所有镜像 二、删除 Docker 容器1. 查找容器2. 删除一个或多个特定容器3. 退出时删除容器4. 删除所有已退出的容器…...

【Linux】TCP网络编程

目录 V1_Echo_Server V2_Echo_Server多进程版本 V3_Echo_Server多线程版本 V3-1_多线程远程命令执行 V4_Echo_Server线程池版本 V1_Echo_Server TcpServer的上层调用如下,和UdpServer几乎一样: 而在InitServer中,大部分也和UDP那里一样&…...

排序学习整理(2)

上集回顾 排序学习整理(1)-CSDN博客 2.3 交换排序 交换排序的基本思想是:根据序列中两个记录键值的比较结果,交换这两个记录在序列中的位置。 特点: 通过比较和交换操作,将键值较大的记录逐步移动到序列…...

AI蛋白质设计与人工智能药物设计

AI蛋白质设计与人工智能药物设计 AI蛋白质设计 一、蛋白质相关的深度学习简介 1.基础概念 1.1.机器学习简介:从手写数字识别到大语言模型 1.2.蛋白质结构预测与设计回顾 1.3.Linux简介 1.4.代码环境:VS code和Jupyter notebook* 1.5.Python关键概…...

IOS ARKit进行图像识别

先讲一下基础控涧,资源的话可以留言,抽空我把它传到GitHub上,这里没写收积分,竟然充值才能下载,我下载也要充值,牛! ARSCNView 可以理解画布或者场景 1 配置 ARWorldTrackingConfiguration AR追…...

初级数据结构——二叉搜索树

目录 前言一、定义二、基本操作三、时间复杂度分析四、变体五、动态图解六、代码模版七、经典例题[1.——700. 二叉搜索树中的搜索](https://leetcode.cn/problems/search-in-a-binary-search-tree/)代码题解 [2.——938. 二叉搜索树的范围和](https://leetcode.cn/problems/ra…...

C++设计模式之组合模式中如何实现同一层部件的有序性

在组合模式中,为了实现同一层上部件的有序性,可以采取以下几种设计方法: 1. 使用有序集合 使用有序集合(如 std::list、std::vector 或其他有序容器)来存储和管理子部件。这种方法可以确保子部件按照特定顺序排列&am…...

duxapp RN 端使用AppUpgrade 进行版本更新

版本更新包含了组件和工具的组合 注册 下面这是 duxcms 入口文件检查更新的注册方法,注册的同时会检查更新 import {request,updateApp,userConfig } from ./utils// 检查app更新 setTimeout(async () > {if (process.env.TARO_ENV rn) {// eslint-disable-n…...

【计网】自定义序列化反序列化(三) —— 实现网络版计算器【下】

🌎实现网络版计算器【下】 本次序列化与反序列化所用到的代码,Tcp服务自定义序列化反序列化实现网络版计算器。 文章目录: 实实现网络版计算器【下】 客户端实现     基于守护进程的改写 🚀客户端实现 在这之前&#xff0c…...

神经网络中的优化方法(一)

目录 摘要Abstract1. 与纯优化的区别1.1 经验风险最小化1.2 代理损失函数1.3 批量算法和小批量算法 2. 神经网络中优化的挑战2.1 病态2.2 局部极小值2.3 高原、鞍点和其他平坦区域2.4 悬崖和梯度爆炸2.5 长期依赖2.6 非精确梯度2.7 局部和全局结构间的弱对应 3. 基本算法3.1 随…...

Linux 计算机网络基础概念

目录 0.前言 1.计算机网络背景 1.1 独立模式 1.2 网络互联 1.3 局域网(Local Area Network,LAN) 1.4 广域网(Wide Area Network,WAN) 2.协议 2.1什么是协议 2.2协议分层和软件分层 2.3 OSI七层网络模型 2.3…...

qt QGraphicsEllipseItem详解

1、概述 QGraphicsEllipseItem是Qt框架中QGraphicsItem的一个子类,它提供了一个可以添加到QGraphicsScene中的椭圆项。QGraphicsEllipseItem表示一个带有填充和轮廓的椭圆,也可以用于表示椭圆段(通过startAngle()和spanAngle()方法&#xff…...

Python websocket

router.websocket(/chat/{flow_id}) 接口代码,并了解其工作流程、涉及的组件以及如何基于此实现你的新 WebSocket 接口。以下内容将分为几个部分进行讲解: 接口整体概述代码逐行解析关键组件和依赖关系如何基于此实现新功能示例:创建一个新的…...

【MySQL-5】MySQL的内置函数

目录 1. 整体学习的思维导图 2. 日期函数 ​编辑 2.1 current_date() 2.2 current_time() 2.3 current_timestamp() 2.4 date(datetime) 2.5 now() 2.6 date_add() 2.7 date_sub() 2.8 datediff() 2.9 案例 2.9.1 创建一个出生日期登记簿 2.9.2 创建一个留言版 3…...

深度学习笔记之BERT(三)RoBERTa

深度学习笔记之RoBERTa 引言回顾:BERT的预训练策略RoBERTa训练过程分析静态掩码与动态掩码的比较模型输入模式与下一句预测使用大批量进行训练使用Byte-pair Encoding作为子词词元化算法更大的数据集和更多的训练步骤 RoBERTa配置 引言 本节将介绍一种基于 BERT \t…...

C++知识点总结(59):背包型动态规划

背包型动态规划 一、背包 dp1. 01 背包(限量)2. 完全背包(不限量)3. 口诀 二、例题1. 和是质数的子集数2. 黄金的太阳3. 负数子集和4. NASA的⻝物计划 一、背包 dp 1. 01 背包(限量) 假如有这几个物品&am…...

C++:反向迭代器的实现

反向迭代器的实现与 stack 、queue 相似&#xff0c;是通过适配器模式实现的。通过传入不同类型的迭代器来实现其反向迭代器。 正向迭代器中&#xff0c;begin() 指向第一个位置&#xff0c;end() 指向最后一个位置的下一个位置。 代码实现&#xff1a; template<class I…...

【SpringBoot】100、SpringBoot中使用自定义注解+AOP实现参数自动解密

在实际项目中,用户注册、登录、修改密码等操作,都涉及到参数传输安全问题。所以我们需要在前端对账户、密码等敏感信息加密传输,在后端接收到数据后能自动解密。 1、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId...

大数据零基础学习day1之环境准备和大数据初步理解

学习大数据会使用到多台Linux服务器。 一、环境准备 1、VMware 基于VMware构建Linux虚拟机 是大数据从业者或者IT从业者的必备技能之一也是成本低廉的方案 所以VMware虚拟机方案是必须要学习的。 &#xff08;1&#xff09;设置网关 打开VMware虚拟机&#xff0c;点击编辑…...

深入理解JavaScript设计模式之单例模式

目录 什么是单例模式为什么需要单例模式常见应用场景包括 单例模式实现透明单例模式实现不透明单例模式用代理实现单例模式javaScript中的单例模式使用命名空间使用闭包封装私有变量 惰性单例通用的惰性单例 结语 什么是单例模式 单例模式&#xff08;Singleton Pattern&#…...

转转集团旗下首家二手多品类循环仓店“超级转转”开业

6月9日&#xff0c;国内领先的循环经济企业转转集团旗下首家二手多品类循环仓店“超级转转”正式开业。 转转集团创始人兼CEO黄炜、转转循环时尚发起人朱珠、转转集团COO兼红布林CEO胡伟琨、王府井集团副总裁祝捷等出席了开业剪彩仪式。 据「TMT星球」了解&#xff0c;“超级…...

基于Java+MySQL实现(GUI)客户管理系统

客户资料管理系统的设计与实现 第一章 需求分析 1.1 需求总体介绍 本项目为了方便维护客户信息为了方便维护客户信息&#xff0c;对客户进行统一管理&#xff0c;可以把所有客户信息录入系统&#xff0c;进行维护和统计功能。可通过文件的方式保存相关录入数据&#xff0c;对…...

三分算法与DeepSeek辅助证明是单峰函数

前置 单峰函数有唯一的最大值&#xff0c;最大值左侧的数值严格单调递增&#xff0c;最大值右侧的数值严格单调递减。 单谷函数有唯一的最小值&#xff0c;最小值左侧的数值严格单调递减&#xff0c;最小值右侧的数值严格单调递增。 三分的本质 三分和二分一样都是通过不断缩…...

【MATLAB代码】基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),附源代码|订阅专栏后可直接查看

文章所述的代码实现了基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),针对传感器观测数据中存在的脉冲型异常噪声问题,通过非线性加权机制提升滤波器的抗干扰能力。代码通过对比传统KF与MCC-KF在含异常值场景下的表现,验证了后者在状态估计鲁棒性方面的显著优…...

yaml读取写入常见错误 (‘cannot represent an object‘, 117)

错误一&#xff1a;yaml.representer.RepresenterError: (‘cannot represent an object’, 117) 出现这个问题一直没找到原因&#xff0c;后面把yaml.safe_dump直接替换成yaml.dump&#xff0c;确实能保存&#xff0c;但出现乱码&#xff1a; 放弃yaml.dump&#xff0c;又切…...

WEB3全栈开发——面试专业技能点P4数据库

一、mysql2 原生驱动及其连接机制 概念介绍 mysql2 是 Node.js 环境中广泛使用的 MySQL 客户端库&#xff0c;基于 mysql 库改进而来&#xff0c;具有更好的性能、Promise 支持、流式查询、二进制数据处理能力等。 主要特点&#xff1a; 支持 Promise / async-await&#xf…...

Docker、Wsl 打包迁移环境

电脑需要开启wsl2 可以使用wsl -v 查看当前的版本 wsl -v WSL 版本&#xff1a; 2.2.4.0 内核版本&#xff1a; 5.15.153.1-2 WSLg 版本&#xff1a; 1.0.61 MSRDC 版本&#xff1a; 1.2.5326 Direct3D 版本&#xff1a; 1.611.1-81528511 DXCore 版本&#xff1a; 10.0.2609…...