当前位置: 首页 > article >正文

Pytorch图像去噪实战(十二):DDPM图像去噪完整训练流程,构建可复现扩散模型工程

Pytorch图像去噪实战十二DDPM图像去噪完整训练流程构建可复现扩散模型工程一、问题场景扩散模型能跑但工程代码很容易写乱上一篇我们从最小实现理解了 Diffusion 的核心逻辑。但如果真正放到项目里会很快遇到问题beta schedule 写在训练脚本里后续不好改采样逻辑和训练逻辑混在一起模型保存与恢复不规范训练参数不可复现后续无法扩展 DDIM、条件去噪、彩色图像很多人学扩散模型时能写出一个 demo但很难整理成工程。这一篇我们重点做一件事把 DDPM 图像去噪流程整理成一个可复现、可扩展的工程结构。二、DDPM核心训练目标DDPM训练目标仍然是预测噪声epsilon_theta(x_t, t) ≈ epsilon训练时从数据集中取 clean image x0随机采样时间步 t根据 t 给 x0 加噪得到 xt模型输入 xt 和 t模型预测 noise使用 MSELoss 训练三、推荐工程结构ddpm_denoise/ ├── configs/ │ └── train_config.py ├── data/ │ └── train/ ├── models/ │ └── unet.py ├── diffusion/ │ └── ddpm.py ├── dataset.py ├── train.py ├── sample.py └── utils.py这个结构相比简单 demo 有几个好处模型独立扩散过程独立配置独立训练和采样分离后续扩展方便四、配置文件configs/train_config.pyclassTrainConfig:image_size64channels1batch_size32num_workers4epochs100lr2e-4timesteps1000beta_start1e-4beta_end0.02save_interval10data_dirdata/trainsave_dircheckpoints配置单独抽出来最大的好处是实验参数不会散落在代码里。后面复现实验时非常重要。五、数据集代码dataset.pyimportosfromPILimportImagefromtorch.utils.dataimportDatasetimporttorchvision.transformsastransformsclassImageFolderDataset(Dataset):def__init__(self,root_dir,image_size64,channels1):self.paths[os.path.join(root_dir,name)fornameinos.listdir(root_dir)ifname.lower().endswith((.jpg,.jpeg,.png))]ifchannels1:self.modeLelse:self.modeRGBself.transformtransforms.Compose([transforms.Resize((image_size,image_size)),transforms.ToTensor()])def__len__(self):returnlen(self.paths)def__getitem__(self,index):imgImage.open(self.paths[index]).convert(self.mode)returnself.transform(img)六、DDPM扩散类封装diffusion/ddpm.pyimporttorchclassDDPM:def__init__(self,timesteps1000,beta_start1e-4,beta_end0.02,devicecuda):self.timestepstimesteps self.devicedevice self.betastorch.linspace(beta_start,beta_end,timesteps).to(device)self.alphas1.0-self.betas self.alpha_barstorch.cumprod(self.alphas,dim0)self.sqrt_alpha_barstorch.sqrt(self.alpha_bars)self.sqrt_one_minus_alpha_barstorch.sqrt(1.0-self.alpha_bars)defq_sample(self,x0,t,noiseNone):ifnoiseisNone:noisetorch.randn_like(x0)sqrt_alpha_barself.sqrt_alpha_bars[t].view(-1,1,1,1)sqrt_one_minusself.sqrt_one_minus_alpha_bars[t].view(-1,1,1,1)xtsqrt_alpha_bar*x0sqrt_one_minus*noisereturnxt,noisetorch.no_grad()defp_sample(self,model,x,t):betaself.betas[t]alphaself.alphas[t]alpha_barself.alpha_bars[t]batch_ttorch.full((x.size(0),),t,devicex.device,dtypetorch.long)pred_noisemodel(x,batch_t)mean(1/torch.sqrt(alpha))*(x-(beta/torch.sqrt(1.0-alpha_bar))*pred_noise)ift0:noisetorch.randn_like(x)returnmeantorch.sqrt(beta)*noisereturnmean七、UNet噪声预测模型models/unet.pyimporttorchimporttorch.nnasnnclassTimeEmbedding(nn.Module):def__init__(self,dim):super().__init__()self.netnn.Sequential(nn.Linear(1,dim),nn.SiLU(),nn.Linear(dim,dim))defforward(self,t):tt.float().view(-1,1)/1000.0returnself.net(t)classResidualBlock(nn.Module):def__init__(self,in_channels,out_channels,time_dim):super().__init__()self.conv1nn.Conv2d(in_channels,out_channels,3,padding1)self.conv2nn.Conv2d(out_channels,out_channels,3,padding1)self.time_projnn.Linear(time_dim,out_channels)self.shortcutnn.Identity()ifin_channels!out_channels:self.shortcutnn.Conv2d(in_channels,out_channels,1)self.actnn.SiLU()defforward(self,x,t_emb):hself.act(self.conv1(x))timeself.time_proj(t_emb).view(x.size(0),-1,1,1)hhtime hself.conv2(self.act(h))returnhself.shortcut(x)classDDPMUNet(nn.Module):def__init__(self,channels1,base64,time_dim128):super().__init__()self.time_mlpTimeEmbedding(time_dim)self.down1ResidualBlock(channels,base,time_dim)self.down2ResidualBlock(base,base*2,time_dim)self.poolnn.MaxPool2d(2)self.midResidualBlock(base*2,base*2,time_dim)self.upnn.ConvTranspose2d(base*2,base,2,2)self.up_blockResidualBlock(base*2,base,time_dim)self.outnn.Conv2d(base,channels,3,padding1)defforward(self,x,t):t_embself.time_mlp(t)d1self.down1(x,t_emb)d2self.down2(self.pool(d1),t_emb)midself.mid(d2,t_emb)uself.up(mid)utorch.cat([u,d1],dim1)uself.up_block(u,t_emb)returnself.out(u)八、训练脚本train.pyimportosimporttorchfromtorch.utils.dataimportDataLoaderfromconfigs.train_configimportTrainConfigfromdatasetimportImageFolderDatasetfrommodels.unetimportDDPMUNetfromdiffusion.ddpmimportDDPMdeftrain():cfgTrainConfig()os.makedirs(cfg.save_dir,exist_okTrue)devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)datasetImageFolderDataset(root_dircfg.data_dir,image_sizecfg.image_size,channelscfg.channels)loaderDataLoader(dataset,batch_sizecfg.batch_size,shuffleTrue,num_workerscfg.num_workers)modelDDPMUNet(channelscfg.channels).to(device)diffusionDDPM(timestepscfg.timesteps,beta_startcfg.beta_start,beta_endcfg.beta_end,devicedevice)optimizertorch.optim.AdamW(model.parameters(),lrcfg.lr)criteriontorch.nn.MSELoss()forepochinrange(1,cfg.epochs1):model.train()total_loss0forx0inloader:x0x0.to(device)ttorch.randint(0,cfg.timesteps,(x0.size(0),),devicedevice)xt,noisediffusion.q_sample(x0,t)pred_noisemodel(xt,t)losscriterion(pred_noise,noise)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)optimizer.step()total_lossloss.item()avg_losstotal_loss/len(loader)print(fEpoch [{epoch}/{cfg.epochs}], Loss:{avg_loss:.6f})ifepoch%cfg.save_interval0:pathos.path.join(cfg.save_dir,fddpm_epoch_{epoch}.pth)torch.save(model.state_dict(),path)if__name____main__:train()九、采样脚本sample.pyimporttorchimporttorchvision.utilsasvutilsfromconfigs.train_configimportTrainConfigfrommodels.unetimportDDPMUNetfromdiffusion.ddpmimportDDPMtorch.no_grad()defsample():cfgTrainConfig()devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)modelDDPMUNet(channelscfg.channels).to(device)model.load_state_dict(torch.load(checkpoints/ddpm_epoch_100.pth,map_locationdevice))model.eval()diffusionDDPM(timestepscfg.timesteps,beta_startcfg.beta_start,beta_endcfg.beta_end,devicedevice)xtorch.randn(16,cfg.channels,cfg.image_size,cfg.image_size).to(device)fortinreversed(range(cfg.timesteps)):xdiffusion.p_sample(model,x,t)xtorch.clamp(x,0.0,1.0)vutils.save_image(x.cpu(),ddpm_sample.png,nrow4)if__name____main__:sample()十、为什么要做工程拆分很多扩散模型代码一开始写在一个文件里能跑但很难维护。工程拆分带来的好处diffusion类可复用UNet可替换config方便调参train和sample互不干扰后续DDIM可以直接扩展这也是从“能跑demo”到“能做项目”的关键一步。十一、踩坑记录坑1采样结果全是噪声常见原因模型训练不够时间步输入错误beta schedule太激进采样公式写错建议先用小数据集验证过拟合能力。坑2loss下降但采样效果差DDPM的loss下降不代表马上能生成好图。采样质量通常需要更多训练轮数。坑3训练太慢DDPM采样慢是正常现象因为要从 T 逐步采样。后续可以使用 DDIM 或减少 timesteps。十二、适合收藏总结DDPM工程化流程配置文件管理参数Dataset加载图像DDPM类负责加噪和采样UNet预测噪声train.py训练模型sample.py生成结果避坑清单不要把所有代码写一个文件时间步必须正确传入beta schedule要稳定采样结果差不一定是loss问题先用小尺寸图跑通十三、优化建议后续可以继续做DDIM加速采样条件Diffusion去噪彩色图像支持EMA模型权重混合精度训练结尾总结DDPM不是一个单独模型而是一套完整的扩散训练和采样框架。如果你只是写一个demo很容易跑通但如果要长期做系列实验就必须从一开始整理好工程结构。这一篇的重点不是追求最强效果而是把DDPM搭成一个稳定可复现的项目骨架。下一篇预告Pytorch图像去噪实战十三DDIM加速采样让扩散模型去噪从1000步降到50步

相关文章:

Pytorch图像去噪实战(十二):DDPM图像去噪完整训练流程,构建可复现扩散模型工程

Pytorch图像去噪实战(十二):DDPM图像去噪完整训练流程,构建可复现扩散模型工程一、问题场景:扩散模型能跑,但工程代码很容易写乱 上一篇我们从最小实现理解了 Diffusion 的核心逻辑。 但如果真正放到项目里…...

电子制造企业设施升级与产能优化实践

1. 电子制造企业的设施升级战略解析当我在电子制造行业深耕十五年后,深刻认识到一个真理:生产线上的每一寸空间都是利润的战场。最近研究Epec公司的设施升级案例时,发现这个投资50万美元的改造项目完美诠释了现代电子制造企业的升级逻辑——不…...

CANoe硬件过滤实战:用VN5000给车载以太网测试‘减负’,避开数据丢失坑

CANoe硬件过滤实战:用VN5000给车载以太网测试‘减负’,避开数据丢失坑 当车载以太网测试遇到每秒数千帧的ADAS数据洪流,或是持续数小时的OTA刷写压力测试时,工程师们常常面临一个两难选择:要么忍受卡顿的实时分析体验&…...

手机号查QQ号终极指南:3分钟学会逆向查询技术

手机号查QQ号终极指南:3分钟学会逆向查询技术 【免费下载链接】phone2qq 项目地址: https://gitcode.com/gh_mirrors/ph/phone2qq 你是否曾经需要快速查询手机号对应的QQ号?手机号查QQ工具正是为你量身打造的Python解决方案!这个开源…...

XUnity.AutoTranslator完整指南:5分钟掌握Unity游戏实时翻译的终极解决方案

XUnity.AutoTranslator完整指南:5分钟掌握Unity游戏实时翻译的终极解决方案 【免费下载链接】XUnity.AutoTranslator 项目地址: https://gitcode.com/gh_mirrors/xu/XUnity.AutoTranslator 你是否曾经因为语言障碍而无法畅玩心爱的日系RPG或欧美独立游戏&am…...

通过审计日志追溯团队成员的模型API调用记录与安全事件

通过审计日志追溯团队成员的模型API调用记录与安全事件 1. 企业API调用管理的核心需求 在企业环境中使用大模型API时,管理员通常面临三个关键挑战:权限管控颗粒度不足、异常调用难追溯、成本归属不透明。传统方案需要自行搭建日志系统或依赖厂商分散的…...

保姆级避坑指南:在Jetson Orin NX上搞定Pixhawk 6X飞控固件编译与烧写(附IMU频率修改)

保姆级避坑指南:在Jetson Orin NX上搞定Pixhawk 6X飞控固件编译与烧写(附IMU频率修改) 当你手头只有一台Jetson Orin NX,却需要完成Pixhawk 6X飞控的固件编译、修改和烧写全流程时,传统的QGroundControl方案突然变得不…...

医疗大模型问答合规性断崖式失效?——Dify 0.12.0+新合规插件包(含GDPR/《个人信息保护法》双模校验器)首次深度拆解

更多请点击: https://intelliparadigm.com 第一章:医疗大模型问答合规性断崖式失效的根源诊断 医疗大模型在临床辅助决策场景中频繁出现合规性“断崖式”失效——即模型在训练/测试阶段表现稳健,但上线后短期内迅速产出违反《互联网诊疗监管…...

从行政区划代码到地图可视化:教你用ECharts快速生成中国省市区层级关系图

从行政区划代码到地图可视化:用ECharts构建中国省市区层级关系图实战指南 1. 行政区划数据的前期处理 行政区划代码作为国家标准编码体系,是地理信息系统的基础数据。但在实际可视化应用中,原始代码表需要经过结构化转换才能被ECharts等工具识…...

【PHP Swoole × LLM长连接实战权威指南】:20年架构师亲授零丢包、低延迟、万级并发配置全流程

更多请点击: https://intelliparadigm.com 第一章:Swoole LLM长连接架构全景与核心挑战 Swoole 作为高性能异步协程 PHP 扩展,与大语言模型(LLM)服务结合时,天然适配流式响应、低延迟会话维持与高并发连接…...

Transformer中斜杠主导注意力头的形成机制研究

1. 项目背景与核心问题在自然语言处理领域,Transformer架构已经成为事实上的标准模型框架。随着模型规模的不断扩大,研究者们逐渐发现了一个有趣的现象:某些特定的注意力头(Attention Head)会自发地形成一种特殊的行为…...

ARM NEON指令集:浮点倒数与平方根优化实践

1. ARM NEON指令集概述 NEON是ARM架构下的SIMD(单指令多数据)扩展指令集,主要应用于Cortex-A系列处理器。它通过128位寄存器同时操作多个数据元素,显著提升多媒体编解码、数字信号处理、图形处理等计算密集型任务的性能。NEON技术…...

Dreambooth微调Stable Diffusion:精准定制AI图像生成

1. 项目概述:Dreambooth微调Stable Diffusion的核心价值去年当Stable Diffusion首次开源时,整个AI绘图领域为之震动。但很快我们就发现,虽然它能生成各种风格的图像,却很难精确还原特定人物、物体或艺术风格的特征。这正是Dreambo…...

保姆级教程:用Realsense D435i和YOLOv5s实现物体三维坐标实时测量(附完整代码)

从零实现Realsense D435i与YOLOv5的物体三维坐标测量实战指南 当机械臂需要精准抓取传送带上的零件,或是AR应用要在真实场景中叠加虚拟物体时,获取目标物体的三维位置信息就成了关键。Intel Realsense D435i深度相机与YOLOv5目标检测算法的组合&#xff…...

《数术原本》(卷一 正统典藏定本)

《数术原本》(卷一 正统典藏定本) 作者:乖乖数学(20260501)《数术原本》(卷一_正统典藏定本)。文档中并未包含具体指令,因此,我将依据文档内容,为您提供一份详…...

Thinking with Visual Primitives【用视觉原语思考】

Thinking with Visual Primitives 用视觉原语思考 Ruijie Lu1,2,∗\mathrm { L u ^ { 1 , 2 , * } }Lu1...

告别蒙圈!用Python手搓Sarsa与Q-learning,搞懂时序差分TD算法的核心差异

从零实现Sarsa与Q-learning:揭秘时序差分算法的本质差异 在强化学习领域,时序差分(Temporal Difference, TD)算法如同一位隐形的导师,它不需要等待完整的学习过程结束,就能在每一步给予我们反馈和指导。想象一下,你正在…...

数独AI求解器:从回溯算法到LLM推理的技术实现

1. 项目概述:当数独遇上AI,一场关于逻辑与推理的深度对话如果你和我一样,对数独这项经典的逻辑游戏抱有浓厚的兴趣,同时又对人工智能如何“思考”充满好奇,那么“Keyoku-ai/keyoku”这个项目绝对值得你花时间深入研究。…...

PHP 9.0 + RAG + Async Streams全栈部署,支撑万级并发AI会话的5大核心配置,你漏了第3个?

更多请点击: https://intelliparadigm.com 第一章:PHP 9.0 RAG Async Streams全栈AI会话架构全景 PHP 9.0(预发布版)原生支持协程级异步 I/O 与结构化并发,结合 RAG(Retrieval-Augmented Generation&…...

江西省人民医院红谷滩分院电话0791-87720770 / 87720771打不通,什么原因?

◆◆ 预约方式◆◆(一)扫描微信二维码或支付宝二维码预约(二)预约电话:0791-87720770 / 87720771据了解,红谷滩院区是院本部优质医疗业务的同质拓展和延伸,占地约126亩,建筑总面积约…...

STM32H7B0VBT6驱动SHT40温湿度传感器:硬件I2C配置与HAL库实战避坑

STM32H7B0VBT6硬件I2C驱动SHT40温湿度传感器全流程解析 在嵌入式系统开发中,精确的环境监测往往离不开温湿度传感器的支持。Sensirion推出的SHT40作为第四代数字温湿度传感器,以其高精度和低功耗特性成为工业级应用的热门选择。本文将深入探讨如何基于ST…...

通过TaotokenAPI管理功能实现团队密钥分发与调用审计

通过Taotoken API管理功能实现团队密钥分发与调用审计 1. 团队API Key管理基础 在Taotoken平台上,团队管理员可以通过控制台集中管理多个API Key。每个Key可以设置独立的权限范围和使用配额,便于分配给不同成员或项目使用。登录控制台后,导航…...

为内容创作平台集成 Taotoken 实现按需调用不同风格的文案生成模型

为内容创作平台集成 Taotoken 实现按需调用不同风格的文案生成模型 1. 多模型统一接入的业务需求 内容创作平台通常需要支持多种文案风格,从正式报告到创意故事,每种风格对生成模型的要求各不相同。传统方案需要对接多个厂商的 API,分别管理…...

Taotoken 模型广场如何帮助开发者快速选型与对比不同大模型

Taotoken 模型广场如何帮助开发者快速选型与对比不同大模型 1. 模型广场的核心功能 Taotoken 模型广场作为统一入口,聚合了当前主流的大语言模型服务。开发者登录控制台后,可在「模型广场」页面查看所有可用模型的列表。每个模型卡片展示了基础信息&am…...

概率论在机器学习中的核心应用与实践

1. 概率论与机器学习的共生关系 概率论是机器学习领域最基础也最容易被低估的数学工具。我在工业界参与过的所有真实项目里,从简单的用户点击率预测到复杂的自动驾驶决策系统,概率模型都扮演着核心角色。举个实际案例:当我们为电商平台构建推…...

ARM GIC虚拟化指令陷阱机制解析与应用

1. ARM GIC虚拟化指令陷阱机制概述在ARM架构的虚拟化环境中,通用中断控制器(GIC)的虚拟化扩展是实现高效、安全中断处理的关键组件。作为系统级芯片(SoC)中负责管理和分发中断的核心模块,GIC在虚拟化场景下需要解决一个基本问题:如何让多个虚…...

基于Git与Markdown的自动化写作系统:打造高效数字工作台

1. 项目概述:一个为写作者打造的“数字工作台” 如果你经常写作,无论是技术博客、小说、学术论文还是日常笔记,大概率都经历过这样的困扰:文档散落在电脑各处,格式五花八门,想找个去年的草稿得翻半天&#…...

造纸机烘缸轴承故障预测【附代码】

✅ 博主简介:擅长数据搜集与处理、建模仿真、程序设计、仿真代码、论文写作与指导,毕业论文、期刊论文经验交流。 ✅ 如需沟通交流,扫描文章底部二维码。(1)烘缸表面温度分布特征与振动联合诊断:造纸机烘缸…...

2026年腾讯云OpenClaw/Hermes Agent集成攻略+百炼token Plan配置全览攻略

2026年腾讯云OpenClaw/Hermes Agent集成攻略百炼token Plan配置全览攻略。OpenClaw和Hermes Agent是什么?OpenClaw和Hermes Agent怎么部署?如何部署OpenClaw/Hermes Agent?2026年还在为部署OpenClaw和Hermes Agent到处找教程踩坑吗&#xff1…...

2026年Hermes Agent/OpenClaw怎么搭建?阿里云部署及token Plan配置解析

2026年Hermes Agent/OpenClaw怎么搭建?阿里云部署及token Plan配置解析。OpenClaw和Hermes Agent是什么?OpenClaw和Hermes Agent怎么部署?如何部署OpenClaw/Hermes Agent?2026年还在为部署OpenClaw和Hermes Agent到处找教程踩坑吗…...