扩散模型实战(八):微调扩散模型
推荐阅读列表:
扩散模型实战(一):基本原理介绍
扩散模型实战(二):扩散模型的发展
扩散模型实战(三):扩散模型的应用
扩散模型实战(四):从零构建扩散模型
扩散模型实战(五):采样过程
扩散模型实战(六):Diffusers DDPM初探
扩散模型实战(七):Diffusers蝴蝶图像生成实战
微调在LLM中并不是新鲜的概念,从头开始训练一个扩散模型需要很长的时间,特别是使用高分辨率图像训练。那么其实我们可以在已经训练好的”去噪“扩散模型基础上使用微调数据集进行二次微调训练。
本文将介绍基于蝴蝶数据集上微调人脸生成的扩散模型:
一、环境准备
1.1 安装相关库
!pip install -qq diffusers datasets accelerate wandb open-clip-torch
1.2 登录Huggingface Hub
如果需要开源微调好的模型到Huggingface Hub上,那么需要使用如下代码登录,否则可忽略此步骤:
from huggingface_hub import notebook_loginnotebook_login()
1.3 导入相关库
import numpy as npimport torchimport torch.nn.functional as Fimport torchvisionfrom datasets import load_datasetfrom diffusers import DDIMScheduler, DDPMPipelinefrom matplotlib import pyplot as pltfrom PIL import Imagefrom torchvision import transformsfrom tqdm.auto import tqdmdevice = ("mps"if torch.backends.mps.is_available()else "cuda"if torch.cuda.is_available()else "cpu")
二、导入预训练的扩散模型
下面我们导入人脸生成的扩散模型,观察一下生成的效果,代码如下:
image_pipe = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256")image_pipe.to(device);
查看生成的图像,代码如下:
images = image_pipe().imagesimages[0]

生成的效果虽然不错,但是速度稍微有点慢,其实有更快的采样器可以加速这一过程,比如下面介绍的DDIM
三、DDIM-更快的采样器
在生成图像的每一步中,模型都会接收一个带有噪声的输入,并且需要预测这个噪声,以此来估计没有噪声的完整图像是什么。这个过程被称为采样过程,在Diffusers库中,采样通过调度器控制的,之前的文章中介绍过DDPMScheduler调度器,本文介绍的DDIMScheduler可以通过更少的迭代周期来产生很好的采样样本(1000多步采样不是必须的)。
# 创建一个新的调度器并设置推理迭代次数scheduler = DDIMScheduler.from_pretrained("google/ddpm-celebahq-256")scheduler.set_timesteps(num_inference_steps=40)
scheduler.timesteps
# 输出tensor([975, 950, 925, 900, 875, 850, 825, 800, 775, 750, 725,700, 675, 650, 625, 600, 575, 550, 525, 500, 475, 450, 425,400, 375, 350, 325, 300, 275, 250, 225, 200, 175, 150,125, 100, 75, 50, 25, 0])
下面使用4幅随机噪声图像进行循环采样,并观察每一步的输入与输出的”去噪“图像,代码如下:
# 从随机噪声开始x = torch.randn(4, 3, 256, 256).to(device)# batch size为4,三通道,长、宽均为256像素的一组图像# 循环一整套时间步for i, t in tqdm(enumerate(scheduler.timesteps)):# 准备模型输入:给“带躁”图像加上时间步信息model_input = scheduler.scale_model_input(x, t)# 预测噪声with torch.no_grad():noise_pred = image_pipe.unet(model_input, t)["sample"]# 使用调度器计算更新后的样本应该是什么样子scheduler_output = scheduler.step(noise_pred, t, x)# 更新输入图像x = scheduler_output.prev_sample# 时不时看一下输入图像和预测的“去噪”图像if i % 10 == 0 or i == len(scheduler.timesteps) - 1:fig, axs = plt.subplots(1, 2, figsize=(12, 5))grid = torchvision.utils.make_grid(x, nrow=4).permute(1, 2, 0)axs[0].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)axs[0].set_title(f"Current x (step {i})")pred_x0 = (scheduler_output.pred_original_sample)grid = torchvision.utils.make_grid(pred_x0, nrow=4).permute(1, 2, 0)axs[1].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)axs[1].set_title(f"Predicted denoised images (step {i})")plt.show()





第二步生成图像的采样器是DDPMScheduler,我们可以使用新的DDIMScheduler来代替DDPMScheduler看看image_pipe生成的效果是否有提升,代码如下:
image_pipe.scheduler = schedulerimages = image_pipe(num_inference_steps=40).imagesimages[0]

上述介绍了生成人脸的扩散模型以及生成的效果,也介绍了更快的采样器DDIMScheduler,下面我们使用蝴蝶数据集来微调人脸生成扩散模型:
四、微调人脸生成扩散模型
4.1 加载蝴蝶数据集
dataset_name = "huggan/smithsonian_butterflies_subset"dataset = load_dataset(dataset_name, split="train")image_size = 256batch_size = 4preprocess = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),])def transform(examples):images = [preprocess(image.convert("RGB")) for image inexamples["image"]]return {"images": images}dataset.set_transform(transform)train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
输出4幅蝴蝶图像,便于观察
print("Previewing batch:")batch = next(iter(train_dataloader))grid = torchvision.utils.make_grid(batch["images"], nrow=4)plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)

4.2 微调人脸生成扩散模型
num_epochs = 2lr = 1e-5grad_accumulation_steps = 2optimizer = torch.optim.AdamW(image_pipe.unet.parameters(), lr=lr)losses = []for epoch in range(num_epochs):for step, batch in tqdm(enumerate(train_dataloader),total=len(train_dataloader)):clean_images = batch["images"].to(device)# 随机生成一个噪声,稍后加到图像上noise = torch.randn(clean_images.shape).to(clean_images.device)bs = clean_images.shape[0]# 随机选取一个时间步timesteps = torch.randint(0,image_pipe.scheduler.num_train_timesteps,(bs,),device=clean_images.device,).long()# 根据选中的时间步和确定的幅值,在干净图像上添加噪声# 此处为前向扩散过程noisy_images = image_pipe.scheduler.add_noise(clean_images,noise, timesteps)# 使用“带噪”图像进行网络预测noise_pred = image_pipe.unet(noisy_images, timesteps,return_dict=False)[0]# 对真正的噪声和预测的结果进行比较,注意这里是预测噪声loss = F.mse_loss(noise_pred, noise)# 保存损失值losses.append(loss.item())# 根据损失值更新梯度loss.backward()# 进行梯度累积,在累积到一定步数后更新模型参数if (step + 1) % grad_accumulation_steps == 0:optimizer.step()optimizer.zero_grad()print(f"Epoch {epoch} average loss: {sum(losses[-len(train_dataloader):])/len(train_dataloader)}")# 画出损失曲线,效果如图所示plt.plot(losses)

4.3 使用微调好的模型生成图像
x = torch.randn(8, 3, 256, 256).to(device)for i, t in tqdm(enumerate(scheduler.timesteps)):model_input = scheduler.scale_model_input(x, t)with torch.no_grad():noise_pred = image_pipe.unet(model_input, t)["sample"]x = scheduler.step(noise_pred, t, x).prev_samplegrid = torchvision.utils.make_grid(x, nrow=4)plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)

从图中可以看出生成的图像有蝴蝶数据的风格。
4.4 保持微调好的扩散模型,并且上传到Huggingface Hub中
image_pipe.save_pretrained("my-finetuned-model")
from huggingface_hub import HfApi, ModelCard, create_repo, get_full_repo_name# 配置Hugging Face Hub,上传文件model_name = "ddpm-celebahq-finetuned-butterflies-2epochs"# 使用@param 脚本程序对上传到# Hugging Face Hub的文件进行命名local_folder_name = "my-finetuned-model" # @param脚本程序生成的名字,# 你也可以通过 image_pipe.save_pretrained('savename')自行指定description = "Describe your model here" # @paramhub_model_id = get_full_repo_name(model_name)create_repo(hub_model_id)api = HfApi()api.upload_folder(folder_path=f"{local_folder_name}/scheduler",path_in_repo="",repo_id=hub_model_id )api.upload_folder(folder_path=f"{local_folder_name}/unet", path_in_repo="",repo_id=hub_model_id )api.upload_file(path_or_fileobj=f"{local_folder_name}/model_index.json",path_in_repo="model_index.json",repo_id=hub_model_id,)# 添加一个模型卡片,这一步虽然不是必需的,但可以给他人提供一些模型描述信息content = f"""---license: mittags:- pytorch- diffusers- unconditional-image-generation- diffusion-models-class---# 用法from diffusers import DDPMPipelinepipeline = DDPMPipeline.from_pretrained(' {hub_model_id}')image = pipeline().images[0]image'''"""card = ModelCard(content)card.push_to_hub(hub_model_id)
微调Trick:
- 设置合适的batch_size,batch_size要在不超过GPU显存的前提下,尽量大一些,这样可以提高GPU计算效果;如果特别小,可以采用梯度累积的方式来更新模型参数,达到和大batch_size类似的效果,也就是多运行几次loss.backward(),再调用optimizer.step()和optimizer.zero_grad();
- 训练过程中,要时不时生成一些图像样本来观察模型性能;
- 训练过程中,可以把损失值、生成的图像样本等信息记录在日志中,可以使用Weights and Biases、TensorBoard等工具;
相关文章:
扩散模型实战(八):微调扩散模型
推荐阅读列表: 扩散模型实战(一):基本原理介绍 扩散模型实战(二):扩散模型的发展 扩散模型实战(三):扩散模型的应用 扩散模型实战(四…...
Android 全局控件属性设置
一 使用需求: 如 设置全局字体、全局文本属性设置 二 实现方式: 在App使用的主题中,添加属性及属性值 如给所有的文本设置属性,注释部分作用是设置应用全局字体 <style name"Theme.AppDemo" parent"Base.Theme.AppDemo&q…...
下面是实践百度飞桨上面的pm2.5分类项目_logistic regression相关
part1:数据的引入,和前一个linear regression基本是一样 part2:数据解析——也就是数据的“规格化” 首先,打算用dataMat[]和labelMat[]数据存储feature和label,并且文件变量fr 然后,是这个for line in fr.readlines()循环&#…...
阿里云误删Python后域yum报错解决方案
阿里云误删Python后域yum报错解决方案 1:找回所有依赖 这里依赖可能很多,也搞不清楚有哪些,建议买一台临时服务器,系统选择跟你当前的系统一致的,配置选最低就行 2:登录临时服务器,创建临时文件夹 mkdir /usr/local/yum-fix cd /usr/local/yum-fix3:查找并下载所有云依赖 r…...
unordered-------Hash
✅<1>主页:我的代码爱吃辣📃<2>知识讲解:数据结构——哈希表☂️<3>开发环境:Visual Studio 2022💬<4>前言:哈希是一种映射的思想,哈希表即使利用这种思想,…...
数据仓库总结
1.为什么要做数仓建模 数据仓库建模的目标是通过建模的方法更好的组织、存储数据,以便在性能、成本、效率和数据质量之间找到最佳平衡点。 当有了适合业务和基础数据存储环境的模型(良好的数据模型),那么大数据就能获得以下好处&…...
hadoop学习:mapreduce入门案例二:统计学生成绩
这里相较于 wordcount,新的知识点在于学生实体类的编写以及使用 数据信息: 1. Student 实体类 import org.apache.hadoop.io.WritableComparable;import java.io.DataInput; import java.io.DataOutput; import java.io.IOException;public class Stude…...
自学TypeScript-基础、编译、类型
自学TypeScript-基础、编译、类型 TS 编译为 JS类型支持类型注解基础类型typeof 运算符高级类型class 类构造函数和实例方法继承可见性只读 类型兼容性交叉类型泛型泛型约束多个泛型泛型接口泛型类泛型工具 索引签名类型映射类型索引查询(访问)类型 类型声明文件 TypeScript 是…...
nginx配置https
1.安装nginx 安装完成后检查 nginx -V2.申请证书与上传 阿里云申请免费的证书 然后上传到某个目录 3.修改nginx配置 #user nobody; worker_processes 1;#error_log logs/error.log; #error_log logs/error.log notice; #error_log logs/error.log info;#pid …...
windows Etcd的安装与使用
一、简介 etcd是一个分布式一致性键值存储,其主要用于分布式系统的共享配置和服务发现。 etcd由Go语言编写 二、下载并安装 1.下载地址: https://github.com/coreos/etcd/releases 解压后的目录如下:其中etcd.exe是服务端,e…...
【py】为什么用 import tkinter 不能运行
为什么用 import tkinter 不能运行 ━━━━━━━━━━━━━━━━━━━━━━ 要显示一个信息框,为什么用 import tkinter 不能运行,改成from tkinter import messagebox 就可以运行了? 可能是因为您的代码中只使用了 messagebox 这个模…...
【深度学习】实验04 交叉验证
文章目录 交叉验证划分自定义划分K折交叉验证留一交叉验证留p交叉验证随机排列交叉验证分层K折交叉验证分层随机交叉验证 分割组 k-fold分割留一组分割留 P 组分割随机分割时间序列分割 交叉验证 # 导入相关库# 交叉验证所需函数 from sklearn.model_selection import train_t…...
whisper语音识别部署及WER评价
1.whisper部署 详细过程可以参照:🏠 创建项目文件夹 mkdir whisper cd whisper conda创建虚拟环境 conda create -n py310 python3.10 -c conda-forge -y 安装pytorch pip install --pre torch torchvision torchaudio --extra-index-url 下载whisper p…...
java太卷了,怎么办?
忧虑: 马上就到30岁了,最近对于自己职业生涯的规划甚是焦虑。在网站论坛上,可谓是哀鸿遍野,大家纷纷叙述着自己被裁后求职的艰辛路程,这更加加深了我的忧虑,于是在各大论坛开始“求医问药”,想…...
android多屏触摸相关的详解方案-安卓framework开发手机车载车机系统开发课程
背景 直播免费视频课程地址:https://www.bilibili.com/video/BV1hN4y1R7t2/ 在做双屏相关需求开发过程中,经常会有对两个屏幕都要求可以正确触摸的场景。但是目前我们模拟器默认创建的双屏其实是没有办法进行触摸的 修改方案1 静态修改方案 使用命令…...
微信小程序 实时日志
目录 实时日志 背景 如何使用 如何查看日志 注意事项 实时日志 背景 为帮助小程序开发者快捷地排查小程序漏洞、定位问题,我们推出了实时日志功能。从基础库2.7.1开始,开发者可通过提供的接口打印日志,日志汇聚并实时上报到小程序后台…...
Spring AOP基于注解方式实现和细节
目录 一、Spring AOP底层技术 二、初步实现AOP编程 三、获取切点详细信息 四、 切点表达式语法 五、重用(提取)切点表达式 一、Spring AOP底层技术 SpringAop的核心在于动态代理,那么在SpringAop的底层的技术是依靠了什么技术呢&#x…...
CVPR2023论文及代码合集来啦~
以下内容由马拉AI整理汇总。 下载:点我跳转。 狂肝200小时的良心制作,529篇最新CVPR2023论文及其Code,汇总成册,制作成《CVPR 2023论文代码检索目录》,包括以下方向: 1、2D目标检测 2、视频目标检测 3、…...
基于ETLCloud的自定义规则调用第三方jar包实现繁体中文转为简体中文
背景 前面曾体验过通过零代码、可视化、拖拉拽的方式快速完成了从 MySQL 到 ClickHouse 的数据迁移,但是在实际生产环境,我们在迁移到目标库之前还需要做一些过滤和转换工作;比如,在诗词数据迁移后,发现原来 MySQL 中…...
TDesign在按钮上加入图标组件
在实际开发中 我们经常会遇到例如 添加或者查询 我们需要在按钮上加入图标的操作 TDesign自然也有预备这样的操作 首先我们打开文档看到图标 例如 我们先用某些图标 就可以点开下面的代码 可以看到 我们的图标大部分都是直接用tdesign-icons-vue 导入他的组件就可以了 而我…...
Java 语言特性(面试系列1)
一、面向对象编程 1. 封装(Encapsulation) 定义:将数据(属性)和操作数据的方法绑定在一起,通过访问控制符(private、protected、public)隐藏内部实现细节。示例: public …...
【CSS position 属性】static、relative、fixed、absolute 、sticky详细介绍,多层嵌套定位示例
文章目录 ★ position 的五种类型及基本用法 ★ 一、position 属性概述 二、position 的五种类型详解(初学者版) 1. static(默认值) 2. relative(相对定位) 3. absolute(绝对定位) 4. fixed(固定定位) 5. sticky(粘性定位) 三、定位元素的层级关系(z-i…...
关于 WASM:1. WASM 基础原理
一、WASM 简介 1.1 WebAssembly 是什么? WebAssembly(WASM) 是一种能在现代浏览器中高效运行的二进制指令格式,它不是传统的编程语言,而是一种 低级字节码格式,可由高级语言(如 C、C、Rust&am…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...
安宝特案例丨Vuzix AR智能眼镜集成专业软件,助力卢森堡医院药房转型,赢得辉瑞创新奖
在Vuzix M400 AR智能眼镜的助力下,卢森堡罗伯特舒曼医院(the Robert Schuman Hospitals, HRS)凭借在无菌制剂生产流程中引入增强现实技术(AR)创新项目,荣获了2024年6月7日由卢森堡医院药剂师协会࿰…...
C++.OpenGL (20/64)混合(Blending)
混合(Blending) 透明效果核心原理 #mermaid-svg-SWG0UzVfJms7Sm3e {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-icon{fill:#552222;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-text{fill…...
08. C#入门系列【类的基本概念】:开启编程世界的奇妙冒险
C#入门系列【类的基本概念】:开启编程世界的奇妙冒险 嘿,各位编程小白探险家!欢迎来到 C# 的奇幻大陆!今天咱们要深入探索这片大陆上至关重要的 “建筑”—— 类!别害怕,跟着我,保准让你轻松搞…...
解读《网络安全法》最新修订,把握网络安全新趋势
《网络安全法》自2017年施行以来,在维护网络空间安全方面发挥了重要作用。但随着网络环境的日益复杂,网络攻击、数据泄露等事件频发,现行法律已难以完全适应新的风险挑战。 2025年3月28日,国家网信办会同相关部门起草了《网络安全…...
掌握 HTTP 请求:理解 cURL GET 语法
cURL 是一个强大的命令行工具,用于发送 HTTP 请求和与 Web 服务器交互。在 Web 开发和测试中,cURL 经常用于发送 GET 请求来获取服务器资源。本文将详细介绍 cURL GET 请求的语法和使用方法。 一、cURL 基本概念 cURL 是 "Client URL" 的缩写…...
如何应对敏捷转型中的团队阻力
应对敏捷转型中的团队阻力需要明确沟通敏捷转型目的、提升团队参与感、提供充分的培训与支持、逐步推进敏捷实践、建立清晰的奖励和反馈机制。其中,明确沟通敏捷转型目的尤为关键,团队成员只有清晰理解转型背后的原因和利益,才能降低对变化的…...
