上海做网站公,做h5场景的网站,宿州房产网,企业网站管理系统 软件著作权Diffusion Models专栏文章汇总#xff1a;入门与实战 前言#xff1a;自从SDXL提出了长宽桶技术之后#xff0c;彻底解决了不同长宽比的图像输入问题#xff0c;现在已经成为训练扩散模型必选的方案。这篇博客从代码详细解读如何在模型训练的时候运用长宽桶技术(Aspect Rat… Diffusion Models专栏文章汇总入门与实战 前言自从SDXL提出了长宽桶技术之后彻底解决了不同长宽比的图像输入问题现在已经成为训练扩散模型必选的方案。这篇博客从代码详细解读如何在模型训练的时候运用长宽桶技术(Aspect Ratio Bucketing)。 目录
原理解读-原有训练的问题
长宽桶技术(Aspect Ratio Bucketing)
完整代码 原理解读-原有训练的问题
纵横比分桶训练可以极大地提高输出质量现有图像生成模型的一个常见问题是它们非常容易生成带有非自然作物的图像。这是因为这些模型被训练成生成方形图像。然而大多数照片和艺术品都不是方形的。然而该模型只能同时在相同大小的图像上工作并且在训练过程中通常的做法是同时在多个训练样本上操作以优化所使用gpu的效率。作为妥协选择正方形图像在训练过程中只裁剪出每个图像的中心然后作为训练样例显示给图像生成模型。
例如人类通常是没有脚或头的剑只有一个刀刃剑柄和剑尖在框架外。因为我们正在创建一个图像生成模型来配合我们的故事叙述体验所以我们的模型能够产生适当的未裁剪的角色是很重要的并且生成的骑士不应该持有延伸到无限的金属状直线。
对裁剪图像进行训练的另一个问题是它可能导致文本和图像之间的不匹配。例如带有王冠标签的图像通常在中央裁剪后不再包含王冠因此君主已经被斩首。我们发现使用随机作物代替中心作物只能略微改善这些问题。使用具有可变图像大小的稳定扩散是可能的尽管可以注意到远远超过512x512的原生分辨率往往会引入重复的图像元素并且非常低的分辨率会产生无法识别的图像。
尽管如此这向我们表明在可变大小的图像上训练模型应该是可能的。在单个、可变大小的样本上进行训练是微不足道的但也非常缓慢而且由于使用小批量提供的缺乏正则化更容易产生训练不稳定性。
长宽桶技术(Aspect Ratio Bucketing)
由于这个问题似乎没有现有的解决方案我们已经为我们的数据集实现了自定义批生成代码允许创建批处理其中批处理中的每个项目具有相同的大小但批处理的图像大小可能不同。
我们通过一种叫做宽高比桶的方法来做到这一点。另一种方法是使用固定的图像大小缩放每个图像以适应这个固定的大小并应用在训练期间被掩盖的填充。由于这会导致训练期间不必要的计算我们没有选择遵循这种替代方法。
在下面我们描述了我们自定义的宽高比桶的批量生成方案背后的原始想法。
首先我们必须定义要将数据集的图像排序到哪个存储桶中。为此我们定义的最大图像尺寸为512x768最大尺寸为1024。由于最大图像大小为512x768比512x512大需要更多的VRAM因此每个gpu的批处理大小必须降低这可以通过梯度积累来补偿。
我们通过应用以下算法生成桶:
Set the width to 256.
While the width is less than or equal to 1024:Find the largest height such that height is less than or equal to 1024 and that width multiplied by height is less than or equal to 512 * 768.Add the resolution given by height and width as a bucket.Increase the width by 64.
同样的重复宽度和高度互换。重复的桶将从列表中删除并添加一个大小为512x512的桶。
接下来我们将图像分配到相应的桶中。为此我们首先将桶分辨率存储在NumPy数组中并计算每个分辨率的长宽比。对于数据集中的每张图像我们检索其分辨率并计算长宽比。图像宽高比从桶宽高比数组中减去使我们能够根据宽高比差的绝对值有效地选择最接近的桶:
image_bucket argmin(abs(bucket_aspects — image_aspect))图像的桶号与其数据集中的项目ID相关联。如果图像的宽高比非常极端甚至与最适合的桶相差太大则从数据集中修剪图像。
由于我们在多个GPU上进行训练在每个epoch之前我们对数据集进行了分片以确保每个GPU在大小相等的不同子集上工作。为此我们首先复制数据集中的项目id列表并对它们进行洗牌。如果这个复制的列表不能被gpu数量乘以批大小整除则会对列表进行修剪并删除最后的项以使其可整除。
然后我们根据当前进程的全局排名选择1/world_size*bsz项id的不同子集。自定义批处理生成的其余部分将从这些过程中的任何一个过程中进行描述并对数据集项id的子集进行操作。
对于当前的分片每个bucket的列表是通过迭代打乱的数据集项目ID列表并将ID分配给分配给图像的bucket对应的列表来创建的。
处理完所有图像后我们遍历每个bucket的列表。如果它的长度不能被批大小整除则根据需要删除列表上的最后一个元素以使其可整除并将它们添加到单独的捕获所有桶中。由于保证整个分片大小包含许多可被批大小整除的元素因此保证生成一个长度可被批大小整除的所有bucket。
当请求批处理时我们从加权分布中随机抽取一个桶。桶的权重设置为桶的大小除以所有剩余桶的大小。这确保了即使有大小差异很大的桶自定义批生成在训练期间不会引入强烈的偏差根据图像大小显示图像。如果在没有加权的情况下选择桶那么小的桶将在训练过程中早期清空只有最大的桶将在训练结束时保留。按大小对桶进行加权可以避免这种情况。
最后从所选的桶中取出一批项。取走的项目从桶中移除。如果桶现在为空则在epoch的剩余时间内删除它。所选的项id和所选桶的分辨率现在被传递给图像加载函数。
完整代码
import numpy as np
import pickle
import timedef get_prng(seed):return np.random.RandomState(seed)class BucketManager:def __init__(self, bucket_file, valid_idsNone, max_size(768,512), divisible64, step_size8, min_dim256, base_res(512,512), bsz1, world_size1, global_rank0, max_ar_error4, seed42, dim_limit1024, debugFalse):with open(bucket_file, rb) as fh:self.res_map pickle.load(fh)if valid_ids is not None:new_res_map {}valid_ids set(valid_ids)for k, v in self.res_map.items():if k in valid_ids:new_res_map[k] vself.res_map new_res_mapself.max_size max_sizeself.f 8self.max_tokens (max_size[0]/self.f) * (max_size[1]/self.f)self.div divisibleself.min_dim min_dimself.dim_limit dim_limitself.base_res base_resself.bsz bszself.world_size world_sizeself.global_rank global_rankself.max_ar_error max_ar_errorself.prng get_prng(seed)epoch_seed self.prng.tomaxint() % (2**32-1)self.epoch_prng get_prng(epoch_seed) # separate prng for sharding use for increased thread resilienceself.epoch Noneself.left_over Noneself.batch_total Noneself.batch_delivered Noneself.debug debugself.gen_buckets()self.assign_buckets()self.start_epoch()def gen_buckets(self):if self.debug:timer time.perf_counter()resolutions []aspects []w self.min_dimwhile (w/self.f) * (self.min_dim/self.f) self.max_tokens and w self.dim_limit:h self.min_dimgot_base Falsewhile (w/self.f) * ((hself.div)/self.f) self.max_tokens and (hself.div) self.dim_limit:if w self.base_res[0] and h self.base_res[1]:got_base Trueh self.divif (w ! self.base_res[0] or h ! self.base_res[1]) and got_base:resolutions.append(self.base_res)aspects.append(1)resolutions.append((w, h))aspects.append(float(w)/float(h))w self.divh self.min_dimwhile (h/self.f) * (self.min_dim/self.f) self.max_tokens and h self.dim_limit:w self.min_dimgot_base Falsewhile (h/self.f) * ((wself.div)/self.f) self.max_tokens and (wself.div) self.dim_limit:if w self.base_res[0] and h self.base_res[1]:got_base Truew self.divresolutions.append((w, h))aspects.append(float(w)/float(h))h self.divres_map {}for i, res in enumerate(resolutions):res_map[res] aspects[i]self.resolutions sorted(res_map.keys(), keylambda x: x[0] * 4096 - x[1])self.aspects np.array(list(map(lambda x: res_map[x], self.resolutions)))self.resolutions np.array(self.resolutions)if self.debug:timer time.perf_counter() - timerprint(fresolutions:\n{self.resolutions})print(faspects:\n{self.aspects})print(fgen_buckets: {timer:.5f}s)def assign_buckets(self):if self.debug:timer time.perf_counter()self.buckets {}self.aspect_errors []skipped 0skip_list []for post_id in self.res_map.keys():w, h self.res_map[post_id]aspect float(w)/float(h)bucket_id np.abs(self.aspects - aspect).argmin()if bucket_id not in self.buckets:self.buckets[bucket_id] []error abs(self.aspects[bucket_id] - aspect)if error self.max_ar_error:self.buckets[bucket_id].append(post_id)if self.debug:self.aspect_errors.append(error)else:skipped 1skip_list.append(post_id)for post_id in skip_list:del self.res_map[post_id]if self.debug:timer time.perf_counter() - timerself.aspect_errors np.array(self.aspect_errors)print(fskipped images: {skipped})print(faspect error: mean {self.aspect_errors.mean()}, median {np.median(self.aspect_errors)}, max {self.aspect_errors.max()})for bucket_id in reversed(sorted(self.buckets.keys(), keylambda b: len(self.buckets[b]))):print(fbucket {bucket_id}: {self.resolutions[bucket_id]}, aspect {self.aspects[bucket_id]:.5f}, entries {len(self.buckets[bucket_id])})print(fassign_buckets: {timer:.5f}s)def start_epoch(self, world_sizeNone, global_rankNone):if self.debug:timer time.perf_counter()if world_size is not None:self.world_size world_sizeif global_rank is not None:self.global_rank global_rank# select ids for this epoch/rankindex np.array(sorted(list(self.res_map.keys())))index_len index.shape[0]index self.epoch_prng.permutation(index)index index[:index_len - (index_len % (self.bsz * self.world_size))]#print(perm, self.global_rank, index[0:16])index index[self.global_rank::self.world_size]self.batch_total index.shape[0] // self.bszassert(index.shape[0] % self.bsz 0)index set(index)self.epoch {}self.left_over []self.batch_delivered 0for bucket_id in sorted(self.buckets.keys()):if len(self.buckets[bucket_id]) 0:self.epoch[bucket_id] np.array([post_id for post_id in self.buckets[bucket_id] if post_id in index], dtypenp.int64)self.prng.shuffle(self.epoch[bucket_id])self.epoch[bucket_id] list(self.epoch[bucket_id])overhang len(self.epoch[bucket_id]) % self.bszif overhang ! 0:self.left_over.extend(self.epoch[bucket_id][:overhang])self.epoch[bucket_id] self.epoch[bucket_id][overhang:]if len(self.epoch[bucket_id]) 0:del self.epoch[bucket_id]if self.debug:timer time.perf_counter() - timercount 0for bucket_id in self.epoch.keys():count len(self.epoch[bucket_id])print(fcorrect item count: {count len(index)} ({count} of {len(index)}))print(fstart_epoch: {timer:.5f}s)def get_batch(self):if self.debug:timer time.perf_counter()# check if no data left or no epoch initializedif self.epoch is None or self.left_over is None or (len(self.left_over) 0 and not bool(self.epoch)) or self.batch_total self.batch_delivered:self.start_epoch()found_batch Falsebatch_data Noneresolution self.base_reswhile not found_batch:bucket_ids list(self.epoch.keys())if len(self.left_over) self.bsz:bucket_probs [len(self.left_over)] [len(self.epoch[bucket_id]) for bucket_id in bucket_ids]bucket_ids [-1] bucket_idselse:bucket_probs [len(self.epoch[bucket_id]) for bucket_id in bucket_ids]bucket_probs np.array(bucket_probs, dtypenp.float32)bucket_lens bucket_probsbucket_probs bucket_probs / bucket_probs.sum()bucket_ids np.array(bucket_ids, dtypenp.int64)if bool(self.epoch):chosen_id int(self.prng.choice(bucket_ids, 1, pbucket_probs)[0])else:chosen_id -1if chosen_id -1:# using leftover images that couldnt make it into a bucketed batch and returning them for use with basic square imageself.prng.shuffle(self.left_over)batch_data self.left_over[:self.bsz]self.left_over self.left_over[self.bsz:]found_batch Trueelse:if len(self.epoch[chosen_id]) self.bsz:# return bucket batch and resolutionbatch_data self.epoch[chosen_id][:self.bsz]self.epoch[chosen_id] self.epoch[chosen_id][self.bsz:]resolution tuple(self.resolutions[chosen_id])found_batch Trueif len(self.epoch[chosen_id]) 0:del self.epoch[chosen_id]else:# cant make a batch from this, not enough images. move them to leftovers and try againself.left_over.extend(self.epoch[chosen_id])del self.epoch[chosen_id]assert(found_batch or len(self.left_over) self.bsz or bool(self.epoch))if self.debug:timer time.perf_counter() - timerprint(fbucket probs: , .join(map(lambda x: f{x:.2f}, list(bucket_probs*100))))print(fchosen id: {chosen_id})print(fbatch data: {batch_data})print(fresolution: {resolution})print(fget_batch: {timer:.5f}s)self.batch_delivered 1return (batch_data, resolution)def generator(self):if self.batch_delivered self.batch_total:self.start_epoch()while self.batch_delivered self.batch_total:yield self.get_batch()if __name__ __main__:# prepare a pickle with mapping of dataset IDs to resolutions called resolutions.pkl to use thiswith open(resolutions.pkl, rb) as fh:ids list(pickle.load(fh).keys())counts np.zeros((len(ids),)).astype(np.int64)id_map {}for i, post_id in enumerate(ids):id_map[post_id] ibm BucketManager(resolutions.pkl, debugTrue, bsz8, world_size8, global_rank3)print(got: str(bm.get_batch()))print(got: str(bm.get_batch()))print(got: str(bm.get_batch()))print(got: str(bm.get_batch()))print(got: str(bm.get_batch()))print(got: str(bm.get_batch()))print(got: str(bm.get_batch()))bm BucketManager(resolutions.pkl, bsz8, world_size1, global_rank0, valid_idsids[0:16])for _ in range(16):bm.get_batch()print(got from future epoch: str(bm.get_batch()))bms []for rank in range(16):bm BucketManager(resolutions.pkl, bsz8, world_size16, global_rankrank)bms.append(bm)for epoch in range(5):print(fepoch {epoch})for i, bm in enumerate(bms):print(fbm {i})first Truefor ids, res in bm.generator():if first and i 0:#print(ids)first Falsefor post_id in ids:counts[id_map[post_id]] 1print(np.bincount(counts))