当前位置: 首页 > news >正文

Jax(Random、Numpy)常用函数

目录

Jax

 vmap

Array

reshape

Random

PRNGKey

uniform

normal

split

 choice

Numpy

expand_dims

linspace

jax.numpy.linalg[pkg]

dot

matmul

arange

interp 

tile

reshape


Jax

jit

jax.jit(funin_shardings=UnspecifiedValueout_shardings=UnspecifiedValuestatic_argnums=Nonestatic_argnames=Nonedonate_argnums=Nonedonate_argnames=Nonekeep_unused=Falsedevice=Nonebackend=Noneinline=Falseabstracted_axes=None)[source]

注:jax.jit 是 JAX 中的一个装饰器,用于将 Python 函数编译为高效的机器代码,以提高运行速度。JIT(Just-In-Time)编译可以加速函数的执行,尤其是在循环或需要多次调用。

>>>jax.jit(lambda x,y : x + y)
<PjitFunction of <function <lambda> at 0x7ea7b402f130>>
>>>jax.jit(lambda x,y : x + y)(1,2) #process jitfunc -> lambda fun
Array(3, dtype=int32, weak_type=True)
>>>@jax.jitdef fun(x,y):return x + y
>>>fun
<PjitFunction of <function fun at 0x7ea7b402f5b0>>
>>>fun(1,2)
Array(3, dtype=int32, weak_type=True)

 vmap

jax.vmap(funin_axes=0out_axes=0axis_name=Noneaxis_size=Nonespmd_axis_name=None)[source]

注:对函数进行向量化处理,通常用于批量处理数据,而不需要显式地编写循环,函数映射调用,区别于pmap,vmap单个设备(CPU或GPU)上处理批量数据,pmap在多个设备(GPU或TPU)上并行处理数据(分布式)

>>>f_xy = lambda x,y : x + y
>>>x = jax.numpy.array([[1, 2], [3, 4]])  # shape (2, 2)
>>>y = jax.numpy.array([[5, 6], [7, 8]])  # shape (2, 2)# in this x and y array, axis 0 is row , axis 1 is col, ref shape index
# in x and y, axis -1 is shape[-1] , axis -2 is shape[-2]>>>jax.vmap(f_xy,in_axes=(0,0))(x,y)      # default out_axes = 0,row ouput
# x row + y row , need x row dim equal y row dim
Array([[ 6,  8],[10, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,0),out_axes=1)(x,y) #show output by col
Array([[ 6,  8],[10, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,1))(x,y) 
# x row + y col , need x row's dim equal y col's dim
Array([[ 6,  9],[ 9, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,1),out_axes=1)(x,y) #show output by col 
Array([[ 6,  9],[ 9, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(None,0))(x,y) #no vector x by row or col, x is block
# x block + y row vector, x shape (2,2) , y shape(2,2), need x row equal y row
# return shape(y_dim_2,x_dim_1,x_dim2)
Array([[[ 6,  8],[ 8, 10]],[[ 8, 10],[10, 12]]], dtype=int32)

ref:Learning about JAX :axes in vmap()

Array

reshape

abstract Array.reshape(*argsorder='C')[source]

注:Array对象的实例方法,引用jax.numpy.reshape函数

Random

PRNGKey

jax.random.PRNGKey(seed*impl=None)[source]#

注:创建一个 PRNG key,作为生成随机数的种子Seed

eg:       

>>>jax.random.PRNGKey(0)
Array([0, 0], dtype=uint32)

uniform

jax.random.uniform(keyshape=()dtype=<class 'float'>minval=0.0maxval=1.0)[source]

注:在给定的形状(shape)和数据类型(dtype)下,从 [minval, maxval) 区间内采样均匀分布的随机值

>>>k = jax.random.PRNGKey(0)
>>>jax.random.uniform(k,shape=(1,))
Array([0.41845703], dtype=float32)

normal

normal(keyshape=()dtype=<class 'float'>)[source]

注:在给定的形状shape和浮点数据类型dtype下,采样标准正态分布的随机值

>>>k = jax.random.PRNGKey(0)
>>>jax.random.normal(k,shape=(1,))
Array([-0.20584226], dtype=float32)

split

jax.random.split(keynum=2)[source]

注:用于生成伪随机数生成器(PRNG)状态的函数。它允许你从一个现有的 PRNG 状态中生成多个新的状态,从而实现随机数的可重复性和并行性。 

>>>k = jax.random.PRNGKey(1)
>>>k1,k2 = jax.random.split(k)
>>>k1
Array([2441914641, 1384938218], dtype=uint32)
>>>k2
Array([3819641963, 2025898573], dtype=uint32)

 choice

jax.random.choice(keyashape=()replace=Truep=Noneaxis=0)[source]

注:从给定数组a中按shape生成随机样本,区别于numpy.random.choice函数。default choice one elem。

>>>k = jax.random.PRNGKey(0)
>>>a = jax.numpy.array([1,2,3,4,5,6,7,8,9,0])
>>>jax.random.choice(k,a,(10,)) # random no seq
Array([9, 6, 8, 7, 8, 4, 1, 2, 3, 3], dtype=int32)
>>>jax.random.choice(k,a,(2,5))
Array([[9, 6, 8, 7, 8],[4, 1, 2, 3, 3]], dtype=int32)

Numpy

expand_dims

expand_dims(aaxis)[source]

注:为数组a的维度axis增加1维度

>>>arr = jax.numpy.array([1,2,3])
>>>arr.shape
(3,)
>>>jax.numpy.expand_dims(arr,axis=0)
Array([[1, 2, 3]], dtype=int32)
>>>jax.numpy.expand_dims(arr,axis=0).shape
(1, 3)
>>>jax.numpy.expand_dims(arr,axis=1)
Array([[1],[2],[3]], dtype=int32)
>>>jax.numpy.expand_dims(arr,axis=1).shape
(3, 1)

linspace

linspace(start: ArrayLikestop: ArrayLikenum: int = 50endpoint: bool = Trueretstep: Literal[False] = Falsedtype: DTypeLike | None = Noneaxis: int = 0*device: xc.Device | Sharding | None = None) → Array[source]

注:在给定区间[start,stop]内返回均匀间隔的数字

>>>jax.numpy.linspace(0,1,5)
Array([0.  , 0.25, 0.5 , 0.75, 1.  ], dtype=float32)

jax.numpy.linalg[pkg]

jax.numpy.linalg 是 JAX 库中用于线性代数操作的模块,对应numpy.linalg库实现

        jax.numpy.linalg.cholesky(a*upper=False)[source]

        注:计算一个正定矩阵A的 Cholesky 分解,得到满足A=L@L.T等式的下三角或上三角矩阵L,@为Python1.5定义的矩阵乘运算(jax.numpy.matmul),L.T为L转置矩阵L^{T}

>>> d = jax.numpy.array([[2. , 1.],[1. , 2.]])
>>>jax.numpy.linalg.cholesky(d)
Array([[1.4142135 , 0.        ],[0.70710677, 1.2247449 ]], dtype=float32)>>>L = jax.numpy.linalg.cholesky(d)
>>>L@L.T
Array([[1.9999999 , 0.99999994],[0.99999994, 2.        ]], dtype=float32)

dot

dot(ab*precision=Nonepreferred_element_type=None)[source]

注:用于计算两个数组的点积(dot product),对于一维数组,它计算的是向量的内积;对于二维数组(矩阵),它计算的是矩阵乘积;对于更高维度的数组,它执行的是逐元素的点积,并在最后一个轴上进行求和

  • 对于一维数组(向量)numpy.dot(a, b) 计算的是向量 a 和 b 的点积,结果是一个标量。
  • 对于二维数组(矩阵)numpy.dot(A, B) 计算的是矩阵 A 和 B 的乘积,其中 A 的列数必须与 B 的行数相等。结果是一个新的矩阵。
  • 对于更高维度的数组numpy.dot() 可以进行更复杂的广播和求和运算,但通常用于计算张量积(tensor product)的某个维度上的和。
>>>jax.numpy.dot(jax.numpy.array([1,2,3]),2)
Array([2, 4, 6], dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([1,2,3]),jax.numpy.array([1,2,3]))
Array(14, dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([[1,2,3],[4,5,6]]),jax.numpy.array([1,2,3]))
Array([14, 32], dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([[1,2],[4,5]]),jax.numpy.array([[1,2],[4,5]]))
Array([[ 9, 12],[24, 33]], dtype=int32)
>>>a = jax.numpy.zeros((1,3,2))
>>>b = jax.numpy.zeros((1,2,4))
>>>jax.numpy.dot(a,b).shape
(1, 3, 1, 4) #matmul ret (1,3,4)

matmul

matmul(ab*precision=Nonepreferred_element_type=None)[source]#

注:于执行矩阵乘法,也称为 @ 运算符(在 Python 3.5+ 中引入),对于一维数组(向量),它计算的是内积(与 dot 相同);对于二维数组(矩阵),它计算的是矩阵乘积(与 dot 相同);对于更高维度的数组,它执行的是逐元素的矩阵乘法,并保留其他轴

  • 对于一维数组(向量)numpy.matmul(a, b) 通常不被定义为向量之间的运算,除非 a 是一个二维数组(表示多个向量)的单个行或列,并且 b 的形状与之兼容。
  • 对于二维数组(矩阵)numpy.matmul(A, B) 计算的是矩阵 A 和 B 的乘积,其中 A 的列数必须与 B 的行数相等。这与 numpy.dot() 对于二维数组的行为相同。
  • 对于更高维度的数组numpy.matmul() 遵循爱因斯坦求和约定(Einstein summation convention)的特定规则,允许在不同维度的数组之间执行矩阵乘法。这包括批处理矩阵乘法,其中每个批次独立地进行乘法运算。
>>>jax.numpy.matmul(jax.numpy.array([1,2,3]),jax.numpy.array([1,2,3]))
Array(14, dtype=int32)
>>>jax.numpy.matmul(jax.numpy.array([[1,2,3],[4,5,6]]),jax.numpy.array([1,2,3]))
Array([14, 32], dtype=int32)
>>>jax.numpy.matmul(jax.numpy.array([[1,2],[4,5]]),jax.numpy.array([[1,2],[4,5]]))
Array([[ 9, 12],[24, 33]], dtype=int32)
>>>a = jax.numpy.zeros((1,3,2))
>>>b = jax.numpy.zeros((1,2,4))
>>>jax.numpy.matmul(a,b).shape
(1, 3, 4) #dot ret (1,3,1,4)

arange

jax.numpy.arange(startstop=Nonestep=Nonedtype=None*device=None)[source]

注:default step 为1,在区间[start,stop)生成步长为1的数组,类似range函数

>>>jax.numpy.arange(0,10,1)
Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

interp 

interp(xxpfpleft=Noneright=Noneperiod=None)[source]

注:在xp点列表中线性插值x,线性插值满足y=y_{i}+\frac{y_{i+1}-y_{i}}{x_{i+1}-x_{i}}(x-x_{i}),x\epsilon [x_{i},x_{i+1}),xi和xi+1表示xp数组相邻两点,插值x位于两点区间之间,xp点对于y值为fp,线性插值为保持符合fp = fun(xp)两点区间斜率的增量

>>>xp = jax.numpy.arange(0,10,1)
>>>fp = jax.numpy.array(range(0,10,1)) * 2
>>>x = jax.numpy.array([1,2,3])
>>>jax.numpy.interp(x,xp,fp)
Array([2., 4., 6.], dtype=float32)

tile

jax.numpy.tile(Areps)[source]

注:将A数组按reps重复化生成新Array

a = jax.numpy.array([1,2,3])
>>>jax.numpy.tile(a,2)
Array([1, 2, 3, 1, 2, 3], dtype=int32)
>>>jax.numpy.tile(a,(2,))
Array([1, 2, 3, 1, 2, 3], dtype=int32)
>>>jax.numpy.tile(a,(1,1))
Array([[1, 2, 3]], dtype=int32)
>>>jax.numpy.tile(a,(2,1)) # repeat axis 0 (row) by 2, repeat axis 1 (col) by 1
Array([[1, 2, 3],[1, 2, 3]], dtype=int32)

reshape

jax.numpy.reshape(ashape=Noneorder='C'*newshape=Deprecatedcopy=None)[source]

注:从定义Array a的shape形状为shape元组(),支持-1,推断dim数值

>>>a = jax.numpy.array([[1, 2, 3],[4, 5, 6]])
>>>jax.numpy.reshape(a,6) # equal reshape(a,(6,))
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>>jax.numpy.reshape(a,-1) # equal reshape(a,6)  -1 is inferred to be 3
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>>jax.numpy.reshape(a,(-1,2)) # equal reshape(a,(3,2)) , -1 is inferred to be 3
Array([[1, 2],[3, 4],[5, 6]], dtype=int32)
>>>jax.numpy.reshape(a,(1,-1)) # not (n,) inferred to 2 d
Array([[1, 2, 3, 4, 5, 6]], dtype=int32)

相关文章:

Jax(Random、Numpy)常用函数

目录 Jax vmap Array reshape Random PRNGKey uniform normal split choice Numpy expand_dims linspace jax.numpy.linalg[pkg] dot matmul arange interp tile reshape Jax jit jax.jit(fun, in_shardingsUnspecifiedValue, out_shardingsUnspecifiedVa…...

python-pptx 中 placeholder 和 shape 有什么区别?

在 python-pptx 库中&#xff0c;placeholder 和 shape 是两个核心概念。虽然它们看起来相似&#xff0c;但在功能和作用上存在显著的区别。为了更好地理解这两个概念&#xff0c;我们可以通过它们的定义、使用场景以及实际代码示例来剖析其差异。 Python-pptx 的官网链接&…...

王者农药更新版

一、启动文件配置 二、GPIO使用 2.1基本步骤 1.配置GPIO&#xff0c;所以RCC开启APB2时钟 2.GPIO初始化&#xff08;结构体&#xff09; 3.给GPIO引脚设置高/低电平&#xff08;WriteBit&#xff09; 2.2Led循环点亮&#xff08;GPIO输出&#xff09; 1.RCC开启APB2时钟。…...

各省份消费差距(城乡差距)数据(2005-2022年)

消费差距&#xff0c;特别是城乡消费差距&#xff0c;是衡量一个国家或地区经济发展均衡性的重要指标。 2005年-2022年各省份消费差距&#xff08;城乡差距&#xff09;数据&#xff08;大数据&#xff09;.zip资源-CSDN文库https://download.csdn.net/download/2401_84585615/…...

[Linux] 进程创建、退出和等待

标题&#xff1a;[Linux] 进程创建、退出和等待 个人主页水墨不写bug &#xff08;图片来源于AI&#xff09; 目录 一、进程创建fork&#xff08;&#xff09; 1&#xff09; fork的返回值&#xff1a; 2&#xff09;写时拷贝 ​编辑3&#xff09;fork常规用法 4&#xff…...

微软推出针对个人的 “AI伴侣” Copilot 会根据用户的行为模式、习惯自动进化

微软推出了为每个人提供的“AI伴侣”Copilot&#xff0c;它不仅能够理解用户的需求&#xff0c;还能根据用户的日常习惯和偏好进行适应和进化。帮助处理各种任务和复杂的日常生活场景。 它能够根据用户的生活背景提供帮助和建议&#xff0c;保护用户的隐私和数据安全。Copilot…...

【QT】QT入门

个人主页~ QT入门 一、简述QT1、什么是QT2、QT的优势3、应用场景 二、QT的基本使用1、新建项目&#xff08;1&#xff09;选择项目模版&#xff08;2&#xff09;选择项目路径&#xff08;3&#xff09;选择构建系统&#xff08;4&#xff09;填写类信息设置界面&#xff08;5&…...

Linux 6.11版本发布

Linux 6.11版本的发布是Linux社区的一个重要里程碑&#xff0c;它不仅在实时计算、性能优化方面取得了显著进展&#xff0c;还在安全性上迈出了关键一步。 一、实时计算与性能优化 1.io_uring子系统支持 Linux 6.11引入了io_uring子系统的增强功能&#xff0c;特别是支持了b…...

CSS 参考手册

CSS 参考手册 概述 CSS(层叠样式表)是一种用于描述HTML或XML文档样式的样式表语言。它用于控制网页的布局和外观,使网页设计更加美观和响应式。CSS可以定义文本颜色、字体、布局、响应式设计等,是网页设计和开发中不可或缺的一部分。 基础语法 CSS的基本语法由选择器和…...

数据采集工具sqoop介绍

文章目录 什么是sqoop?一、Sqoop的起源与发展二、Sqoop的主要功能三、Sqoop的工作原理四、Sqoop的使用场景五、Sqoop的优势六、Sqoop的安装与配置 sqoop命令行一、Sqoop简介与架构二、Sqoop特点三、Sqoop常用命令及参数四、使用示例五、注意事项 什么是sqoop? Sqoop是一款开…...

扫盲:写给UI设计师的SCADA系统知识点

一、SCADA是什么&#xff0c;及其组成。 SCADA&#xff08;Supervisory Control And Data Acquisition&#xff0c;监控与数据采集系统&#xff09;是一种用于实时监控、控制和数据采集的自动化系统。 SCADA的组成部分&#xff1a; - 人机界面&#xff08;HMI*&#xff1a;提…...

类的特殊成员函数——三之法则、五之法则、零之法则

系统中的动态资源、文件句柄&#xff08;socket描述符、文件描述符&#xff09;是有限的&#xff0c;在类中若涉及对此类资源的操作&#xff0c;但是未做到妥善的管理&#xff0c;常会造成资源泄露问题&#xff0c;严重的可能造成资源不可用。或引发未定义行为&#xff0c;进而…...

计算机毕业设计 智慧物业服务系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…...

Python软体中使用SpaCy进行命名实体识别

Python软体中使用SpaCy进行命名实体识别 命名实体识别(Named Entity Recognition,NER)是自然语言处理(NLP)中的一个重要任务,它涉及识别文本中的命名实体,例如人名、地名、组织名等。SpaCy是一种流行的NLP库,提供了高效的NER功能。在本文中,我们将介绍如何使用SpaCy进…...

华为云技术深度解析:以系统性创新加速智能化升级

华为云技术深度解析&#xff1a;以系统性创新加速智能化升级 在当今数字化转型的浪潮中&#xff0c;云计算作为关键的基础设施&#xff0c;正以前所未有的速度推动着各行各业的智能化升级。作为全球领先的云服务提供商&#xff0c;华为云凭借其深厚的技术积累和创新实力&#…...

推理攻击-Python案例

1、本文通过推理攻击的方式来估计训练集中每个类别的样本数量、某样本是否在训练集中。 2、一种简单的实现方法&#xff1a;用模型对训练数据标签进行拟合&#xff0c;拟合结果即推理为训练集中的情况。 3、了解这些案例可以帮助我们更好的保护数据隐私。 推理攻击&#xff08;…...

find_box_3d

参数 &#xff08;ObjectModel3DScene, SideLen1, SideLen2, SideLen3, MinScore, GenParam : GrippingPose, Score, ObjectModel3DBox, BoxInformation) 入参介绍 1&#xff0c;ObjectModel3DScene&#xff0c; 输入的3d模型&#xff0c;这个模型最好是由xyx三通道点…...

Visual Studio2017编译GDAL3.0.2源码过程

一、编译环境 操作系统&#xff1a;Windows 10企业版 编译工具&#xff1a;Visual Studio 2017旗舰版 源码版本&#xff1a;gdal3.0.2 二、生成解决方案 打开Visual Studio 2017的x64本机生成工具&#xff0c;切换到gdal3.0.2源码根目录&#xff1b;执行generate_vcxproj.b…...

计算机网络——email

pop3拉出来 超出ASCII码范围就不让传了 这样就可以传更大的文件...

【Linux】信号知识三把斧——信号的产生、保存和处理

目录​​​​​​​ 1、关于信号的前置知识 1.1.什么是信号&#xff1f; 1.2.为什么要学习信号&#xff1f; 1.3.如何学习信号&#xff1f; 1.4.一些常见的信号 1.5.信号的处理方式 1.6.为什么每一个进程都可以系统调用&#xff1f; 2.信号的产生 2.1.kill命令产生信号…...

测试微信模版消息推送

进入“开发接口管理”--“公众平台测试账号”&#xff0c;无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息&#xff1a; 关注测试号&#xff1a;扫二维码关注测试号。 发送模版消息&#xff1a; import requests da…...

XCTF-web-easyupload

试了试php&#xff0c;php7&#xff0c;pht&#xff0c;phtml等&#xff0c;都没有用 尝试.user.ini 抓包修改将.user.ini修改为jpg图片 在上传一个123.jpg 用蚁剑连接&#xff0c;得到flag...

DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径

目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...

C++:std::is_convertible

C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...

pam_env.so模块配置解析

在PAM&#xff08;Pluggable Authentication Modules&#xff09;配置中&#xff0c; /etc/pam.d/su 文件相关配置含义如下&#xff1a; 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块&#xff0c;负责验证用户身份&am…...

【android bluetooth 框架分析 04】【bt-framework 层详解 1】【BluetoothProperties介绍】

1. BluetoothProperties介绍 libsysprop/srcs/android/sysprop/BluetoothProperties.sysprop BluetoothProperties.sysprop 是 Android AOSP 中的一种 系统属性定义文件&#xff08;System Property Definition File&#xff09;&#xff0c;用于声明和管理 Bluetooth 模块相…...

MySQL中【正则表达式】用法

MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现&#xff08;两者等价&#xff09;&#xff0c;用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例&#xff1a; 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...

在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用

1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...

【Go语言基础【13】】函数、闭包、方法

文章目录 零、概述一、函数基础1、函数基础概念2、参数传递机制3、返回值特性3.1. 多返回值3.2. 命名返回值3.3. 错误处理 二、函数类型与高阶函数1. 函数类型定义2. 高阶函数&#xff08;函数作为参数、返回值&#xff09; 三、匿名函数与闭包1. 匿名函数&#xff08;Lambda函…...

MySQL 知识小结(一)

一、my.cnf配置详解 我们知道安装MySQL有两种方式来安装咱们的MySQL数据库&#xff0c;分别是二进制安装编译数据库或者使用三方yum来进行安装,第三方yum的安装相对于二进制压缩包的安装更快捷&#xff0c;但是文件存放起来数据比较冗余&#xff0c;用二进制能够更好管理咱们M…...