【扩散模型 李宏毅B站教学以及基础代码运用】
李宏毅教学视频:
Link1
B站DDPM公式推导以及代码实现:
Link2
这个视频里面有论文里面的公式推导,并且1小时10分开始讲解实例代码。
文章目录
- 扩散模型概念:
- Diffusion Model工作原理:
- 影像生成模型本质上的共同目标
- B站简单示例代码讲解
扩散模型概念:
就像石头里面已经有了雕塑,只需要看我们怎么把其他多余的部分去掉。
注意观察,我们每一个Denoise阶段都不一样,因为每一个阶段传入的图片以及需要处理的noise都不一样,并且直接产生图片比直接产生噪音更难,所以我们通过预测noise来解决问题。
比如下图所示:step2是我们加的噪声,那么传入input和2的时候就希望预测出gt了,然后进行相减得到step1的图片。
Diffusion Model工作原理:
VAE和Diffusion的区别
先看整个训练过程:
实际结果和我们想的是不一样的。训练时通过X0和噪声得到一个图,逆向的时候输入t和生成的图来得到噪音。想象的是一点一点加入噪音,实际上是直接加进去的。
推断时刻:theat是带有参数的网络。
影像生成模型本质上的共同目标
通过采样一个高深distribution生成一个图片。希望生成的图片和真实的图片的distribution很接近。
那么怎么衡量这两个分布的接近程度呢?多数采用的都是Maximum liklihood Estimation.
我们希望我们采样的数据能够通过theta网络计算出来的概率越大越好。
通过数学变换,将概率最大变为Pdata和Ptheat这两个distribution的KL散度最小。
VAE的下界
Ptheat(x)表示:通过theta产生x的概率。
DDPM计算Ptheta(x)的方法 下图表示产生X0的概率。
两者对比
接下来需要计算q(x1|x0)此类公式。
计算方法:X1到X2的计算方法在论文中有提及。
两个高斯分布都是服从N(0,1),相加的话还是一个高斯分布,并且还是服从N(0,1),只是前面系数会发生变化。系数的话是根号下面数字相加。所以相加之后均值还是为0,方差a方加b方即可,这个在另外一个视频里面有讲解。
经过一番推导之后得到:
之后计算最下面三项:
通过以下推导:
之后通过X0,Xt可以得到Xt-1的分布。
可以看到前面一项的mean 和 variance是固定的,第二项的variance也是固定的,因此我们需要把第二项的mean变得和第一项的接近。
那么怎么minimiaze这个mean呢?希望用Xt去预测出来那个mean。
经过推导:
最终得到下图:
里面beta可以学习,但是效果不好,所以使用线性固定。最后加上一个噪声猜测是为了增强鲁棒性,并且本身就是从噪声开始,不加噪声的话可能不会生成图片。
B站简单示例代码讲解
# 加载数据集
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torchs_curve,_ = make_s_curve(10**4,noise=0.1)
print(np.shape(s_curve))
s_curve = s_curve[:,[0,2]]/10.0print("shape of s:",np.shape(s_curve))data = s_curve.Tfig,ax = plt.subplots()
ax.scatter(*data,color='blue',edgecolor='white');ax.axis('off')dataset = torch.Tensor(s_curve).float()
# 2确定超参数的值
num_steps = 100
#制定每一步的beta
betas = torch.linspace(-6,6,num_steps)
betas = torch.sigmoid(betas)*(0.5e-2 - 1e-5)+1e-5
#计算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1-betas
alphas_prod = torch.cumprod(alphas,0)
# print(alphas_prod)
alphas_prod_p = torch.cat([torch.tensor([1]).float(),alphas_prod[:-1]],0)
# print(alphas_prod_p)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
assert alphas.shape==alphas_prod.shape==alphas_prod_p.shape==\
alphas_bar_sqrt.shape==one_minus_alphas_bar_log.shape\
==one_minus_alphas_bar_sqrt.shape
print("all the same shape",betas.shape)、确定扩散过程任意时刻的采样值#3 计算任意时刻的x采样值,基于x_0和重参数化
def q_x(x_0,t):"""可以基于x[0]得到任意时刻t的x[t]"""noise = torch.randn_like(x_0)alphas_t = alphas_bar_sqrt[t]alphas_1_m_t = one_minus_alphas_bar_sqrt[t]return (alphas_t * x_0 + alphas_1_m_t * noise)#在x[0]的基础上添加噪声
j
# 4 演示原始数据分布加噪100步后的结果num_shows = 20
fig,axs = plt.subplots(2,10,figsize=(28,3))
plt.rc('text',color='black')
#共有10000个点,每个点包含两个坐标
#生成100步以内每隔5步加噪声后的图像
for i in range(num_shows):j = i//10k = i%10q_i = q_x(dataset,torch.tensor([i*num_steps//num_shows]))#生成t时刻的采样数据axs[j,k].scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white')axs[j,k].set_axis_off()axs[j,k].set_title('$q(\mathbf{x}_{'+str(i*num_steps//num_shows)+'})$')
# 5 编写拟合逆扩散过程高斯分布的模型import torch
import torch.nn as nn
class MLPDiffusion(nn.Module):def __init__(self,n_steps,num_units=128):super(MLPDiffusion,self).__init__()self.linears = nn.ModuleList([nn.Linear(2,num_units),nn.ReLU(),nn.Linear(num_units,num_units),nn.ReLU(),nn.Linear(num_units,num_units),nn.ReLU(),nn.Linear(num_units,2),])self.step_embeddings = nn.ModuleList([nn.Embedding(n_steps,num_units),nn.Embedding(n_steps,num_units),nn.Embedding(n_steps,num_units),])def forward(self,x,t):
# x = x_0for idx,embedding_layer in enumerate(self.step_embeddings):t_embedding = embedding_layer(t)x = self.linears[2*idx](x)x += t_embeddingx = self.linears[2*idx+1](x)x = self.linears[-1](x)return x
loss_fn 就是Lsimple得表达式。通过传入参数,生成一个随机噪声,并且送入model里面,那么上面也讲了,model的作用是根据X0,和t预测出我们应该减去的噪声,所以损失函数就是用生成的噪声减去预测的噪声。
# 6 编写训练的误差函数
def diffusion_loss_fn(model,x_0,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps):"""对任意时刻t进行采样计算loss"""batch_size = x_0.shape[0]#对一个batchsize样本生成随机的时刻tt = torch.randint(0,n_steps,size=(batch_size//2,))t = torch.cat([t,n_steps-1-t],dim=0)t = t.unsqueeze(-1)#x0的系数a = alphas_bar_sqrt[t]#eps的系数aml = one_minus_alphas_bar_sqrt[t]#生成随机噪音epse = torch.randn_like(x_0)#构造模型的输入x = x_0*a+e*aml#送入模型,得到t时刻的随机噪声预测值output = model(x,t.squeeze(-1))#与真实噪声一起计算误差,求平均值return torch.pow((e - output),2).mean()
# 7、编写逆扩散采样函数(inference)def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):"""从x[T]恢复x[T-1]、x[T-2]|...x[0]"""cur_x = torch.randn(shape)x_seq = [cur_x]for i in reversed(range(n_steps)):cur_x = p_sample(model,cur_x,i,betas,one_minus_alphas_bar_sqrt)x_seq.append(cur_x)return x_seq
def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):"""从x[T]采样t时刻的重构值"""t = torch.tensor([t])coeff = betas[t] / one_minus_alphas_bar_sqrt[t]eps_theta = model(x,t)mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))z = torch.randn_like(x)sigma_t = betas[t].sqrt()sample = mean + sigma_t * zreturn (sample)
# 8、开始训练模型,打印loss及中间重构效果seed = 1234
class EMA():"""构建一个参数平滑器"""def __init__(self,mu=0.01):self.mu = muself.shadow = {}def register(self,name,val):self.shadow[name] = val.clone()def __call__(self,name,x):assert name in self.shadownew_average = self.mu * x + (1.0-self.mu)*self.shadow[name]self.shadow[name] = new_average.clone()return new_averageprint('Training model...')
batch_size = 128
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
num_epoch = 4000
plt.rc('text',color='blue')
model = MLPDiffusion(num_steps)#输出维度是2,输入是x和step
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
for t in range(num_epoch):for idx,batch_x in enumerate(dataloader):loss = diffusion_loss_fn(model,batch_x,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,num_steps)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),1.)optimizer.step()if(t%100==0):print(loss)x_seq = p_sample_loop(model,dataset.shape,num_steps,betas,one_minus_alphas_bar_sqrt)fig,axs = plt.subplots(1,10,figsize=(28,3))for i in range(1,11):cur_x = x_seq[i*10].detach()axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');axs[i-1].set_axis_off();axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')
最后的演示
动画演示扩散过程和逆扩散过程import io
from PIL import Image
imgs = []
for i in range(100):plt.clf()q_i = q_x(dataset,torch.tensor([i]))plt.scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white',s=5);plt.axis('off');img_buf = io.BytesIO()plt.savefig(img_buf,format='png')img = Image.open(img_buf)imgs.append(img)
mg = Image.open(img_buf)reverse.append(img)
reverse = []
for i in range(100):plt.clf()cur_x = x_seq[i].detach()plt.scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white',s=5);plt.axis('off')img_buf = io.BytesIO()plt.savefig(img_buf,format='png')img = Image.open(img_buf)reverse.append(img)
imgs = imgs +reverse
imgs[0].save("diffusion.gif",format='GIF',append_images=imgs,save_all=True,duration=100,loop=0)
相关文章:

【扩散模型 李宏毅B站教学以及基础代码运用】
李宏毅教学视频: Link1 B站DDPM公式推导以及代码实现: Link2 这个视频里面有论文里面的公式推导,并且1小时10分开始讲解实例代码。 文章目录 扩散模型概念:Diffusion Model工作原理:影像生成模型本质上的共同目标B站…...

SpringBoot隐藏文件
1.设置 2.输入file Types 3.点击忽略文件或者文件夹 4.成功...
常见数据库介绍对比之SQL关系型数据库
1.关系型数据库介绍 关系型数据库是一种基于关系模型的数据库,它使用表格来组织和存储数据。下面是一些常见的关系型数据库: 1.1. MySQL MySQL是一种开源的关系型数据库管理系统(RDBMS),广泛用于Web应用程序和企业级…...

OLED透明屏模块:引领未来显示技术的突破
OLED透明屏模块作为一项引领未来显示技术的突破,以其独特的特点和卓越的画质在市场上引起了广泛关注。 根据行业报告,预计到2025年,OLED透明屏模块将占据智能手机市场的20%份额,并在汽车导航系统市场中占据30%以上份额。 那么&am…...
Python_操作记录
1、Pandas读取数据文件(以文本文件作为示例),sep表示间隔,headerNone表示无标题行 df pd.read_table("data/youcans3.dat", sep"\t", headerNone) 2、线性规划问题求解 1)问题定义,…...

常用激活函数整理
最近一边应付工作,一边在补足人工智能的一些基础知识,这个方向虽然新兴,但已是卷帙浩繁,有时不知从何入手,幸亏有个适合基础薄弱的人士学习的网站,每天学习一点,积跬步以至千里吧。有像我一样学…...
uniapp 地图跳转到第三方导航软件 直接打包成apk
// 判断是否存在导航软件judgeHasExistNavignation() {let navAppParam [{pname: com.baidu.BaiduMap,action: baidumap://}, // 百度{pname: com.autonavi.minimap,action: iosamap://}, // 高德{pname: com.tencent.map,action: tencentmap://}, // 腾讯];return navAppPara…...

CentOS 8 通过YUM方式升级最新内核
CentOS 8 通过YUM方式升级最新内核 查看当前内核 uname -r 4.18.0-193.6.3.el8_2.x86_64导入 ELRepo 仓库的公钥: rpm --import https://www.elrepo.org/RPM-GPG-KEY-elrepo.org安装升级内核相关的yum源仓库(安装 ELRepo 仓库的 yum 源) yum install https://www…...

java 版本企业招标投标管理系统源码+功能描述+tbms+及时准确+全程电子化
功能描述 1、门户管理:所有用户可在门户页面查看所有的公告信息及相关的通知信息。主要板块包含:招标公告、非招标公告、系统通知、政策法规。 2、立项管理:企业用户可对需要采购的项目进行立项申请,并提交审批,查看所…...

Python爬虫数据存哪里|数据存储到文件的几种方式
前言 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 爬虫请求解析后的数据,需要保存下来,才能进行下一步的处理,一般保存数据的方式有如下几种: 文件:txt、csv、excel、json等,保存数据量小。 关系型数据库…...

软件测试/测试开发丨Web自动化 测试用例流程设计
点此获取更多相关资料 本文为霍格沃兹测试开发学社学员学习笔记分享 原文链接:https://ceshiren.com/t/topic/27173 一、测试用例通用结构回顾 1.1、现有测试用例存在的问题 可维护性差可读性差稳定性差 1.2、用例结构设计 测试用例的编排测试用例的项目结构 1…...
git撤销修改命令
要撤销Git中尚未提交的所有修改,可以使用以下几种方法: 1、使用git checkout命令丢弃工作目录的修改,重置工作目录中所有文件的修改。 git checkout . 2、使用git reset命令重置暂存区和工作目录, 重置暂存区和工作目录,回到最后一次提交后的状态。 …...

EOCR-AR电机保护器自动复位的启用条件说明
为适用不同的现场使用需求,施耐德韩国公司推出了带有自动复位功能的模拟型电动机保护器-EOCR-AR。EOCR-AR电机保护器具有过电流、缺相、堵转保护功能,还可根据实际需要设置自动复位时间。 EOCR-AR自动复位的设置方法 如上图,R-TIME旋钮是自动…...

Apache nginx解析漏洞复现
文章目录 空字节漏洞安装环境漏洞复现 背锅解析漏洞安装环境漏洞复现 空字节漏洞 安装环境 将nginx解压后放到c盘根目录下: 运行startup.bat启动环境: 在HTML文件夹下有它的主页文件: 漏洞复现 nginx在遇到后缀名有php的文件时,…...

.NET之后,再无大创新
回想起来,2001年发布的.NET已经是距离最近的一次软件开发技术的整体创新了,后续的新技术就没有在各个端都这么成功的了。.NET是Windows平台下软件开发技术的巨大变革。在此之前,有VB、C(MFC)、JSP,在此之后…...
【大麦小米学量化】什么是量化交易?哪些人适合做量化交易?
系列文章目录 文章目录 系列文章目录学霸的梦想前言一、什么是量化交易?二、哪些人适合做量化交易?三、量化交易都需要掌握哪些技术和方法?总结 学霸的梦想 小米支棱着迷糊的眼睛,一脸懵逼的问大麦:“我说大麦哥哥&…...

计算机视觉的应用12-卷积神经网络中图像特征提取的可视化研究,让大家理解特征提取的全过程
大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用12-卷积神经网络中图像特征提取的可视化研究,让大家理解特征提取的全过程。 要理解卷积神经网络中图像特征提取的全过程,我们可以将其比喻为人脑对视觉信息的处理过程。就像…...

el-table中点击跳转到详情页的两种方法
跳转的两种写法: 1.使用keep-alive使组件缓存,防止刷新时参数丢失 keep-alive 组件用于缓存和保持组件的状态,而不是路由参数。它可以在组件切换时保留组件的状态,从而避免重新渲染和加载数据。 keep-alive 主要用于提高页面性能和用户体验,而…...

RT-DETR个人整理向理解
一、前言 在开始介绍RT-DETR这个网络之前,我们首先需要先了解DETR这个系列的网络与我们常提及的anchor-base以及anchor-free存在着何种差异。 首先我们先简单讨论一下anchor-base以及anchor-free两者的差异与共性: 1、两者差异:顾名思义&…...

易点易动库存管理系统与ERP系统打通,帮助企业实现低值易耗品管理
现今,企业管理日趋复杂,无论是核心经营还是辅助环节,都需要依靠信息化手段来提升效率。而低值易耗品作为企业日常运营中的必需品,其管理也面临诸多挑战。传统做法效率低下,容易出错。如何通过信息化手段实现低值易耗品的高效管理,成为许多企业必顾及的一个课题。 易点易动作为…...

label-studio的使用教程(导入本地路径)
文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...

遍历 Map 类型集合的方法汇总
1 方法一 先用方法 keySet() 获取集合中的所有键。再通过 gey(key) 方法用对应键获取值 import java.util.HashMap; import java.util.Set;public class Test {public static void main(String[] args) {HashMap hashMap new HashMap();hashMap.put("语文",99);has…...

UDP(Echoserver)
网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法:netstat [选项] 功能:查看网络状态 常用选项: n 拒绝显示别名&#…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...

Mac下Android Studio扫描根目录卡死问题记录
环境信息 操作系统: macOS 15.5 (Apple M2芯片)Android Studio版本: Meerkat Feature Drop | 2024.3.2 Patch 1 (Build #AI-243.26053.27.2432.13536105, 2025年5月22日构建) 问题现象 在项目开发过程中,提示一个依赖外部头文件的cpp源文件需要同步,点…...
python报错No module named ‘tensorflow.keras‘
是由于不同版本的tensorflow下的keras所在的路径不同,结合所安装的tensorflow的目录结构修改from语句即可。 原语句: from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后: from tensorflow.python.keras.lay…...
深入理解Optional:处理空指针异常
1. 使用Optional处理可能为空的集合 在Java开发中,集合判空是一个常见但容易出错的场景。传统方式虽然可行,但存在一些潜在问题: // 传统判空方式 if (!CollectionUtils.isEmpty(userInfoList)) {for (UserInfo userInfo : userInfoList) {…...
「全栈技术解析」推客小程序系统开发:从架构设计到裂变增长的完整解决方案
在移动互联网营销竞争白热化的当下,推客小程序系统凭借其裂变传播、精准营销等特性,成为企业抢占市场的利器。本文将深度解析推客小程序系统开发的核心技术与实现路径,助力开发者打造具有市场竞争力的营销工具。 一、系统核心功能架构&…...
go 里面的指针
指针 在 Go 中,指针(pointer)是一个变量的内存地址,就像 C 语言那样: a : 10 p : &a // p 是一个指向 a 的指针 fmt.Println(*p) // 输出 10,通过指针解引用• &a 表示获取变量 a 的地址 p 表示…...

Linux部署私有文件管理系统MinIO
最近需要用到一个文件管理服务,但是又不想花钱,所以就想着自己搭建一个,刚好我们用的一个开源框架已经集成了MinIO,所以就选了这个 我这边对文件服务性能要求不是太高,单机版就可以 安装非常简单,几个命令就…...