当前位置: 首页 > news >正文

diffusers 源码待理解之处

一、训练DreamBooth时,相关代码的细节小计

在这里插入图片描述
**

class_labels = timesteps 时,模型的前向传播怎么走?待深入去看

**

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

利用class_prompt去生成数据,而不是instance_prompt

在这里插入图片描述

class DreamBoothDataset(Dataset):"""A dataset to prepare the instance and class images with the prompts for fine-tuning the model.It pre-processes the images and the tokenizes prompts."""def __init__(self,instance_data_root,instance_prompt,tokenizer,class_data_root=None,class_prompt=None,class_num=None,size=512,center_crop=False,encoder_hidden_states=None,class_prompt_encoder_hidden_states=None,tokenizer_max_length=None,):self.size = sizeself.center_crop = center_cropself.tokenizer = tokenizerself.encoder_hidden_states = encoder_hidden_statesself.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_statesself.tokenizer_max_length = tokenizer_max_lengthself.instance_data_root = Path(instance_data_root)if not self.instance_data_root.exists():raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")self.instance_images_path = list(Path(instance_data_root).iterdir())self.num_instance_images = len(self.instance_images_path)self.instance_prompt = instance_promptself._length = self.num_instance_imagesif class_data_root is not None:self.class_data_root = Path(class_data_root)self.class_data_root.mkdir(parents=True, exist_ok=True)self.class_images_path = list(self.class_data_root.iterdir())if class_num is not None:self.num_class_images = min(len(self.class_images_path), class_num)else:self.num_class_images = len(self.class_images_path)self._length = max(self.num_class_images, self.num_instance_images)self.class_prompt = class_promptelse:self.class_data_root = Noneself.image_transforms = transforms.Compose([transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),])def __len__(self):return self._lengthdef __getitem__(self, index):example = {}instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])instance_image = exif_transpose(instance_image)if not instance_image.mode == "RGB":instance_image = instance_image.convert("RGB")example["instance_images"] = self.image_transforms(instance_image)if self.encoder_hidden_states is not None:example["instance_prompt_ids"] = self.encoder_hidden_stateselse:text_inputs = tokenize_prompt(self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length)example["instance_prompt_ids"] = text_inputs.input_idsexample["instance_attention_mask"] = text_inputs.attention_maskif self.class_data_root:class_image = Image.open(self.class_images_path[index % self.num_class_images])class_image = exif_transpose(class_image)if not class_image.mode == "RGB":class_image = class_image.convert("RGB")example["class_images"] = self.image_transforms(class_image)if self.class_prompt_encoder_hidden_states is not None:example["class_prompt_ids"] = self.class_prompt_encoder_hidden_stateselse:class_text_inputs = tokenize_prompt(self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length)example["class_prompt_ids"] = class_text_inputs.input_idsexample["class_attention_mask"] = class_text_inputs.attention_maskreturn example
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):if tokenizer_max_length is not None:max_length = tokenizer_max_lengthelse:max_length = tokenizer.model_max_lengthtext_inputs = tokenizer(prompt,truncation=True,padding="max_length",max_length=max_length,return_tensors="pt",)return text_inputs
def collate_fn(examples, with_prior_preservation=False):has_attention_mask = "instance_attention_mask" in examples[0]input_ids = [example["instance_prompt_ids"] for example in examples]pixel_values = [example["instance_images"] for example in examples]if has_attention_mask:attention_mask = [example["instance_attention_mask"] for example in examples]# Concat class and instance examples for prior preservation.# We do this to avoid doing two forward passes.if with_prior_preservation:input_ids += [example["class_prompt_ids"] for example in examples]pixel_values += [example["class_images"] for example in examples]if has_attention_mask:attention_mask += [example["class_attention_mask"] for example in examples]pixel_values = torch.stack(pixel_values)pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()input_ids = torch.cat(input_ids, dim=0)batch = {"input_ids": input_ids,"pixel_values": pixel_values,}if has_attention_mask:attention_mask = torch.cat(attention_mask, dim=0)batch["attention_mask"] = attention_maskreturn batch

Dataset和Dataloader的构成
在这里插入图片描述
为了避免模型过拟合或者是说语言漂移的情况,需要用模型去用一个普通的prompt先生成样本。

fine-tune text-encoder,但是对显存要求更高
在这里插入图片描述

二、训练text to image,相关代码的细节小计

**

1、Dataloader的构建如下,但是为啥没有attention_mask呢?训练DreamBooth时有
2、训练或者微调模型时需要图文数据对,如果没有文本数据,可以用BLIP去生成图像描述的文本,但是文本描述不一定可靠
**

 # Get the datasets: you can either provide your own training and evaluation files (see below)# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).# In distributed training, the load_dataset function guarantees that only one local process can concurrently# download the dataset.if args.dataset_name is not None:# Downloading and loading a dataset from the hub.dataset = load_dataset(args.dataset_name,args.dataset_config_name,cache_dir=args.cache_dir,data_dir=args.train_data_dir,)else:data_files = {}if args.train_data_dir is not None:data_files["train"] = os.path.join(args.train_data_dir, "**")dataset = load_dataset("imagefolder",data_files=data_files,cache_dir=args.cache_dir,)# See more about loading custom images at# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder# Preprocessing the datasets.# We need to tokenize inputs and targets.column_names = dataset["train"].column_names# 6. Get the column names for input/target.dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)if args.image_column is None:image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]else:image_column = args.image_columnif image_column not in column_names:raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}")if args.caption_column is None:caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]else:caption_column = args.caption_columnif caption_column not in column_names:raise ValueError(f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}")# Preprocessing the datasets.# We need to tokenize input captions and transform the images.def tokenize_captions(examples, is_train=True):captions = []for caption in examples[caption_column]:if isinstance(caption, str):captions.append(caption)elif isinstance(caption, (list, np.ndarray)):# take a random caption if there are multiplecaptions.append(random.choice(caption) if is_train else caption[0])else:raise ValueError(f"Caption column `{caption_column}` should contain either strings or lists of strings.")inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt")return inputs.input_ids# Preprocessing the datasets.train_transforms = transforms.Compose([transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),])def preprocess_train(examples):images = [image.convert("RGB") for image in examples[image_column]]examples["pixel_values"] = [train_transforms(image) for image in images]examples["input_ids"] = tokenize_captions(examples)# images text pixel_values input_ids 4种keyreturn exampleswith accelerator.main_process_first():if args.max_train_samples is not None:dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))# Set the training transformstrain_dataset = dataset["train"].with_transform(preprocess_train)def collate_fn(examples):pixel_values = torch.stack([example["pixel_values"] for example in examples])pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()input_ids = torch.stack([example["input_ids"] for example in examples])return {"pixel_values": pixel_values, "input_ids": input_ids}# DataLoaders creation:train_dataloader = torch.utils.data.DataLoader(train_dataset,shuffle=True,collate_fn=collate_fn,batch_size=args.train_batch_size,num_workers=args.dataloader_num_workers,)

三、训ControlNet

Dataloader的搭建的代码如下:


1、新增conditioning_pixel_values图像数据,用于做可控的生成
2、输入中依旧没有attention-mask,待思考


def make_train_dataset(args, tokenizer, accelerator):# Get the datasets: you can either provide your own training and evaluation files (see below)# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).# In distributed training, the load_dataset function guarantees that only one local process can concurrently# download the dataset.if args.dataset_name is not None:# Downloading and loading a dataset from the hub.dataset = load_dataset(args.dataset_name,args.dataset_config_name,cache_dir=args.cache_dir,)else:if args.train_data_dir is not None:dataset = load_dataset(args.train_data_dir,cache_dir=args.cache_dir,)# See more about loading custom images at# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script# Preprocessing the datasets.# We need to tokenize inputs and targets.column_names = dataset["train"].column_names# 6. Get the column names for input/target.if args.image_column is None:image_column = column_names[0]logger.info(f"image column defaulting to {image_column}")else:image_column = args.image_columnif image_column not in column_names:raise ValueError(f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}")if args.caption_column is None:caption_column = column_names[1]logger.info(f"caption column defaulting to {caption_column}")else:caption_column = args.caption_columnif caption_column not in column_names:raise ValueError(f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}")if args.conditioning_image_column is None:conditioning_image_column = column_names[2]logger.info(f"conditioning image column defaulting to {conditioning_image_column}")else:conditioning_image_column = args.conditioning_image_columnif conditioning_image_column not in column_names:raise ValueError(f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}")def tokenize_captions(examples, is_train=True):captions = []for caption in examples[caption_column]:if random.random() < args.proportion_empty_prompts:captions.append("")elif isinstance(caption, str):captions.append(caption)elif isinstance(caption, (list, np.ndarray)):# take a random caption if there are multiplecaptions.append(random.choice(caption) if is_train else caption[0])else:raise ValueError(f"Caption column `{caption_column}` should contain either strings or lists of strings.")inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt")return inputs.input_idsimage_transforms = transforms.Compose([transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),transforms.CenterCrop(args.resolution),transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),])conditioning_image_transforms = transforms.Compose([transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),transforms.CenterCrop(args.resolution),transforms.ToTensor(),])def preprocess_train(examples):images = [image.convert("RGB") for image in examples[image_column]]images = [image_transforms(image) for image in images]conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]examples["pixel_values"] = imagesexamples["conditioning_pixel_values"] = conditioning_imagesexamples["input_ids"] = tokenize_captions(examples)return exampleswith accelerator.main_process_first():if args.max_train_samples is not None:dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))# Set the training transformstrain_dataset = dataset["train"].with_transform(preprocess_train)return train_datasetdef collate_fn(examples):pixel_values = torch.stack([example["pixel_values"] for example in examples])pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()input_ids = torch.stack([example["input_ids"] for example in examples])return {"pixel_values": pixel_values,"conditioning_pixel_values": conditioning_pixel_values,"input_ids": input_ids,}

相关文章:

diffusers 源码待理解之处

一、训练DreamBooth时&#xff0c;相关代码的细节小计 ** class_labels timesteps 时&#xff0c;模型的前向传播怎么走&#xff1f;待深入去看 ** 利用class_prompt去生成数据&#xff0c;而不是instance_prompt class DreamBoothDataset(Dataset):"""A dat…...

正则表达式 详解,10分钟学会

大家好&#xff0c;欢迎来到停止重构的频道。 本期我们讨论正则表达式。 正则表达式是一种用于匹配和操作文本的工具&#xff0c;常用于文本查找、文本替换、校验文本格式等场景。 正则表达式不仅是写代码时才会使用&#xff0c;在平常使用的很多文本编辑软件&#xff0c;都…...

【排序算法】归并排序与快速排序:深入解析与比较

文章目录 1. 引言2. 归并排序&#xff08;Merge Sort&#xff09;3. 快速排序&#xff08;Quick Sort&#xff09;4. 归并排序与快速排序的比较5. 结论 1. 引言 排序算法是计算机科学中最基本且至关重要的概念之一。它们不仅是理解更复杂算法和数据结构的基石&#xff0c;而且…...

万字长文谈自动驾驶bev感知(一)

文章目录 prologuepaper listcamera bev :1. Lift, Splat, Shoot: Encoding Images from Arbitrary Camera Rigs by Implicitly Unprojecting to 3D2. M2BEV: Multi-Camera Joint 3D Detection and Segmentation with Unified Birds-Eye View Representation3. BEVDet: High-Pe…...

cfa一级考生复习经验分享系列(十七)

考场经验&#xff1a; 1.本人在Prometric广州考试中心&#xff0c;提前一天在附近住下&#xff0c;地方比较好找&#xff0c;到了百汇广场北门&#xff0c;进去就可以看见电梯直达10楼。进去之后需要现场检查行程卡和健康码&#xff0c;然后会问最近你有没有发烧咳嗽等问题&…...

机器人活动区域 - 华为OD统一考试

OD统一考试 题解: Java / Python / C++ 题目描述 现有一个机器人,可放置于 M x N 的网格中任意位置,每个网格包含一个非负整数编号,当相邻网格的数字编号差值的绝对值小于等于 1 时机器人可以在网格间移动。 问题: 求机器人可活动的最大范围对应的网格点数目。 说明: 网格…...

三、HTML元素

一、HTML元素 HTML 文档由 HTML 元素定义。 *开始标签常被称为起始标签&#xff08;opening tag&#xff09;&#xff0c;结束标签常称为闭合标签&#xff08;closing tag&#xff09;。 二、HTML 元素语法 HTML 元素以开始标签起始。HTML 元素以结束标签终止。元素的内容是…...

置顶> 个人学习记录一览

个人学习记录一览表 写个说明   知识学的好&#xff0c;不如笔记记得好&#xff0c;知识点的遗忘在所难免&#xff0c;这里记录我个人的学习过程&#xff0c;以备后面二次学习使用。 Linux 操作系统 Linux 操作系统 001-介绍 Linux 操作系统 002-VMware Workstation的相关操…...

c++重载操作符

支持重载操作符是c的一个特性&#xff0c;先不管好不好用&#xff0c;这起码能让它看起来比其他语言NB很多&#xff0c;但真正了解重载操作符后&#xff0c;就会发现这个特性...就这&#xff1f;本文分两个部分 重载操作符简介和使用——适用新手重载操作符的原理和sao操作——…...

C# 如何读取Excel文件

当处理Excel文件时&#xff0c;从中读取数据是一个常见的需求。通过读取Excel数据&#xff0c;可以获取电子表格中包含的信息&#xff0c;并在其他应用程序或编程环境中使用这些数据进行进一步的处理和分析。本文将分享一个使用免费库来实现C#中读取Excel数据的方法。具体如下&…...

Vue2面试题:说一下对vuex的理解?

五种状态&#xff1a; state: 存储公共数据 this.$store.state mutations&#xff1a;同步操作&#xff0c;改变store的数据 this.$store.commit() actions: 异步操作&#xff0c;让mutations中的方法能在异步操作中起作用 this.$store.dispatch() getters: 计算属性 th…...

elasticsearch系列五:集群的备份与恢复

概述 前几篇咱们讲了es的语法、存储的优化、常规运维等等&#xff0c;今天咱们看下如何备份数据和恢复数据。 在传统的关系型数据库中我们有多种备份方式&#xff0c;常见有热备、冷备、全量定时增量备份、通过开发程序备份等等&#xff0c;其实在es中是一样的。 官方建议采用s…...

【Elasticsearch源码】 分片恢复分析

带着疑问学源码&#xff0c;第七篇&#xff1a;Elasticsearch 分片恢复分析 代码分析基于&#xff1a;https://github.com/jiankunking/elasticsearch Elasticsearch 8.0.0-SNAPSHOT 目的 在看源码之前先梳理一下&#xff0c;自己对于分片恢复的疑问点&#xff1a; 网上对于E…...

elasticsearch如何操作索引库里面的文档

上节介绍了索引库的CRUD&#xff0c;接下来操作索引库里面的文档 目录 一、添加文档 二、查询文档 三、删除文档 四、修改文档 一、添加文档 新增文档的DSL语法如下 POST /索引库名/_doc/文档id(不加id,es会自动生成) { "字段1":"值1", "字段2&q…...

opencv期末练习题(2)附带解析

图像插值与缩放 %matplotlib inline import cv2 import matplotlib.pyplot as plt def imshow(img,grayFalse,bgr_modeFalse):if gray:img cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)plt.imshow(img,cmap"gray")else:if not bgr_mode:img cv2.cvtColor(img,cv2.COLOR_B…...

【Mybatis】深入学习MyBatis:高级特性与Spring整合

&#x1f34e;个人博客&#xff1a;个人主页 &#x1f3c6;个人专栏&#xff1a; Mybatis ⛳️ 功不唐捐&#xff0c;玉汝于成 目录 前言 正文 高级特性 1 一级缓存和二级缓存 一级缓存 二级缓存 2 延迟加载 5 整合Spring 1 MyBatis-Spring模块 2 事务管理 结…...

C语言与人生函数的对比,使用,参数详解

各位少年&#xff0c;大家好&#xff0c;我是博主那一脸阳光。&#xff0c;今天给大家分享函数的定义&#xff0c;和数学的函数的区别和使用 前言&#xff1a;C语言中的函数和数学中的函数在概念上有相似之处&#xff0c;但也存在显著的区别。下面对比它们的主要特点&#xff…...

机器人动力学一些笔记

动力学方程中&#xff0c;Q和q的关系(Q是sita) Q其实是一个向量&#xff0c;q(Q1&#xff0c;Q2&#xff0c;Q3&#xff0c;Q4&#xff0c;Q5&#xff0c;Q6)&#xff08;假如6个关节&#xff09; https://zhuanlan.zhihu.com/p/25789930 举个浅显易懂的例子&#xff0c;你在房…...

Plantuml之甘特图语法介绍(二十八)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…...

Docker support for NVIDIA GPU Accelerated Computing on WSL 2

Docker support for NVIDIA GPU Accelerated Computing on WSL 2 0. 背景1. 安装 Docker Desktop2. 配置 Docker Desktop3. WLS Ubuntu 配置4. 安装 Docker-ce5. 安装 NVIDIA Container Toolkit6. 配置 Docker7. 运行一个 Sample Workload 0. 背景 今天尝试一下 NVIDIA GPU 在…...

Chapter03-Authentication vulnerabilities

文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...

ES6从入门到精通:前言

ES6简介 ES6&#xff08;ECMAScript 2015&#xff09;是JavaScript语言的重大更新&#xff0c;引入了许多新特性&#xff0c;包括语法糖、新数据类型、模块化支持等&#xff0c;显著提升了开发效率和代码可维护性。 核心知识点概览 变量声明 let 和 const 取代 var&#xf…...

模型参数、模型存储精度、参数与显存

模型参数量衡量单位 M&#xff1a;百万&#xff08;Million&#xff09; B&#xff1a;十亿&#xff08;Billion&#xff09; 1 B 1000 M 1B 1000M 1B1000M 参数存储精度 模型参数是固定的&#xff0c;但是一个参数所表示多少字节不一定&#xff0c;需要看这个参数以什么…...

大语言模型如何处理长文本?常用文本分割技术详解

为什么需要文本分割? 引言:为什么需要文本分割?一、基础文本分割方法1. 按段落分割(Paragraph Splitting)2. 按句子分割(Sentence Splitting)二、高级文本分割策略3. 重叠分割(Sliding Window)4. 递归分割(Recursive Splitting)三、生产级工具推荐5. 使用LangChain的…...

【Oracle】分区表

个人主页&#xff1a;Guiat 归属专栏&#xff1a;Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...

云原生玩法三问:构建自定义开发环境

云原生玩法三问&#xff1a;构建自定义开发环境 引言 临时运维一个古董项目&#xff0c;无文档&#xff0c;无环境&#xff0c;无交接人&#xff0c;俗称三无。 运行设备的环境老&#xff0c;本地环境版本高&#xff0c;ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...

sipsak:SIP瑞士军刀!全参数详细教程!Kali Linux教程!

简介 sipsak 是一个面向会话初始协议 (SIP) 应用程序开发人员和管理员的小型命令行工具。它可以用于对 SIP 应用程序和设备进行一些简单的测试。 sipsak 是一款 SIP 压力和诊断实用程序。它通过 sip-uri 向服务器发送 SIP 请求&#xff0c;并检查收到的响应。它以以下模式之一…...

C++使用 new 来创建动态数组

问题&#xff1a; 不能使用变量定义数组大小 原因&#xff1a; 这是因为数组在内存中是连续存储的&#xff0c;编译器需要在编译阶段就确定数组的大小&#xff0c;以便正确地分配内存空间。如果允许使用变量来定义数组的大小&#xff0c;那么编译器就无法在编译时确定数组的大…...

打手机检测算法AI智能分析网关V4守护公共/工业/医疗等多场景安全应用

一、方案背景​ 在现代生产与生活场景中&#xff0c;如工厂高危作业区、医院手术室、公共场景等&#xff0c;人员违规打手机的行为潜藏着巨大风险。传统依靠人工巡查的监管方式&#xff0c;存在效率低、覆盖面不足、判断主观性强等问题&#xff0c;难以满足对人员打手机行为精…...

【学习笔记】erase 删除顺序迭代器后迭代器失效的解决方案

目录 使用 erase 返回值继续迭代使用索引进行遍历 我们知道类似 vector 的顺序迭代器被删除后&#xff0c;迭代器会失效&#xff0c;因为顺序迭代器在内存中是连续存储的&#xff0c;元素删除后&#xff0c;后续元素会前移。 但一些场景中&#xff0c;我们又需要在执行删除操作…...