鞍山 网站建设,乐清哪里有做网站,物联网公司排名国内,餐饮品牌设计网站目录 一、SG网络功能介绍二、SG网络代码实现 一、SG网络功能介绍
DiAD论文最主要的创新点就是使用SG网络解决多类别异常检测中的语义信息丢失问题#xff0c;那么它是怎么实现的保留原始图像语义信息的同时重建异常区域#xff1f;
与稳定扩散去噪网络的连接#xff1a; S… 目录 一、SG网络功能介绍二、SG网络代码实现 一、SG网络功能介绍
DiAD论文最主要的创新点就是使用SG网络解决多类别异常检测中的语义信息丢失问题那么它是怎么实现的保留原始图像语义信息的同时重建异常区域
与稳定扩散去噪网络的连接 SG网络被设计为与稳定扩散Stable Diffusion, SD去噪网络相连接。SD去噪网络本身具有强大的图像生成能力但可能无法在多类异常检测任务中保持图像的语义信息一致性。SG网络通过引入语义引导机制使得在重构异常区域时能够参考并保留原始图像的语义上下文。整个框架图中SG网络与去噪网络的连接如下图所示。 这是论文给出的最终输出我认为图中圈出来的地方有问题应该改为SG网络的编码器才对。
语义一致性保持 SG网络在重构过程中通过在不同尺度下处理噪声并利用空间感知特征融合Spatial-aware Feature Fusion, SFF块融合特征确保重建过程中保留语义信息。这样即使在重构异常区域时也能使修复后的区域与原始图像的语义上下文保持一致。 多尺度特征融合 SFF块将高尺度的语义信息集成到低尺度中使得在保留原始正常样本信息的同时能够处理大规模异常区域的重建。这种机制有助于在处理需要广泛重构的区域时最大化重构的准确性同时保持图像的语义一致性。从下图中可以看到特征融合模块还是很好理解的。 与预训练特征提取器的结合 SG网络还与特征空间中的预训练特征提取器相结合。预训练特征提取器能够处理输入图像和重建图像并在不同尺度上提取特征。通过比较这些特征系统能够生成异常图anomaly maps这些图显示了图像中可能存在的异常区域并给出了异常得分或置信度。这一步骤进一步验证了SG网络在保留语义信息方面的有效性。 避免类别错误 相比于传统的扩散模型如DDPMSG网络通过引入类别条件解决了在多类异常检测任务中可能出现的类别错误问题。LDM虽然通过交叉注意力引入了条件约束但在随机高斯噪声下去噪时仍可能丢失语义信息。SG网络则通过其语义引导机制有效地避免了这一问题。
二、SG网络代码实现
这部分代码大概有300行
class SemanticGuidedNetwork(nn.Module):def __init__(self,image_size,in_channels,model_channels,hint_channels,num_res_blocks,attention_resolutions,dropout0,channel_mult(1, 2, 4, 8),conv_resampleTrue,dims2,use_checkpointFalse,use_fp16False,num_heads-1,num_head_channels-1,num_heads_upsample-1,use_scale_shift_normFalse,resblock_updownFalse,use_new_attention_orderFalse,use_spatial_transformerFalse, # custom transformer supporttransformer_depth1, # custom transformer supportcontext_dimNone, # custom transformer supportn_embedNone, # custom support for prediction of discrete ids into codebook of first stage vq modellegacyTrue,disable_self_attentionsNone,num_attention_blocksNone,disable_middle_self_attnFalse,use_linear_in_transformerFalse,):super().__init__()if use_spatial_transformer:assert context_dim is not None, Fool!! You forgot to include the dimension of your cross-attention conditioning...if context_dim is not None:assert use_spatial_transformer, Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...from omegaconf.listconfig import ListConfigif type(context_dim) ListConfig:context_dim list(context_dim)if num_heads_upsample -1:num_heads_upsample num_headsif num_heads -1:assert num_head_channels ! -1, Either num_heads or num_head_channels has to be setif num_head_channels -1:assert num_heads ! -1, Either num_heads or num_head_channels has to be setself.dims dimsself.image_size image_sizeself.in_channels in_channelsself.model_channels model_channelsif isinstance(num_res_blocks, int):self.num_res_blocks len(channel_mult) * [num_res_blocks]else:if len(num_res_blocks) ! len(channel_mult):raise ValueError(provide num_res_blocks either as an int (globally constant) or as a list/tuple (per-level) with the same length as channel_mult)self.num_res_blocks num_res_blocksif disable_self_attentions is not None:# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or notassert len(disable_self_attentions) len(channel_mult)if num_attention_blocks is not None:assert len(num_attention_blocks) len(self.num_res_blocks)assert all(map(lambda i: self.num_res_blocks[i] num_attention_blocks[i], range(len(num_attention_blocks))))print(fConstructor of UNetModel received num_attention_blocks{num_attention_blocks}. fThis option has LESS priority than attention_resolutions {attention_resolutions}, fi.e., in cases where num_attention_blocks[i] 0 but 2**i not in attention_resolutions, fattention will still not be set.)self.attention_resolutions attention_resolutionsself.dropout dropoutself.channel_mult channel_multself.conv_resample conv_resampleself.use_checkpoint use_checkpointself.dtype th.float16 if use_fp16 else th.float32self.num_heads num_headsself.num_head_channels num_head_channelsself.num_heads_upsample num_heads_upsampleself.predict_codebook_ids n_embed is not Nonetime_embed_dim model_channels * 4self.time_embed nn.Sequential(linear(model_channels, time_embed_dim),nn.SiLU(),linear(time_embed_dim, time_embed_dim),)self.input_blocks nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding1))])self.zero_convs nn.ModuleList([self.make_zero_conv(model_channels)])self.input_hint_block TimestepEmbedSequential(conv_nd(dims, hint_channels, 16, 3, padding1),nn.SiLU(),conv_nd(dims, 16, 16, 3, padding1),nn.SiLU(),conv_nd(dims, 16, 32, 3, padding1, stride2),nn.SiLU(),conv_nd(dims, 32, 32, 3, padding1),nn.SiLU(),conv_nd(dims, 32, 96, 3, padding1, stride2),nn.SiLU(),conv_nd(dims, 96, 96, 3, padding1),nn.SiLU(),conv_nd(dims, 96, 256, 3, padding1, stride2),nn.SiLU(),zero_module(conv_nd(dims, 256, model_channels, 3, padding1)))self._feature_size model_channelsinput_block_chans [model_channels]ch model_channelsds 1for level, mult in enumerate(channel_mult):for nr in range(self.num_res_blocks[level]):layers [ResBlock(ch,time_embed_dim,dropout,out_channelsmult * model_channels,dimsdims,use_checkpointuse_checkpoint,use_scale_shift_normuse_scale_shift_norm,)]ch mult * model_channelsif ds in attention_resolutions:if num_head_channels -1:dim_head ch // num_headselse:num_heads ch // num_head_channelsdim_head num_head_channelsif legacy:# num_heads 1dim_head ch // num_heads if use_spatial_transformer else num_head_channelsif exists(disable_self_attentions):disabled_sa disable_self_attentions[level]else:disabled_sa Falseif not exists(num_attention_blocks) or nr num_attention_blocks[level]:layers.append(AttentionBlock(ch,use_checkpointuse_checkpoint,num_headsnum_heads,num_head_channelsdim_head,use_new_attention_orderuse_new_attention_order,) if not use_spatial_transformer else SpatialTransformer(ch, num_heads, dim_head, depthtransformer_depth, context_dimcontext_dim,disable_self_attndisabled_sa, use_linearuse_linear_in_transformer,use_checkpointuse_checkpoint))self.input_blocks.append(TimestepEmbedSequential(*layers))self.zero_convs.append(self.make_zero_conv(ch))self._feature_size chinput_block_chans.append(ch)if level ! len(channel_mult) - 1:out_ch chself.input_blocks.append(TimestepEmbedSequential(ResBlock(ch,time_embed_dim,dropout,out_channelsout_ch,dimsdims,use_checkpointuse_checkpoint,use_scale_shift_normuse_scale_shift_norm,downTrue,)if resblock_updownelse Downsample(ch, conv_resample, dimsdims, out_channelsout_ch)))ch out_chinput_block_chans.append(ch)self.zero_convs.append(self.make_zero_conv(ch))ds * 2self._feature_size chif num_head_channels -1:dim_head ch // num_headselse:num_heads ch // num_head_channelsdim_head num_head_channelsif legacy:# num_heads 1dim_head ch // num_heads if use_spatial_transformer else num_head_channelsself.middle_block TimestepEmbedSequential(ResBlock(ch,time_embed_dim,dropout,dimsdims,use_checkpointuse_checkpoint,use_scale_shift_normuse_scale_shift_norm,),AttentionBlock(ch,use_checkpointuse_checkpoint,num_headsnum_heads,num_head_channelsdim_head,use_new_attention_orderuse_new_attention_order,) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attnch, num_heads, dim_head, depthtransformer_depth, context_dimcontext_dim,disable_self_attndisable_middle_self_attn, use_linearuse_linear_in_transformer,use_checkpointuse_checkpoint),ResBlock(ch,time_embed_dim,dropout,dimsdims,use_checkpointuse_checkpoint,use_scale_shift_normuse_scale_shift_norm,),)self.middle_block_out self.make_zero_conv(ch)self._feature_size ch#SFF Blockself.down11 nn.Sequential(zero_module(nn.Conv2d(640, 1280, kernel_size3, stride2, padding1, biasFalse)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down12 nn.Sequential(zero_module(nn.Conv2d(640, 1280, kernel_size3, stride2, padding1, biasFalse)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down13 nn.Sequential(zero_module(nn.Conv2d(640, 1280, kernel_size3, stride2, padding1, biasFalse)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down21 nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size3, stride2, padding1, biasFalse)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down22 nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size3, stride2, padding1, biasFalse)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down23 nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size3, stride2, padding1, biasFalse)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down31 nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size3, stride2, padding1, biasFalse)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down32 nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size3, stride2, padding1, biasFalse)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down33 nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size3, stride2, padding1, biasFalse)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.silu nn.SiLU()def make_zero_conv(self, channels):return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding0)))def forward(self, x, hint, timesteps, context, **kwargs):t_emb timestep_embedding(timesteps, self.model_channels, repeat_onlyFalse)emb self.time_embed(t_emb)guided_hint self.input_hint_block(hint, emb, context)outs []h x.type(self.dtype)for module, zero_conv in zip(self.input_blocks, self.zero_convs):if guided_hint is not None:h module(h, emb, context)h guided_hintguided_hint Noneelse:h module(h, emb, context)outs.append(zero_conv(h, emb, context))#SFF Block Implementationouts[9] self.silu(outs[9]self.down11(outs[6])self.down21(outs[7])self.down31(outs[8]))outs[10] self.silu(outs[10]self.down12(outs[6])self.down22(outs[7])self.down32(outs[8]))outs[11] self.silu(outs[11]self.down13(outs[6])self.down23(outs[7])self.down33(outs[8]))h self.middle_block(h, emb, context)outs.append(self.middle_block_out(h, emb, context))return outs