当前位置: 首页 > article >正文

PyTorch参数管理详解:从访问到初始化与共享

本文通过实例代码讲解如何在PyTorch中管理神经网络参数,包括参数访问、多种初始化方法、自定义初始化以及参数绑定技术。所有代码可直接运行,适合深度学习初学者进阶学习。


1. 定义网络与参数访问

1.1 定义单隐藏层多层感知机

import torch
from torch import nn# 定义单隐藏层多层感知机
net1 = nn.Sequential(nn.Linear(4, 8),  # 输入层4维,隐藏层8维nn.ReLU(),nn.Linear(8, 1)   # 输出层1维
)
x = torch.rand(2, 4)  # 随机生成2个4维输入向量
net1(x)                # 前向传播

1.2 访问网络参数

# 访问第二层(索引2)的参数(权重和偏置)
print(net1[2].state_dict())# 查看参数类型、数据和梯度
print(type(net1[2].bias))    # 类型:Parameter
print(net1[2].bias)          # 参数值(含梯度信息)
print(net1[2].bias.data)     # 参数数据(张量)
print(net1[2].bias.grad)     # 梯度(未反向传播时为None)

1.3 批量访问参数

# 访问第一层的参数名称和形状
print(*[(name, param.shape) for name, param in net1[0].named_parameters()])# 访问整个网络的参数
print(*[(name, param.shape) for name, param in net1.named_parameters()])# 通过state_dict直接访问参数数据
print(net1.state_dict()['2.bias'].data)

2. 参数初始化方法

2.1 内置初始化

# 正态分布初始化权重,偏置置零
def init_normal(model):if isinstance(model, nn.Linear):nn.init.normal_(model.weight, mean=0, std=0.01)nn.init.zeros_(model.bias)net1.apply(init_normal)
print(net1[0].weight.data[0], net1[0].bias.data[0])# 常数初始化(权重为1,偏置为0)
def init_constant(model):if isinstance(model, nn.Linear):nn.init.constant_(model.weight, 1)nn.init.zeros_(model.bias)net1.apply(init_constant)
print(net1[0].weight.data[0], net1[0].bias.data[0])

2.2 分层初始化

# 对第一层使用Xavier初始化,第二层使用常数42初始化
def xavier(model):if isinstance(model, nn.Linear):nn.init.xavier_uniform_(model.weight)def init_42(model):if isinstance(model, nn.Linear):nn.init.constant_(model.weight, 42)net1[0].apply(xavier)
net1[2].apply(init_42)
print(net1[0].weight.data[0])
print(net1[2].weight.data)

2.3 自定义初始化

# 自定义初始化:权重在[-10,10]均匀分布,并过滤绝对值小于5的值
def my_init(model):if isinstance(model, nn.Linear):print(f'init weight {model.weight.shape}')nn.init.uniform_(model.weight, -10, 10)model.weight.data *= (model.weight.abs() >= 5)net1.apply(my_init)
print(net1[0].weight.data[:2])  # 显示前两行权重

3. 参数绑定与共享

3.1 直接修改参数

# 直接操作参数数据
net1[0].weight.data[:] += 1     # 所有权重+1
net1[0].weight.data[0, 0] = 42  # 修改特定位置权重
print(net1[0].weight.data[0])   # 输出第一行权重

3.2 参数共享

# 共享线性层参数
shared_layer = nn.Linear(8, 8)
net3 = nn.Sequential(nn.Linear(4, 8), nn.ReLU(),shared_layer, nn.ReLU(),     # 第2层shared_layer, nn.ReLU(),     # 第4层(共享参数)nn.Linear(8, 1)
)# 验证参数共享
print(net3[2].weight.data[0] == net3[4].weight.data[0])  # 输出全True
net3[2].weight.data[0, 0] = 100
print(net3[2].weight.data[0] == net3[4].weight.data[0])  # 修改后仍为True

4. 嵌套网络结构

# 构建嵌套网络
def model1():return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),nn.Linear(8, 4), nn.ReLU())def model2():net = nn.Sequential()for i in range(4):net.add_module(f'model{i}', model1())return netrgnet = nn.Sequential(model2(), nn.Linear(4, 1))
print(rgnet)  # 打印网络结构

总结

本文演示了PyTorch中参数管理的核心操作,包括:

  • 通过state_dictnamed_parameters访问参数

  • 使用内置初始化方法(正态分布、常数、Xavier)

  • 自定义初始化逻辑

  • 参数的直接修改与共享

  • 复杂嵌套网络的定义

掌握这些技能可以更灵活地设计和优化神经网络模型。建议读者在实践中结合具体任务调整初始化策略,并注意参数共享时的梯度传播特性。


提示:以上代码需要在PyTorch环境中运行,建议使用Jupyter Notebook逐步调试以观察中间结果。

相关文章:

PyTorch参数管理详解:从访问到初始化与共享

本文通过实例代码讲解如何在PyTorch中管理神经网络参数,包括参数访问、多种初始化方法、自定义初始化以及参数绑定技术。所有代码可直接运行,适合深度学习初学者进阶学习。 1. 定义网络与参数访问 1.1 定义单隐藏层多层感知机 import torch from torch…...

页面简单传参

#简单的情景:你需要在帖子主页传递参数给帖子详情页面,携带在主页获得的帖子ID。你有以下几种传递方法# #使用Vue3 TS# 1. 通过 URL 参数传递(Query 参数) 这是最简单、最常用的方法,ID 会显示在 URL 中的 ? 后面…...

nginx路径匹配的优先级

在 Nginx 配置中,当请求 /portal/agent/sse 时,会匹配 location ~* /sse$ 规则,而不是 location /portal。原因如下: 匹配规则解析 location ~* /sse$ ~* 表示 不区分大小写的正则匹配/sse$ 表示以 /sse 结尾的路径匹配结果&#…...

一周学会Pandas2 Python数据处理与分析-Pandas2一维数据结构-Series

锋哥原创的Pandas2 Python数据处理与分析 视频教程: 2025版 Pandas2 Python数据处理与分析 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili Pandas提供Series和DataFrame作为数组数据的存储框架。 Series(系列、数列、序列)是一个带有…...

DApp实战篇:前端技术栈一览

前言 在前面一系列内容中,我们由浅入深地了解了DApp的组成,从本小节开始我将带领大家如何完成一个完整的DApp。 本小节则先从前端开始。 前端技术栈 在前端开发者速入:DApp中的前端要干些什么?文中我说过,即便是在…...

leetcode6.Z字形变换

题目说是z字形变化&#xff0c;但其实模拟更像n字形变化&#xff0c;找到字符下标规律就逐个拼接就能得到答案 class Solution {public String convert(String s, int numRows) {if(numRows1)return s;StringBuilder stringBuilder new StringBuilder();for (int i 0; i <…...

HarmonyOS应用开发者高级-编程题-001

题目一:跨设备分布式数据同步 需求描述 开发一个分布式待办事项应用,要求: 手机与平板登录同一华为账号时,自动同步任务列表任一设备修改任务状态(完成/删除),另一设备实时更新任务数据在设备离线时能本地存储,联网后自动同步实现方案 // 1. 定义分布式数据模型 imp…...

鸿蒙开发者高级认证编程题库

题目一:跨设备分布式数据同步 需求描述 开发一个分布式待办事项应用,要求: 手机与平板登录同一华为账号时,自动同步任务列表任一设备修改任务状态(完成/删除),另一设备实时更新任务数据在设备离线时能本地存储,联网后自动同步实现方案 // 1. 定义分布式数据模型 imp…...

Ubuntu(CentOS、Rockylinux等)快速进入深度学习pytorch环境

这里写自定义目录标题 安装进入系统&#xff08;如Ubuntu22.04&#xff09;安装anacondapip、conda换源pip换源conda换源 安装nvidia安装pytorch环境针对于wsl的优化 安装进入系统&#xff08;如Ubuntu22.04&#xff09; docker 、 wsl 、 双系统 、服务器系统 推荐 Ubuntu 20…...

[实战] 天线阵列波束成形原理详解与仿真实战(完整代码)

天线阵列波束成形原理详解与仿真实战 1. 引言 在无线通信、雷达和声学系统中&#xff0c;波束成形&#xff08;Beamforming&#xff09;是一种通过调整天线阵列中各个阵元的信号相位和幅度&#xff0c;将电磁波能量集中在特定方向的技术。其核心目标是通过空间滤波增强目标方…...

Android开发okhttp添加头部参数

Android开发okhttp添加头部参数或者是头文件 private static class RequestHeaderInterceptor implements Interceptor {Overridepublic Response intercept(Chain chain) throws IOException {Request original chain.request();//添加头部信息Request request original.new…...

Halcon图像采集

Halcon是一款强大的机器视觉软件&#xff0c;结合C#可以开发出功能完善的视觉应用程序。 基本设置 确保已经安装了Halcon和Halcon的.NET库&#xff08;HalconDotNet&#xff09;。 1. 添加引用 在C#项目中&#xff0c;需要添加对HalconDotNet.dll的引用&#xff1a; 右键点…...

自动提取pdf公式 ➕ 输出 LaTeX

# 创建打包脚本的主内容 script_content """ from doc2x.extract_formula import extract_formula_imgs from pix2text import Pix2Text from PIL import Image import osdef main():pdf_path "your_file.pdf" # 将你的PDF命名为 your_file.pdf 并…...

(十)安卓开发中的Activity之间的通信使用详解

在 Android 开发中&#xff0c;Activity 之间的通信是非常常见且核心的功能之一&#xff0c;常见的方式包括&#xff1a; 使用显式 Intent 传递数据使用隐式 Intent 实现跨组件调用使用 startActivityForResult&#xff08;或新版 Activity Result API&#xff09;回传数据传递…...

python 浅拷贝copy与深拷贝deepcopy 理解

一 浅拷贝与深拷贝 1. 浅拷贝 浅拷贝只复制了对象本身&#xff08;即c中的引用&#xff09;。 2. 深拷贝 深拷贝创建一个新的对象&#xff0c;同时也会创建所有子对象的副本&#xff0c;因此新对象与原对象之间完全独立。 二 代码理解 1. 案例一 a 10 b a b 20 print…...

基于neo4j存储知识树-mac

1、安装jdk21 for mac(jdk-21_macos-aarch64_bin.dmg) 2、安装neo4j for mac(neo4j-community-5.26.0-unix.tar.gz) 3、使用默认neo4j/neo4j登录http://localhost:7474 修改登录密码&#xff0c;可以使用生成按钮生成密码&#xff0c;连接数据库&#xff0c;默认设置为neo4j…...

Tiktok 关键字 视频及评论信息爬虫(1) [2025.04.07]

&#x1f64b;‍♀️Tiktok APP的基于关键字检索的视频及评论信息爬虫共分为两期&#xff0c;希望对大家有所帮助。 第一期见下文。 第二期&#xff1a;基于视频URL的评论信息爬取 1. Node.js环境配置 首先配置 JavaScript 运行环境&#xff08;如 Node.js&#xff09;&#x…...

基于人工智能的高中教育评价体系重构研究

基于人工智能的高中教育评价体系重构研究 一、引言 1.1 研究背景 在科技飞速发展的当下&#xff0c;人工智能技术已广泛渗透至各个领域&#xff0c;教育领域亦不例外。人工智能凭借其强大的数据处理能力、智能分析能力和个性化服务能力&#xff0c;为教育评价体系的创新与发…...

【学习笔记】文件上传漏洞--二次渲染、.htaccess、变异免杀

目录 第十二关 远程包含地址转换 第十三关 突破上传删除 条件竞争 第十四关 二次渲染 第十五关 第十六关 第十七关 .htaccess 第十八关 后门免杀 第十九关 日志包含 第十二关 远程包含地址转换 延续第十一关&#xff0c;加一个文件头&#xff0c;上传成功&#xff0c…...

C++ 基础进阶

C 基础进阶 内容概述&#xff1a; 函数重载&#xff1a;int add(int x, inty);&#xff0c;long long add(long long x, long long y);&#xff0c;double add(double x, double y);模板函数&#xff1a;template<typename T> 或 template<class T>结构体&#x…...

【OS】Process Management(3)

《计算机操作系统&#xff08;第三版&#xff09;》&#xff08;汤小丹&#xff09;学习笔记 文章目录 5、进程通信&#xff08;Inter-Process Communication&#xff09;5.1、进程通信的类型5.1.1、共享存储器系统&#xff08;Shared Memory System&#xff09;5.1.2、消息传递…...

单reactor实战

前言&#xff1a;reactor作为一种高性能的范式&#xff0c;值得我们学习 本次目标 实现一个基于的reactor 具备echo功能的服务器 核心组件 Reactor本身是靠一个事件驱动的框架,无疑引出一个类似于moduo的"EventLoop "以及boost.asio中的context而言&#xff0c;不断…...

初阶C++笔记第一篇:C++基础语法

虽然以下大多数知识点都在C语言中学过&#xff0c;但还是有一些知识点和C语言不同&#xff0c;比如&#xff1a;代码格式、头文件、关键字、输入输出、字符串类型等... 1. 初识C 1.1 第一个C程序 编写C分为4个步骤&#xff1a; 创建项目创建文件编写代码运行程序 C的第一条…...

java基础 流(Stream)

Stream Stream 的核心概念核心特点 Stream 的操作分类中间操作&#xff08;Intermediate Operations&#xff09;终止操作&#xff08;Terminal Operations&#xff09; Stream 的流分类顺序流&#xff08;Sequential Stream&#xff09;并行流&#xff08;Parallel Stream&…...

【AI】prompt engineering

prompt engineering ## prompt engineering ## prompt engineering ## prompt engineering 一、定义 Prompt 工程&#xff08;Prompt Engineering&#xff09;是指在使用语言模型&#xff08;如 ChatGPT、文心一言等&#xff09;等人工智能工具时&#xff0c;设计和优化输入提…...

无需libpacp库,BPF指令高效捕获指定数据包

【环境】无libpacp库的Linux服务器 【要求】高效率读取数据包&#xff0c;并过滤指定端口和ip 目前遇到两个问题 一是手写BPF&#xff0c;难以兼容&#xff0c;有些无法正常过滤二是性能消耗问题&#xff0c;尽可能控制到1% 大方向&#xff1a;过滤数据包要在内核层处理&…...

LeetCode算法题(Go语言实现)_36

题目 给定一个二叉树的根节点 root &#xff0c;和一个整数 targetSum &#xff0c;求该二叉树里节点值之和等于 targetSum 的 路径 的数目。 路径 不需要从根节点开始&#xff0c;也不需要在叶子节点结束&#xff0c;但是路径方向必须是向下的&#xff08;只能从父节点到子节点…...

react实现上传图片到阿里云OSS以及问题解决(保姆级)

一、优势 提高上传速度&#xff1a;前端直传利用了浏览器与 OSS 之间的直接连接&#xff0c;能够充分利用用户的网络带宽。相比之下&#xff0c;后端传递文件时&#xff0c;文件需要经过后端服务器的中转&#xff0c;可能会受到后端服务器网络环境和处理能力的限制&#xff0c;…...

无法看到新安装的 JDK 17

在 Linux 系统中使用 update-alternatives --config java 无法看到新安装的 JDK 17&#xff0c;可能是由于 JDK 未正确注册到系统备选列表中。 一、原因分析 JDK 未注册到 update-alternatives update-alternatives 工具需要手动注册 JDK 路径后才能识别新版本。如果仅安装 JDK…...

LeetCode 3396.使数组元素互不相同所需的最少操作次数:O(n)一次倒序遍历

【LetMeFly】3396.使数组元素互不相同所需的最少操作次数&#xff1a;O(n)一次倒序遍历 力扣题目链接&#xff1a;https://leetcode.cn/problems/minimum-number-of-operations-to-make-elements-in-array-distinct/ 给你一个整数数组 nums&#xff0c;你需要确保数组中的元素…...