当前位置: 首页 > news >正文

【扩散模型 李宏毅B站教学以及基础代码运用】

李宏毅教学视频:
Link1

B站DDPM公式推导以及代码实现:
Link2

这个视频里面有论文里面的公式推导,并且1小时10分开始讲解实例代码。

文章目录

    • 扩散模型概念:
    • Diffusion Model工作原理:
    • 影像生成模型本质上的共同目标
    • B站简单示例代码讲解

扩散模型概念:

就像石头里面已经有了雕塑,只需要看我们怎么把其他多余的部分去掉。
在这里插入图片描述
注意观察,我们每一个Denoise阶段都不一样,因为每一个阶段传入的图片以及需要处理的noise都不一样,并且直接产生图片比直接产生噪音更难,所以我们通过预测noise来解决问题。
在这里插入图片描述

比如下图所示:step2是我们加的噪声,那么传入input和2的时候就希望预测出gt了,然后进行相减得到step1的图片。
在这里插入图片描述

Diffusion Model工作原理:

VAE和Diffusion的区别
在这里插入图片描述
先看整个训练过程:
在这里插入图片描述

实际结果和我们想的是不一样的。训练时通过X0和噪声得到一个图,逆向的时候输入t和生成的图来得到噪音。想象的是一点一点加入噪音,实际上是直接加进去的。在这里插入图片描述
推断时刻:theat是带有参数的网络。
在这里插入图片描述

影像生成模型本质上的共同目标

通过采样一个高深distribution生成一个图片。希望生成的图片和真实的图片的distribution很接近。
在这里插入图片描述
那么怎么衡量这两个分布的接近程度呢?多数采用的都是Maximum liklihood Estimation.
我们希望我们采样的数据能够通过theta网络计算出来的概率越大越好。 在这里插入图片描述
通过数学变换,将概率最大变为Pdata和Ptheat这两个distribution的KL散度最小。
在这里插入图片描述
VAE的下界
Ptheat(x)表示:通过theta产生x的概率。
在这里插入图片描述

在这里插入图片描述
DDPM计算Ptheta(x)的方法 下图表示产生X0的概率。
在这里插入图片描述
两者对比
在这里插入图片描述
接下来需要计算q(x1|x0)此类公式。
计算方法:X1到X2的计算方法在论文中有提及。
在这里插入图片描述
两个高斯分布都是服从N(0,1),相加的话还是一个高斯分布,并且还是服从N(0,1),只是前面系数会发生变化。系数的话是根号下面数字相加。所以相加之后均值还是为0,方差a方加b方即可,这个在另外一个视频里面有讲解。
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
经过一番推导之后得到:
在这里插入图片描述
之后计算最下面三项:
在这里插入图片描述
通过以下推导:
在这里插入图片描述
之后通过X0,Xt可以得到Xt-1的分布。
在这里插入图片描述
可以看到前面一项的mean 和 variance是固定的,第二项的variance也是固定的,因此我们需要把第二项的mean变得和第一项的接近。
在这里插入图片描述
那么怎么minimiaze这个mean呢?希望用Xt去预测出来那个mean。
在这里插入图片描述
经过推导:
在这里插入图片描述
最终得到下图:
在这里插入图片描述
里面beta可以学习,但是效果不好,所以使用线性固定。最后加上一个噪声猜测是为了增强鲁棒性,并且本身就是从噪声开始,不加噪声的话可能不会生成图片。

B站简单示例代码讲解

# 加载数据集
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torchs_curve,_ = make_s_curve(10**4,noise=0.1)
print(np.shape(s_curve))
s_curve = s_curve[:,[0,2]]/10.0print("shape of s:",np.shape(s_curve))data = s_curve.Tfig,ax = plt.subplots()
ax.scatter(*data,color='blue',edgecolor='white');ax.axis('off')dataset = torch.Tensor(s_curve).float()

在这里插入图片描述

# 2确定超参数的值
num_steps = 100
#制定每一步的beta
betas = torch.linspace(-6,6,num_steps)
betas = torch.sigmoid(betas)*(0.5e-2 - 1e-5)+1e-5#计算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1-betas
alphas_prod = torch.cumprod(alphas,0)
# print(alphas_prod)
alphas_prod_p = torch.cat([torch.tensor([1]).float(),alphas_prod[:-1]],0)
# print(alphas_prod_p)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)assert alphas.shape==alphas_prod.shape==alphas_prod_p.shape==\
alphas_bar_sqrt.shape==one_minus_alphas_bar_log.shape\
==one_minus_alphas_bar_sqrt.shape
print("all the same shape",betas.shape)、确定扩散过程任意时刻的采样值#3 计算任意时刻的x采样值,基于x_0和重参数化
def q_x(x_0,t):"""可以基于x[0]得到任意时刻t的x[t]"""noise = torch.randn_like(x_0)alphas_t = alphas_bar_sqrt[t]alphas_1_m_t = one_minus_alphas_bar_sqrt[t]return (alphas_t * x_0 + alphas_1_m_t * noise)#在x[0]的基础上添加噪声
j
# 4 演示原始数据分布加噪100步后的结果num_shows = 20
fig,axs = plt.subplots(2,10,figsize=(28,3))
plt.rc('text',color='black')#共有10000个点,每个点包含两个坐标
#生成100步以内每隔5步加噪声后的图像
for i in range(num_shows):j = i//10k = i%10q_i = q_x(dataset,torch.tensor([i*num_steps//num_shows]))#生成t时刻的采样数据axs[j,k].scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white')axs[j,k].set_axis_off()axs[j,k].set_title('$q(\mathbf{x}_{'+str(i*num_steps//num_shows)+'})$')

在这里插入图片描述

# 5 编写拟合逆扩散过程高斯分布的模型import torch
import torch.nn as nn
​
class MLPDiffusion(nn.Module):def __init__(self,n_steps,num_units=128):super(MLPDiffusion,self).__init__()self.linears = nn.ModuleList([nn.Linear(2,num_units),nn.ReLU(),nn.Linear(num_units,num_units),nn.ReLU(),nn.Linear(num_units,num_units),nn.ReLU(),nn.Linear(num_units,2),])self.step_embeddings = nn.ModuleList([nn.Embedding(n_steps,num_units),nn.Embedding(n_steps,num_units),nn.Embedding(n_steps,num_units),])def forward(self,x,t):
#         x = x_0for idx,embedding_layer in enumerate(self.step_embeddings):t_embedding = embedding_layer(t)x = self.linears[2*idx](x)x += t_embeddingx = self.linears[2*idx+1](x)x = self.linears[-1](x)return x

loss_fn 就是Lsimple得表达式。通过传入参数,生成一个随机噪声,并且送入model里面,那么上面也讲了,model的作用是根据X0,和t预测出我们应该减去的噪声,所以损失函数就是用生成的噪声减去预测的噪声。
在这里插入图片描述

# 6 编写训练的误差函数
def diffusion_loss_fn(model,x_0,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps):"""对任意时刻t进行采样计算loss"""batch_size = x_0.shape[0]#对一个batchsize样本生成随机的时刻tt = torch.randint(0,n_steps,size=(batch_size//2,))t = torch.cat([t,n_steps-1-t],dim=0)t = t.unsqueeze(-1)#x0的系数a = alphas_bar_sqrt[t]#eps的系数aml = one_minus_alphas_bar_sqrt[t]#生成随机噪音epse = torch.randn_like(x_0)#构造模型的输入x = x_0*a+e*aml#送入模型,得到t时刻的随机噪声预测值output = model(x,t.squeeze(-1))#与真实噪声一起计算误差,求平均值return torch.pow((e - output),2).mean()
# 7、编写逆扩散采样函数(inference)def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):"""从x[T]恢复x[T-1]、x[T-2]|...x[0]"""cur_x = torch.randn(shape)x_seq = [cur_x]for i in reversed(range(n_steps)):cur_x = p_sample(model,cur_x,i,betas,one_minus_alphas_bar_sqrt)x_seq.append(cur_x)return x_seq
​
def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):"""从x[T]采样t时刻的重构值"""t = torch.tensor([t])coeff = betas[t] / one_minus_alphas_bar_sqrt[t]eps_theta = model(x,t)mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))z = torch.randn_like(x)sigma_t = betas[t].sqrt()sample = mean + sigma_t * zreturn (sample)
# 8、开始训练模型,打印loss及中间重构效果seed = 1234class EMA():"""构建一个参数平滑器"""def __init__(self,mu=0.01):self.mu = muself.shadow = {}def register(self,name,val):self.shadow[name] = val.clone()def __call__(self,name,x):assert name in self.shadownew_average = self.mu * x + (1.0-self.mu)*self.shadow[name]self.shadow[name] = new_average.clone()return new_averageprint('Training model...')
batch_size = 128
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
num_epoch = 4000
plt.rc('text',color='blue')
​
model = MLPDiffusion(num_steps)#输出维度是2,输入是x和step
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)for t in range(num_epoch):for idx,batch_x in enumerate(dataloader):loss = diffusion_loss_fn(model,batch_x,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,num_steps)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),1.)optimizer.step()if(t%100==0):print(loss)x_seq = p_sample_loop(model,dataset.shape,num_steps,betas,one_minus_alphas_bar_sqrt)fig,axs = plt.subplots(1,10,figsize=(28,3))for i in range(1,11):cur_x = x_seq[i*10].detach()axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');axs[i-1].set_axis_off();axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')

最后的演示

动画演示扩散过程和逆扩散过程import io
from PIL import Image
​
imgs = []
for i in range(100):plt.clf()q_i = q_x(dataset,torch.tensor([i]))plt.scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white',s=5);plt.axis('off');img_buf = io.BytesIO()plt.savefig(img_buf,format='png')img = Image.open(img_buf)imgs.append(img)
mg = Image.open(img_buf)reverse.append(img)
reverse = []
for i in range(100):plt.clf()cur_x = x_seq[i].detach()plt.scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white',s=5);plt.axis('off')img_buf = io.BytesIO()plt.savefig(img_buf,format='png')img = Image.open(img_buf)reverse.append(img)
​
imgs = imgs +reverse
imgs[0].save("diffusion.gif",format='GIF',append_images=imgs,save_all=True,duration=100,loop=0)

相关文章:

【扩散模型 李宏毅B站教学以及基础代码运用】

李宏毅教学视频: Link1 B站DDPM公式推导以及代码实现: Link2 这个视频里面有论文里面的公式推导,并且1小时10分开始讲解实例代码。 文章目录 扩散模型概念:Diffusion Model工作原理:影像生成模型本质上的共同目标B站…...

SpringBoot隐藏文件

1.设置 2.输入file Types 3.点击忽略文件或者文件夹 4.成功...

常见数据库介绍对比之SQL关系型数据库

1.关系型数据库介绍 关系型数据库是一种基于关系模型的数据库,它使用表格来组织和存储数据。下面是一些常见的关系型数据库: 1.1. MySQL MySQL是一种开源的关系型数据库管理系统(RDBMS),广泛用于Web应用程序和企业级…...

OLED透明屏模块:引领未来显示技术的突破

OLED透明屏模块作为一项引领未来显示技术的突破,以其独特的特点和卓越的画质在市场上引起了广泛关注。 根据行业报告,预计到2025年,OLED透明屏模块将占据智能手机市场的20%份额,并在汽车导航系统市场中占据30%以上份额。 那么&am…...

Python_操作记录

1、Pandas读取数据文件(以文本文件作为示例),sep表示间隔,headerNone表示无标题行 df pd.read_table("data/youcans3.dat", sep"\t", headerNone) 2、线性规划问题求解 1)问题定义,…...

常用激活函数整理

最近一边应付工作,一边在补足人工智能的一些基础知识,这个方向虽然新兴,但已是卷帙浩繁,有时不知从何入手,幸亏有个适合基础薄弱的人士学习的网站,每天学习一点,积跬步以至千里吧。有像我一样学…...

uniapp 地图跳转到第三方导航软件 直接打包成apk

// 判断是否存在导航软件judgeHasExistNavignation() {let navAppParam [{pname: com.baidu.BaiduMap,action: baidumap://}, // 百度{pname: com.autonavi.minimap,action: iosamap://}, // 高德{pname: com.tencent.map,action: tencentmap://}, // 腾讯];return navAppPara…...

CentOS 8 通过YUM方式升级最新内核

CentOS 8 通过YUM方式升级最新内核 查看当前内核 uname -r 4.18.0-193.6.3.el8_2.x86_64导入 ELRepo 仓库的公钥: rpm --import https://www.elrepo.org/RPM-GPG-KEY-elrepo.org安装升级内核相关的yum源仓库(安装 ELRepo 仓库的 yum 源) yum install https://www…...

java 版本企业招标投标管理系统源码+功能描述+tbms+及时准确+全程电子化

功能描述 1、门户管理:所有用户可在门户页面查看所有的公告信息及相关的通知信息。主要板块包含:招标公告、非招标公告、系统通知、政策法规。 2、立项管理:企业用户可对需要采购的项目进行立项申请,并提交审批,查看所…...

Python爬虫数据存哪里|数据存储到文件的几种方式

前言 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 爬虫请求解析后的数据,需要保存下来,才能进行下一步的处理,一般保存数据的方式有如下几种: 文件:txt、csv、excel、json等,保存数据量小。 关系型数据库…...

软件测试/测试开发丨Web自动化 测试用例流程设计

点此获取更多相关资料 本文为霍格沃兹测试开发学社学员学习笔记分享 原文链接:https://ceshiren.com/t/topic/27173 一、测试用例通用结构回顾 1.1、现有测试用例存在的问题 可维护性差可读性差稳定性差 1.2、用例结构设计 测试用例的编排测试用例的项目结构 1…...

git撤销修改命令

要撤销Git中尚未提交的所有修改,可以使用以下几种方法: 1、使用git checkout命令丢弃工作目录的修改,重置工作目录中所有文件的修改。 git checkout . 2、使用git reset命令重置暂存区和工作目录, 重置暂存区和工作目录,回到最后一次提交后的状态。 …...

EOCR-AR电机保护器自动复位的启用条件说明

为适用不同的现场使用需求,施耐德韩国公司推出了带有自动复位功能的模拟型电动机保护器-EOCR-AR。EOCR-AR电机保护器具有过电流、缺相、堵转保护功能,还可根据实际需要设置自动复位时间。 EOCR-AR自动复位的设置方法 如上图,R-TIME旋钮是自动…...

Apache nginx解析漏洞复现

文章目录 空字节漏洞安装环境漏洞复现 背锅解析漏洞安装环境漏洞复现 空字节漏洞 安装环境 将nginx解压后放到c盘根目录下: 运行startup.bat启动环境: 在HTML文件夹下有它的主页文件: 漏洞复现 nginx在遇到后缀名有php的文件时,…...

.NET之后,再无大创新

回想起来,2001年发布的.NET已经是距离最近的一次软件开发技术的整体创新了,后续的新技术就没有在各个端都这么成功的了。.NET是Windows平台下软件开发技术的巨大变革。在此之前,有VB、C(MFC)、JSP,在此之后…...

【大麦小米学量化】什么是量化交易?哪些人适合做量化交易?

系列文章目录 文章目录 系列文章目录学霸的梦想前言一、什么是量化交易?二、哪些人适合做量化交易?三、量化交易都需要掌握哪些技术和方法?总结 学霸的梦想 小米支棱着迷糊的眼睛,一脸懵逼的问大麦:“我说大麦哥哥&…...

计算机视觉的应用12-卷积神经网络中图像特征提取的可视化研究,让大家理解特征提取的全过程

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用12-卷积神经网络中图像特征提取的可视化研究,让大家理解特征提取的全过程。 要理解卷积神经网络中图像特征提取的全过程,我们可以将其比喻为人脑对视觉信息的处理过程。就像…...

el-table中点击跳转到详情页的两种方法

跳转的两种写法: 1.使用keep-alive使组件缓存,防止刷新时参数丢失 keep-alive 组件用于缓存和保持组件的状态,而不是路由参数。它可以在组件切换时保留组件的状态,从而避免重新渲染和加载数据。 keep-alive 主要用于提高页面性能和用户体验,而…...

RT-DETR个人整理向理解

一、前言 在开始介绍RT-DETR这个网络之前,我们首先需要先了解DETR这个系列的网络与我们常提及的anchor-base以及anchor-free存在着何种差异。 首先我们先简单讨论一下anchor-base以及anchor-free两者的差异与共性: 1、两者差异:顾名思义&…...

易点易动库存管理系统与ERP系统打通,帮助企业实现低值易耗品管理

现今,企业管理日趋复杂,无论是核心经营还是辅助环节,都需要依靠信息化手段来提升效率。而低值易耗品作为企业日常运营中的必需品,其管理也面临诸多挑战。传统做法效率低下,容易出错。如何通过信息化手段实现低值易耗品的高效管理,成为许多企业必顾及的一个课题。 易点易动作为…...

多云管理“拦路虎”:深入解析网络互联、身份同步与成本可视化的技术复杂度​

一、引言:多云环境的技术复杂性本质​​ 企业采用多云策略已从技术选型升维至生存刚需。当业务系统分散部署在多个云平台时,​​基础设施的技术债呈现指数级积累​​。网络连接、身份认证、成本管理这三大核心挑战相互嵌套:跨云网络构建数据…...

【WiFi帧结构】

文章目录 帧结构MAC头部管理帧 帧结构 Wi-Fi的帧分为三部分组成:MAC头部frame bodyFCS,其中MAC是固定格式的,frame body是可变长度。 MAC头部有frame control,duration,address1,address2,addre…...

iPhone密码忘记了办?iPhoneUnlocker,iPhone解锁工具Aiseesoft iPhone Unlocker 高级注册版​分享

平时用 iPhone 的时候,难免会碰到解锁的麻烦事。比如密码忘了、人脸识别 / 指纹识别突然不灵,或者买了二手 iPhone 却被原来的 iCloud 账号锁住,这时候就需要靠谱的解锁工具来帮忙了。Aiseesoft iPhone Unlocker 就是专门解决这些问题的软件&…...

P3 QT项目----记事本(3.8)

3.8 记事本项目总结 项目源码 1.main.cpp #include "widget.h" #include <QApplication> int main(int argc, char *argv[]) {QApplication a(argc, argv);Widget w;w.show();return a.exec(); } 2.widget.cpp #include "widget.h" #include &q…...

让AI看见世界:MCP协议与服务器的工作原理

让AI看见世界&#xff1a;MCP协议与服务器的工作原理 MCP&#xff08;Model Context Protocol&#xff09;是一种创新的通信协议&#xff0c;旨在让大型语言模型能够安全、高效地与外部资源进行交互。在AI技术快速发展的今天&#xff0c;MCP正成为连接AI与现实世界的重要桥梁。…...

大学生职业发展与就业创业指导教学评价

这里是引用 作为软工2203/2204班的学生&#xff0c;我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要&#xff0c;而您认真负责的教学态度&#xff0c;让课程的每一部分都充满了实用价值。 尤其让我…...

2023赣州旅游投资集团

单选题 1.“不登高山&#xff0c;不知天之高也&#xff1b;不临深溪&#xff0c;不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...

Linux C语言网络编程详细入门教程:如何一步步实现TCP服务端与客户端通信

文章目录 Linux C语言网络编程详细入门教程&#xff1a;如何一步步实现TCP服务端与客户端通信前言一、网络通信基础概念二、服务端与客户端的完整流程图解三、每一步的详细讲解和代码示例1. 创建Socket&#xff08;服务端和客户端都要&#xff09;2. 绑定本地地址和端口&#x…...

基于 TAPD 进行项目管理

起因 自己写了个小工具&#xff0c;仓库用的Github。之前在用markdown进行需求管理&#xff0c;现在随着功能的增加&#xff0c;感觉有点难以管理了&#xff0c;所以用TAPD这个工具进行需求、Bug管理。 操作流程 注册 TAPD&#xff0c;需要提供一个企业名新建一个项目&#…...

return this;返回的是谁

一个审批系统的示例来演示责任链模式的实现。假设公司需要处理不同金额的采购申请&#xff0c;不同级别的经理有不同的审批权限&#xff1a; // 抽象处理者&#xff1a;审批者 abstract class Approver {protected Approver successor; // 下一个处理者// 设置下一个处理者pub…...