当前位置: 首页 > news >正文

【扩散模型 李宏毅B站教学以及基础代码运用】

李宏毅教学视频:
Link1

B站DDPM公式推导以及代码实现:
Link2

这个视频里面有论文里面的公式推导,并且1小时10分开始讲解实例代码。

文章目录

    • 扩散模型概念:
    • Diffusion Model工作原理:
    • 影像生成模型本质上的共同目标
    • B站简单示例代码讲解

扩散模型概念:

就像石头里面已经有了雕塑,只需要看我们怎么把其他多余的部分去掉。
在这里插入图片描述
注意观察,我们每一个Denoise阶段都不一样,因为每一个阶段传入的图片以及需要处理的noise都不一样,并且直接产生图片比直接产生噪音更难,所以我们通过预测noise来解决问题。
在这里插入图片描述

比如下图所示:step2是我们加的噪声,那么传入input和2的时候就希望预测出gt了,然后进行相减得到step1的图片。
在这里插入图片描述

Diffusion Model工作原理:

VAE和Diffusion的区别
在这里插入图片描述
先看整个训练过程:
在这里插入图片描述

实际结果和我们想的是不一样的。训练时通过X0和噪声得到一个图,逆向的时候输入t和生成的图来得到噪音。想象的是一点一点加入噪音,实际上是直接加进去的。在这里插入图片描述
推断时刻:theat是带有参数的网络。
在这里插入图片描述

影像生成模型本质上的共同目标

通过采样一个高深distribution生成一个图片。希望生成的图片和真实的图片的distribution很接近。
在这里插入图片描述
那么怎么衡量这两个分布的接近程度呢?多数采用的都是Maximum liklihood Estimation.
我们希望我们采样的数据能够通过theta网络计算出来的概率越大越好。 在这里插入图片描述
通过数学变换,将概率最大变为Pdata和Ptheat这两个distribution的KL散度最小。
在这里插入图片描述
VAE的下界
Ptheat(x)表示:通过theta产生x的概率。
在这里插入图片描述

在这里插入图片描述
DDPM计算Ptheta(x)的方法 下图表示产生X0的概率。
在这里插入图片描述
两者对比
在这里插入图片描述
接下来需要计算q(x1|x0)此类公式。
计算方法:X1到X2的计算方法在论文中有提及。
在这里插入图片描述
两个高斯分布都是服从N(0,1),相加的话还是一个高斯分布,并且还是服从N(0,1),只是前面系数会发生变化。系数的话是根号下面数字相加。所以相加之后均值还是为0,方差a方加b方即可,这个在另外一个视频里面有讲解。
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
经过一番推导之后得到:
在这里插入图片描述
之后计算最下面三项:
在这里插入图片描述
通过以下推导:
在这里插入图片描述
之后通过X0,Xt可以得到Xt-1的分布。
在这里插入图片描述
可以看到前面一项的mean 和 variance是固定的,第二项的variance也是固定的,因此我们需要把第二项的mean变得和第一项的接近。
在这里插入图片描述
那么怎么minimiaze这个mean呢?希望用Xt去预测出来那个mean。
在这里插入图片描述
经过推导:
在这里插入图片描述
最终得到下图:
在这里插入图片描述
里面beta可以学习,但是效果不好,所以使用线性固定。最后加上一个噪声猜测是为了增强鲁棒性,并且本身就是从噪声开始,不加噪声的话可能不会生成图片。

B站简单示例代码讲解

# 加载数据集
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torchs_curve,_ = make_s_curve(10**4,noise=0.1)
print(np.shape(s_curve))
s_curve = s_curve[:,[0,2]]/10.0print("shape of s:",np.shape(s_curve))data = s_curve.Tfig,ax = plt.subplots()
ax.scatter(*data,color='blue',edgecolor='white');ax.axis('off')dataset = torch.Tensor(s_curve).float()

在这里插入图片描述

# 2确定超参数的值
num_steps = 100
#制定每一步的beta
betas = torch.linspace(-6,6,num_steps)
betas = torch.sigmoid(betas)*(0.5e-2 - 1e-5)+1e-5#计算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1-betas
alphas_prod = torch.cumprod(alphas,0)
# print(alphas_prod)
alphas_prod_p = torch.cat([torch.tensor([1]).float(),alphas_prod[:-1]],0)
# print(alphas_prod_p)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)assert alphas.shape==alphas_prod.shape==alphas_prod_p.shape==\
alphas_bar_sqrt.shape==one_minus_alphas_bar_log.shape\
==one_minus_alphas_bar_sqrt.shape
print("all the same shape",betas.shape)、确定扩散过程任意时刻的采样值#3 计算任意时刻的x采样值,基于x_0和重参数化
def q_x(x_0,t):"""可以基于x[0]得到任意时刻t的x[t]"""noise = torch.randn_like(x_0)alphas_t = alphas_bar_sqrt[t]alphas_1_m_t = one_minus_alphas_bar_sqrt[t]return (alphas_t * x_0 + alphas_1_m_t * noise)#在x[0]的基础上添加噪声
j
# 4 演示原始数据分布加噪100步后的结果num_shows = 20
fig,axs = plt.subplots(2,10,figsize=(28,3))
plt.rc('text',color='black')#共有10000个点,每个点包含两个坐标
#生成100步以内每隔5步加噪声后的图像
for i in range(num_shows):j = i//10k = i%10q_i = q_x(dataset,torch.tensor([i*num_steps//num_shows]))#生成t时刻的采样数据axs[j,k].scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white')axs[j,k].set_axis_off()axs[j,k].set_title('$q(\mathbf{x}_{'+str(i*num_steps//num_shows)+'})$')

在这里插入图片描述

# 5 编写拟合逆扩散过程高斯分布的模型import torch
import torch.nn as nn
​
class MLPDiffusion(nn.Module):def __init__(self,n_steps,num_units=128):super(MLPDiffusion,self).__init__()self.linears = nn.ModuleList([nn.Linear(2,num_units),nn.ReLU(),nn.Linear(num_units,num_units),nn.ReLU(),nn.Linear(num_units,num_units),nn.ReLU(),nn.Linear(num_units,2),])self.step_embeddings = nn.ModuleList([nn.Embedding(n_steps,num_units),nn.Embedding(n_steps,num_units),nn.Embedding(n_steps,num_units),])def forward(self,x,t):
#         x = x_0for idx,embedding_layer in enumerate(self.step_embeddings):t_embedding = embedding_layer(t)x = self.linears[2*idx](x)x += t_embeddingx = self.linears[2*idx+1](x)x = self.linears[-1](x)return x

loss_fn 就是Lsimple得表达式。通过传入参数,生成一个随机噪声,并且送入model里面,那么上面也讲了,model的作用是根据X0,和t预测出我们应该减去的噪声,所以损失函数就是用生成的噪声减去预测的噪声。
在这里插入图片描述

# 6 编写训练的误差函数
def diffusion_loss_fn(model,x_0,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps):"""对任意时刻t进行采样计算loss"""batch_size = x_0.shape[0]#对一个batchsize样本生成随机的时刻tt = torch.randint(0,n_steps,size=(batch_size//2,))t = torch.cat([t,n_steps-1-t],dim=0)t = t.unsqueeze(-1)#x0的系数a = alphas_bar_sqrt[t]#eps的系数aml = one_minus_alphas_bar_sqrt[t]#生成随机噪音epse = torch.randn_like(x_0)#构造模型的输入x = x_0*a+e*aml#送入模型,得到t时刻的随机噪声预测值output = model(x,t.squeeze(-1))#与真实噪声一起计算误差,求平均值return torch.pow((e - output),2).mean()
# 7、编写逆扩散采样函数(inference)def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):"""从x[T]恢复x[T-1]、x[T-2]|...x[0]"""cur_x = torch.randn(shape)x_seq = [cur_x]for i in reversed(range(n_steps)):cur_x = p_sample(model,cur_x,i,betas,one_minus_alphas_bar_sqrt)x_seq.append(cur_x)return x_seq
​
def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):"""从x[T]采样t时刻的重构值"""t = torch.tensor([t])coeff = betas[t] / one_minus_alphas_bar_sqrt[t]eps_theta = model(x,t)mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))z = torch.randn_like(x)sigma_t = betas[t].sqrt()sample = mean + sigma_t * zreturn (sample)
# 8、开始训练模型,打印loss及中间重构效果seed = 1234class EMA():"""构建一个参数平滑器"""def __init__(self,mu=0.01):self.mu = muself.shadow = {}def register(self,name,val):self.shadow[name] = val.clone()def __call__(self,name,x):assert name in self.shadownew_average = self.mu * x + (1.0-self.mu)*self.shadow[name]self.shadow[name] = new_average.clone()return new_averageprint('Training model...')
batch_size = 128
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
num_epoch = 4000
plt.rc('text',color='blue')
​
model = MLPDiffusion(num_steps)#输出维度是2,输入是x和step
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)for t in range(num_epoch):for idx,batch_x in enumerate(dataloader):loss = diffusion_loss_fn(model,batch_x,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,num_steps)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),1.)optimizer.step()if(t%100==0):print(loss)x_seq = p_sample_loop(model,dataset.shape,num_steps,betas,one_minus_alphas_bar_sqrt)fig,axs = plt.subplots(1,10,figsize=(28,3))for i in range(1,11):cur_x = x_seq[i*10].detach()axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');axs[i-1].set_axis_off();axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')

最后的演示

动画演示扩散过程和逆扩散过程import io
from PIL import Image
​
imgs = []
for i in range(100):plt.clf()q_i = q_x(dataset,torch.tensor([i]))plt.scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white',s=5);plt.axis('off');img_buf = io.BytesIO()plt.savefig(img_buf,format='png')img = Image.open(img_buf)imgs.append(img)
mg = Image.open(img_buf)reverse.append(img)
reverse = []
for i in range(100):plt.clf()cur_x = x_seq[i].detach()plt.scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white',s=5);plt.axis('off')img_buf = io.BytesIO()plt.savefig(img_buf,format='png')img = Image.open(img_buf)reverse.append(img)
​
imgs = imgs +reverse
imgs[0].save("diffusion.gif",format='GIF',append_images=imgs,save_all=True,duration=100,loop=0)

相关文章:

【扩散模型 李宏毅B站教学以及基础代码运用】

李宏毅教学视频: Link1 B站DDPM公式推导以及代码实现: Link2 这个视频里面有论文里面的公式推导,并且1小时10分开始讲解实例代码。 文章目录 扩散模型概念:Diffusion Model工作原理:影像生成模型本质上的共同目标B站…...

SpringBoot隐藏文件

1.设置 2.输入file Types 3.点击忽略文件或者文件夹 4.成功...

常见数据库介绍对比之SQL关系型数据库

1.关系型数据库介绍 关系型数据库是一种基于关系模型的数据库,它使用表格来组织和存储数据。下面是一些常见的关系型数据库: 1.1. MySQL MySQL是一种开源的关系型数据库管理系统(RDBMS),广泛用于Web应用程序和企业级…...

OLED透明屏模块:引领未来显示技术的突破

OLED透明屏模块作为一项引领未来显示技术的突破,以其独特的特点和卓越的画质在市场上引起了广泛关注。 根据行业报告,预计到2025年,OLED透明屏模块将占据智能手机市场的20%份额,并在汽车导航系统市场中占据30%以上份额。 那么&am…...

Python_操作记录

1、Pandas读取数据文件(以文本文件作为示例),sep表示间隔,headerNone表示无标题行 df pd.read_table("data/youcans3.dat", sep"\t", headerNone) 2、线性规划问题求解 1)问题定义,…...

常用激活函数整理

最近一边应付工作,一边在补足人工智能的一些基础知识,这个方向虽然新兴,但已是卷帙浩繁,有时不知从何入手,幸亏有个适合基础薄弱的人士学习的网站,每天学习一点,积跬步以至千里吧。有像我一样学…...

uniapp 地图跳转到第三方导航软件 直接打包成apk

// 判断是否存在导航软件judgeHasExistNavignation() {let navAppParam [{pname: com.baidu.BaiduMap,action: baidumap://}, // 百度{pname: com.autonavi.minimap,action: iosamap://}, // 高德{pname: com.tencent.map,action: tencentmap://}, // 腾讯];return navAppPara…...

CentOS 8 通过YUM方式升级最新内核

CentOS 8 通过YUM方式升级最新内核 查看当前内核 uname -r 4.18.0-193.6.3.el8_2.x86_64导入 ELRepo 仓库的公钥: rpm --import https://www.elrepo.org/RPM-GPG-KEY-elrepo.org安装升级内核相关的yum源仓库(安装 ELRepo 仓库的 yum 源) yum install https://www…...

java 版本企业招标投标管理系统源码+功能描述+tbms+及时准确+全程电子化

功能描述 1、门户管理:所有用户可在门户页面查看所有的公告信息及相关的通知信息。主要板块包含:招标公告、非招标公告、系统通知、政策法规。 2、立项管理:企业用户可对需要采购的项目进行立项申请,并提交审批,查看所…...

Python爬虫数据存哪里|数据存储到文件的几种方式

前言 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 爬虫请求解析后的数据,需要保存下来,才能进行下一步的处理,一般保存数据的方式有如下几种: 文件:txt、csv、excel、json等,保存数据量小。 关系型数据库…...

软件测试/测试开发丨Web自动化 测试用例流程设计

点此获取更多相关资料 本文为霍格沃兹测试开发学社学员学习笔记分享 原文链接:https://ceshiren.com/t/topic/27173 一、测试用例通用结构回顾 1.1、现有测试用例存在的问题 可维护性差可读性差稳定性差 1.2、用例结构设计 测试用例的编排测试用例的项目结构 1…...

git撤销修改命令

要撤销Git中尚未提交的所有修改,可以使用以下几种方法: 1、使用git checkout命令丢弃工作目录的修改,重置工作目录中所有文件的修改。 git checkout . 2、使用git reset命令重置暂存区和工作目录, 重置暂存区和工作目录,回到最后一次提交后的状态。 …...

EOCR-AR电机保护器自动复位的启用条件说明

为适用不同的现场使用需求,施耐德韩国公司推出了带有自动复位功能的模拟型电动机保护器-EOCR-AR。EOCR-AR电机保护器具有过电流、缺相、堵转保护功能,还可根据实际需要设置自动复位时间。 EOCR-AR自动复位的设置方法 如上图,R-TIME旋钮是自动…...

Apache nginx解析漏洞复现

文章目录 空字节漏洞安装环境漏洞复现 背锅解析漏洞安装环境漏洞复现 空字节漏洞 安装环境 将nginx解压后放到c盘根目录下: 运行startup.bat启动环境: 在HTML文件夹下有它的主页文件: 漏洞复现 nginx在遇到后缀名有php的文件时,…...

.NET之后,再无大创新

回想起来,2001年发布的.NET已经是距离最近的一次软件开发技术的整体创新了,后续的新技术就没有在各个端都这么成功的了。.NET是Windows平台下软件开发技术的巨大变革。在此之前,有VB、C(MFC)、JSP,在此之后…...

【大麦小米学量化】什么是量化交易?哪些人适合做量化交易?

系列文章目录 文章目录 系列文章目录学霸的梦想前言一、什么是量化交易?二、哪些人适合做量化交易?三、量化交易都需要掌握哪些技术和方法?总结 学霸的梦想 小米支棱着迷糊的眼睛,一脸懵逼的问大麦:“我说大麦哥哥&…...

计算机视觉的应用12-卷积神经网络中图像特征提取的可视化研究,让大家理解特征提取的全过程

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用12-卷积神经网络中图像特征提取的可视化研究,让大家理解特征提取的全过程。 要理解卷积神经网络中图像特征提取的全过程,我们可以将其比喻为人脑对视觉信息的处理过程。就像…...

el-table中点击跳转到详情页的两种方法

跳转的两种写法: 1.使用keep-alive使组件缓存,防止刷新时参数丢失 keep-alive 组件用于缓存和保持组件的状态,而不是路由参数。它可以在组件切换时保留组件的状态,从而避免重新渲染和加载数据。 keep-alive 主要用于提高页面性能和用户体验,而…...

RT-DETR个人整理向理解

一、前言 在开始介绍RT-DETR这个网络之前,我们首先需要先了解DETR这个系列的网络与我们常提及的anchor-base以及anchor-free存在着何种差异。 首先我们先简单讨论一下anchor-base以及anchor-free两者的差异与共性: 1、两者差异:顾名思义&…...

易点易动库存管理系统与ERP系统打通,帮助企业实现低值易耗品管理

现今,企业管理日趋复杂,无论是核心经营还是辅助环节,都需要依靠信息化手段来提升效率。而低值易耗品作为企业日常运营中的必需品,其管理也面临诸多挑战。传统做法效率低下,容易出错。如何通过信息化手段实现低值易耗品的高效管理,成为许多企业必顾及的一个课题。 易点易动作为…...

KubeSphere 容器平台高可用:环境搭建与可视化操作指南

Linux_k8s篇 欢迎来到Linux的世界,看笔记好好学多敲多打,每个人都是大神! 题目:KubeSphere 容器平台高可用:环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

反向工程与模型迁移:打造未来商品详情API的可持续创新体系

在电商行业蓬勃发展的当下,商品详情API作为连接电商平台与开发者、商家及用户的关键纽带,其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息(如名称、价格、库存等)的获取与展示,已难以满足市场对个性化、智能…...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以? 在 Golang 的面试中,map 类型的使用是一个常见的考点,其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

Python:操作 Excel 折叠

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...

Redis相关知识总结(缓存雪崩,缓存穿透,缓存击穿,Redis实现分布式锁,如何保持数据库和缓存一致)

文章目录 1.什么是Redis?2.为什么要使用redis作为mysql的缓存?3.什么是缓存雪崩、缓存穿透、缓存击穿?3.1缓存雪崩3.1.1 大量缓存同时过期3.1.2 Redis宕机 3.2 缓存击穿3.3 缓存穿透3.4 总结 4. 数据库和缓存如何保持一致性5. Redis实现分布式…...

MFC内存泄露

1、泄露代码示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 获取 Ribbon Bar 指针// 创建自定义按钮CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…...

Cesium1.95中高性能加载1500个点

一、基本方式&#xff1a; 图标使用.png比.svg性能要好 <template><div id"cesiumContainer"></div><div class"toolbar"><button id"resetButton">重新生成点</button><span id"countDisplay&qu…...

2024年赣州旅游投资集团社会招聘笔试真

2024年赣州旅游投资集团社会招聘笔试真 题 ( 满 分 1 0 0 分 时 间 1 2 0 分 钟 ) 一、单选题(每题只有一个正确答案,答错、不答或多答均不得分) 1.纪要的特点不包括()。 A.概括重点 B.指导传达 C. 客观纪实 D.有言必录 【答案】: D 2.1864年,()预言了电磁波的存在,并指出…...

IT供电系统绝缘监测及故障定位解决方案

随着新能源的快速发展&#xff0c;光伏电站、储能系统及充电设备已广泛应用于现代能源网络。在光伏领域&#xff0c;IT供电系统凭借其持续供电性好、安全性高等优势成为光伏首选&#xff0c;但在长期运行中&#xff0c;例如老化、潮湿、隐裂、机械损伤等问题会影响光伏板绝缘层…...

Unity | AmplifyShaderEditor插件基础(第七集:平面波动shader)

目录 一、&#x1f44b;&#x1f3fb;前言 二、&#x1f608;sinx波动的基本原理 三、&#x1f608;波动起来 1.sinx节点介绍 2.vertexPosition 3.集成Vector3 a.节点Append b.连起来 4.波动起来 a.波动的原理 b.时间节点 c.sinx的处理 四、&#x1f30a;波动优化…...