手撕扩散模型(一)| 训练部分——前向扩散,反向预测代码全解析
文章目录
- 1 直接使用 核心代码
- 2 工程代码实现
- 2.1 DDPM
- 2.2 训练
三大模型VAE,GAN, DIffusion扩散模型 是生成界的重要模型,但是最近一段时间扩散模型被用到的越来越多的,最近爆火的OpenAI的
Sora文生视频模型其实也是用了这种的方式,因而我打算系统回顾扩散系列知识,并注重代码的分析,感兴趣可以关注这一系列的博客,先介绍基础版本的,之后介绍扩散进阶的相关知识。
扩散模型很多的讲解上来会讲解很多的数学,会让人望而却步,但其实扩散在实际使用的时候并不复杂,我会先从代码的角度告诉大家怎么实操,再介绍数学推理
扩散要弄明白训练和推理两个过程~这节主要分析训练过程
1 直接使用 核心代码
基础版本的扩散核心就两句话
(1) DDPM前向扩散得到加噪后的图片
得到标记,对应一个核心公式**
(2) DDPM反向利用Unet网络预测加的噪声
实际上抽象一下,忽略细节,训练部分代码就主要以下部分
import torch
from torch import nn
n_steps=1000#假设我们最大的加噪步数是1000
x0=torch.ones(128,1,28,28) #模拟输入,1个batch有128张图片,通道数1,宽度高度为28
eta = torch.randn_like(x0) #生成初始随机噪声,形状和模拟输入一样
t= torch.randint(0, n_steps, (128,))#t是加噪时间,注意这里的t是随机生成的0到1000的128个随机数
noisy_imgs = ddpm(x0, t, eta) #前向加噪 输入原始输入图片和随机的t,得到128个加噪后的图像,扩散模型核心的第一句话
eta_theta = ddpm.backward(noisy_imgs, t.reshape(n, -1)) #反向预测,给定图和t,得到预测噪声,扩散模型核心的第二句话
loss = nn.mse(eta_theta, eta) #计算噪声和实际的噪声之间的差异作为损失
optim.zero_grad()
loss.backward()
optim.step()
2 工程代码实现
当然上面是一个简略版本,实际中肯定要考虑较多的细节问题~
先来实现DDPM
2.1 DDPM
我们申明一个这样一个类MyDDPM
class MyDDPM(nn.Module):def __init__(self, network, n_steps=200, min_beta=10 ** -4, max_beta=0.02, device=None, image_chw=(1, 28, 28)):super(MyDDPM, self).__init__()self.n_steps = n_steps #扩散时间总步数 self.device = device self.image_chw = image_chw #image_chw 用于表示图像的通道数、高度和宽度。这里通道数1,宽度高度为28self.network = network.to(device)self.betas = torch.linspace(min_beta, max_beta, n_steps).to(device) # beta预先算出来了self.alphas = 1 - self.betas #alphas也预先算出来self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]).to(device) #alphas_bars也预先算出来了 前i个乘积def forward(self, x0, t, eta=None):n, c, h, w = x0.shape #[批大小,通道数,图片高,图片宽]a_bar = self.alpha_bars[t] #t的大小和批大小相等if eta is None:eta = torch.randn(n, c, h, w).to(self.device)noisy_img = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * etareturn noisy_imgdef backward(self, x, t):# Run each image through the network for each timestep t in the vector t.# The network returns its estimation of the noise that was added.return self.network(x, t)
这段代码定义了一个名为MyDDPM的类,它是nn.Module的子类。
在MyDDPM类的构造函数__init__中,有以下几个重要的属性和操作:
n_steps:扩散时间总步数,表示模型在每个输入上进行的扩散步数。device:设备,表示模型在哪个设备上运行(如CPU或GPU)。image_chw:图像通道数、高度和宽度的元组,用于表示图像的形状。在这里,通道数为1,高度和宽度为28。network:神经网络模型,用于估计添加的噪声。betas:通过使用torch.linspace函数在min_beta和max_beta之间生成n_steps个均匀间隔的值,得到一个表示扩散系数的张量。alphas:通过将1减去betas得到的张量,表示衰减系数。alpha_bars:通过计算alphas的前i+1个元素的乘积,得到一个表示衰减系数累积乘积的张量。
MyDDPM类还定义了两个方法:
forward方法用于前向传播。它接受输入x0、时间步t和可选的噪声eta作为参数。在该方法中,首先获取输入x0的形状,并根据时间步t获取对应的衰减系数a_bar。如果未提供噪声eta,则使用torch.randn函数生成一个与输入形状相同的噪声张量。然后,根据衰减系数和噪声,计算得到带有噪声的图像张量,并返回该张量作为输出。backward方法用于反向传播。它接受输入x和时间步t作为参数,并通过调用network模型对每个时间步t的输入x进行处理,得到估计的添加噪声。最后,返回估计的噪声张量作为输出。
2.2 训练
有了DDPM我们就可以进行训练了(实际上这里的network我们先当做一个黑盒,在下一节讲解结构,network实现的效果就是输入某一时刻的t,和该时刻加噪后的图像,输出预测的噪声结果,该结果和前向生成的噪声做损失函数~优化参数)
def training_loop(ddpm, loader, n_epochs, optim, device, display=False, store_path="ddpm_model.pt"):mse = nn.MSELoss()best_loss = float("inf")n_steps = ddpm.n_stepsfor epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):epoch_loss = 0.0for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")):# Loading datax0 = batch[0].to(device) #[128,1,1,28]n = len(x0)# Picking some noise for each of the images in the batch, a timestep and the respective alpha_barseta = torch.randn_like(x0).to(device)t = torch.randint(0, n_steps, (n,)).to(device) #注意这里的t是随机生成的# Computing the noisy image based on x0 and the time-step (forward process)noisy_imgs = ddpm(x0, t, eta) #经过前向过程 y一次得到一个批次的# Getting model estimation of noise based on the images and the time-stepeta_theta = ddpm.backward(noisy_imgs, t.reshape(n, -1)) loss = mse(eta_theta, eta) #预测噪声和给出的噪声之间的差异optim.zero_grad()loss.backward()optim.step()epoch_loss += loss.item() * len(x0) / len(loader.dataset)# Display images generated at this epochif display:show_images(generate_new_images(ddpm, device=device), f"Images generated at epoch {epoch + 1}")log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"# Storing the modelif best_loss > epoch_loss:best_loss = epoch_losstorch.save(ddpm.state_dict(), store_path)log_string += " --> Best model ever (stored)"print(log_string)
-
函数定义:
def training_loop(ddpm, loader, n_epochs, optim, device, display=False, store_path="ddpm_model.pt"):- 这个函数接受多个参数:
ddpm是一个对象,loader是一个数据加载器,n_epochs是训练的轮数,optim是优化器,device是设备(如CPU或GPU),display是一个布尔值,用于控制是否显示生成的图像,store_path是模型存储的路径。 - 函数没有返回值。
- 这个函数接受多个参数:
-
导入模块:
mse = nn.MSELoss()- 这里导入了
nn模块,并创建了一个MSELoss的实例对象mse。
- 这里导入了
-
初始化变量:
best_loss = float("inf") n_steps = ddpm.n_stepsbest_loss被初始化为正无穷大,用于跟踪最佳损失值。n_steps从ddpm对象中获取,表示模型的步数。
-
训练循环:
for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):epoch_loss = 0.0for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")):# Loading datax0 = batch[0].to(device) #[128,1,1,28]n = len(x0)...- 外部循环是训练的轮数,使用
range(n_epochs)生成一个迭代器,并使用tqdm函数包装,以显示训练进度条。 - 内部循环是对数据加载器中的批次进行迭代,使用
enumerate函数包装,并使用tqdm函数包装,以显示每个批次的进度条。 - 在每个批次中,首先从批次中加载数据,并将其移动到指定的设备上。
x0是批次中的第一个元素,表示输入数据。n是批次的大小。
- 外部循环是训练的轮数,使用
-
数据处理和模型训练:
eta = torch.randn_like(x0).to(device) t = torch.randint(0, n_steps, (n,)).to(device) noisy_imgs = ddpm(x0, t, eta) eta_theta = ddpm.backward(noisy_imgs, t.reshape(n, -1)) loss = mse(eta_theta, eta) optim.zero_grad() loss.backward() optim.step()eta是一个与x0形状相同的随机张量,用于添加噪声。t是一个随机生成的整数张量,表示时间步骤。noisy_imgs是通过将x0和t作为输入,使用ddpm对象进行前向传播得到的噪声图像。eta_theta是通过将noisy_imgs和t进行反向传播,使用ddpm对象得到的噪声估计。loss是通过计算eta_theta和eta之间的均方误差(MSE)得到的损失。optim.zero_grad()用于清除优化器的梯度。loss.backward()用于计算损失相对于模型参数的梯度。optim.step()用于更新模型参数。
-
显示生成的图像和存储模型:
if display:show_images(generate_new_images(ddpm, device=device), f"Images generated at epoch {epoch + 1}") ... if best_loss > epoch_loss:best_loss = epoch_losstorch.save(ddpm.state_dict(), store_path)log_string += " --> Best model ever (stored)" ... print(log_string)- 如果
display为True,则调用show_images函数显示生成的图像。 generate_new_images函数用于生成新的图像样本。- 如果当前轮的损失比之前的最佳损失更低,则将模型参数保存到指定的路径。
- 最后,打印训练日志字符串。
- 如果
相关文章:
手撕扩散模型(一)| 训练部分——前向扩散,反向预测代码全解析
文章目录 1 直接使用 核心代码2 工程代码实现2.1 DDPM2.2 训练 三大模型VAE,GAN, DIffusion扩散模型 是生成界的重要模型,但是最近一段时间扩散模型被用到的越来越多的,最近爆火的OpenAI的 Sora文生视频模型其实也是用了这种的方…...
linux 防火墙
防火墙分类 按保护范围划分 主机防火墙:服务服务为当前一台主机 网络防火墙:服务服务为防火墙一侧的局域网 按实现方式分类划分 硬件防火墙:在专用硬件级别实现部分功能的防火墙;另一部分基于软件的实现 如:华为&#…...
Go应用性能分析实战
Go很适合用来开发高性能网络应用,但仍然需要借助有效的工具进行性能分析,优化代码逻辑。本文介绍了如何通过go test benchmark和pprof进行性能分析,从而实现最优的代码效能。原文: Profiling Go Applications in the Right Way with Examples…...
MySQL的索引类型
目录 1. 主键索引 (PRIMARY KEY) 2. 唯一索引 (UNIQUE) 3. 普通索引 (INDEX) 4. 全文索引 (FULLTEXT) 5. 空间索引 (SPATIAL) 6. 组合索引 (COMPOSITE INDEX) 7. 前缀索引 (PREFIX INDEX) 8. 覆盖索引 (COVERING INDEX) 1. 主键索引 (PRIMARY KEY) 描述:表…...
picker选择器-年月日选择
从底部弹起的滚动选择器。支持五种选择器,通过mode来区分,分别是普通选择器,多列选择器,时间选择器,日期选择器,省市区选择器,默认是普通选择器。 学习一下日期选择器 平台差异说明 日期选择默…...
【LeetCode-494】目标和(回溯动归)
目录 LeetCode494.目标和 题目描述 解法1:回溯法 代码实现 解法2:动态规划 代码实现 LeetCode494.目标和 题目链接 题目描述 给定一个非负整数数组,a1, a2, ..., an, 和一个目标数,S。现在你有两个符号 和 -。对于数组中…...
力扣 188. 买卖股票的最佳时机 IV
题目来源:https://leetcode.cn/problems/best-time-to-buy-and-sell-stock-iv/description/ C题解:动态规划 思路同力扣 123. 买卖股票的最佳时机 III-CSDN博客,只是把最高2次换成k次。如果思路不清晰,可以将k从0写到4等找找规律…...
【Go语言】Go项目工程管理
GO 项目工程管理(Go Modules) Go 1.11 版本开始,官方提供了 Go Modules 进行项目管理,Go 1.13开始,Go项目默认使用 Go Modules 进行项目管理。 使用 Go Modules的好处是不再需要依赖 GOPATH,可以在任意位…...
美容小程序:让预约更简单,服务更贴心
在当今繁忙的生活节奏中,美容预约常常令人感到繁琐和疲惫。为了解决这个问题,许多美容院和SPA中心已经开始采用美容小程序来简化预约流程,并提供更加贴心的服务。在这篇文章中,我们将引导您了解如何制作一个美容小程序,…...
【递归】:原理、应用与案例解析 ,助你深入理解递归核心思想
递归 1.基础简介 递归在计算机科学中,递归是一种解决计算问题的方法,其中解决方案取决于同一类问题的更小子集 例如 递归遍历环形链表 基本情况(Base Case):基本情况是递归函数中最简单的情况,它们通常是递…...
【 Maven 】花式玩法之多模块项目
目录 一、认识Maven多模块项目 二、maven如何定义项目的发布策略 2.1 版本管理 2.2 构建配置 2.3 部署和发布 2.4 依赖管理 2.5 发布流程 三、使用Jenkins持续集成Maven项目 四、总结 如果你有一个多模块项目,并且想将这些模块发布到不同的仓库或目标位置&…...
LeetCode 热题 100 Day01
哈希模块 哈希结构: 哈希结构,即hash table,哈希表|散列表结构。 图摘自《代码随想录》 哈希表本质上表示的元素和索引的一种映射关系。 若查找某个数组中第n个元素,有两种方法: 1.从头遍历,复杂度…...
[vscode]vue js部分结尾加分号
设置中寻找 semicolons确定在TypeScript的这个扩展中设置选项为insert...
友点CMS image_upload.php 文件上传漏洞复现
0x01 产品简介 友点CMS是一款高效且灵活的网站管理系统,它为用户提供了简单易用的界面和丰富的功能。无论是企业还是个人,都能通过友点CMS快速搭建出专业且美观的网站。该系统支持多种内容类型和自定义模板,方便用户按需调整。同时,它具备强大的SEO功能,能提升网站在搜索…...
C语言—指针(3)
嘿嘿嘿嘿,你看我像指针吗? 不会写,等我啥时候会写了再说吧,真的累了,倦了 1.面试题 1)定义整形变量i; 2)p为指向整形变量的指针变量; 3)定…...
【八股文】面向对象基础
【八股文】面向对象基础 面向对象和面向过程的区别 面向过程把解决问题的过程拆成一个个方法,通过一个个方法的执行解决问题。面向对象会先抽象出对象,然后用对象执行方法的方式解决问题。 创建一个对象用什么运算符?对象实体与对象引用有何不同? …...
Day49 647 回文子串 516 最长回文子序列
647 回文子串 给定一个字符串,你的任务是计算这个字符串中有多少个回文子串。 具有不同开始位置或结束位置的子串,即使是由相同的字符组成,也会被视作不同的子串。 方法一:动态规划: 采用一个二维的dp数组…...
探秘GNU/Linux Shell:命令行的魔法世界
GNU/Linux的Shell是一种特殊的交互式工具,为用户提供了强大的控制和管理Linux系统的方式。在这个博客中,我们将深入了解Shell的基本概念、功能以及不同类型的Shell。 Shell的本质 Shell的核心是命令行提示符,它是用户与Linux系统进行交互的…...
基于STM32F407的coreJSON使用教程
目录 概述 工程建立 代码集成 函数介绍 使用示例 概述 coreJSON是FreeRTOS中的一个组件库,支持key查找的解析器,他只是一个解析器,不能生成json数据。同时严格执行 ECMA-404 JSON 标准。该库用 C 语言编写,设计符合 ISO C90…...
keepalived双主模式测试
文章目录 环境准备部署安装keepavlived配置启动测试模拟Nginx宕机重新启动问题分析 环境准备 测试一下keepalived的双主模式,所谓双主模式就是两个keepavlied节点各持有一个/组虚IP,默认情况下,二者互为主备,同时对外提供服务&am…...
Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件
今天呢,博主的学习进度也是步入了Java Mybatis 框架,目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学,希望能对大家有所帮助,也特别欢迎大家指点不足之处,小生很乐意接受正确的建议&…...
HBuilderX安装(uni-app和小程序开发)
下载HBuilderX 访问官方网站:https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本: Windows版(推荐下载标准版) Windows系统安装步骤 运行安装程序: 双击下载的.exe安装文件 如果出现安全提示&…...
深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南
🚀 C extern 关键字深度解析:跨文件编程的终极指南 📅 更新时间:2025年6月5日 🏷️ 标签:C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言🔥一、extern 是什么?&…...
docker 部署发现spring.profiles.active 问题
报错: org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...
算法岗面试经验分享-大模型篇
文章目录 A 基础语言模型A.1 TransformerA.2 Bert B 大语言模型结构B.1 GPTB.2 LLamaB.3 ChatGLMB.4 Qwen C 大语言模型微调C.1 Fine-tuningC.2 Adapter-tuningC.3 Prefix-tuningC.4 P-tuningC.5 LoRA A 基础语言模型 A.1 Transformer (1)资源 论文&a…...
在Ubuntu24上采用Wine打开SourceInsight
1. 安装wine sudo apt install wine 2. 安装32位库支持,SourceInsight是32位程序 sudo dpkg --add-architecture i386 sudo apt update sudo apt install wine32:i386 3. 验证安装 wine --version 4. 安装必要的字体和库(解决显示问题) sudo apt install fonts-wqy…...
浪潮交换机配置track检测实现高速公路收费网络主备切换NQA
浪潮交换机track配置 项目背景高速网络拓扑网络情况分析通信线路收费网络路由 收费汇聚交换机相应配置收费汇聚track配置 项目背景 在实施省内一条高速公路时遇到的需求,本次涉及的主要是收费汇聚交换机的配置,浪潮网络设备在高速项目很少,通…...
IP如何挑?2025年海外专线IP如何购买?
你花了时间和预算买了IP,结果IP质量不佳,项目效率低下不说,还可能带来莫名的网络问题,是不是太闹心了?尤其是在面对海外专线IP时,到底怎么才能买到适合自己的呢?所以,挑IP绝对是个技…...
uniapp手机号一键登录保姆级教程(包含前端和后端)
目录 前置条件创建uniapp项目并关联uniClound云空间开启一键登录模块并开通一键登录服务编写云函数并上传部署获取手机号流程(第一种) 前端直接调用云函数获取手机号(第三种)后台调用云函数获取手机号 错误码常见问题 前置条件 手机安装有sim卡手机开启…...
日常一水C
多态 言简意赅:就是一个对象面对同一事件时做出的不同反应 而之前的继承中说过,当子类和父类的函数名相同时,会隐藏父类的同名函数转而调用子类的同名函数,如果要调用父类的同名函数,那么就需要对父类进行引用&#…...
