网站 如何备案,网易网站建设,做微网站平台,劳务输送网站建设方案文章目录 前言一、class CVRPModel(nn.Module):__init__(self, **model_params)函数功能函数代码 二、class CVRPModel(nn.Module):pre_forward(self, reset_state)函数功能函数代码 三、class CVRPModel(nn.Module):forward(self, state)函数功能函数代码 四、def _get_encodi… 文章目录 前言一、class CVRPModel(nn.Module):__init__(self, **model_params)函数功能函数代码 二、class CVRPModel(nn.Module):pre_forward(self, reset_state)函数功能函数代码 三、class CVRPModel(nn.Module):forward(self, state)函数功能函数代码 四、def _get_encoding(encoded_nodes, node_index_to_pick)函数功能函数代码 五、class CVRP_Encoder(nn.Module)六、class EncoderLayer(nn.Module)七、CVRP_Decoder(nn.Module)八、def reshape_by_heads(qkv, head_num)函数功能函数代码 九、def multi_head_attention(q, k, v, rank2_ninf_maskNone, rank3_ninf_maskNone)函数功能函数代码 十、class AddAndInstanceNormalization(nn.Module):__init__(self, **model_params)函数功能Batch Normalization (BN) 是什么Batch Normalization 的具体操作1. **计算均值和方差**2. **标准化**3. **缩放和平移** Batch Normalization 的优势 函数代码 十一、class AddAndInstanceNormalization(nn.Module):forward(self, input1, input2)函数功能函数代码 十二、class FeedForward(nn.Module):__init__(self, **model_params)函数功能函数代码 十三、class FeedForward(nn.Module):forward(self, input1)函数功能函数代码 附录代码全 前言
学习代码: class CVRPModel(nn.Module): class CVRP_Encoder(nn.Module): class EncoderLayer(nn.Module): class CVRP_Decoder(nn.Module): class AddAndInstanceNormalization(nn.Module): class AddAndBatchNormalization(nn.Module): class FeedForward(nn.Module): /home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPModel.py 一、class CVRPModel(nn.Module):init(self, **model_params)
函数功能
init 是 CVRPModel 类的构造函数负责初始化模型的各个组件。 主要任务包括
接收和存储模型的参数model_params。初始化编码器encoder和解码器decoder子模块。初始化 encoded_nodes 变量用于存储经过编码的节点数据。
执行流程图链接
函数代码 def __init__(self, **model_params):super().__init__()self.model_params model_paramsself.encoder CVRP_Encoder(**model_params)self.decoder CVRP_Decoder(**model_params)self.encoded_nodes None# shape: (batch, problem1, EMBEDDING_DIM) 二、class CVRPModel(nn.Module):pre_forward(self, reset_state)
函数功能
pre_forward 是 CVRPModel 类的一个前向传播前的准备函数。它的主要任务是根据给定的初始状态reset_state准备和编码数据为模型的后续前向传播forward过程做准备。 具体来说函数的作用是
提取并处理初始状态的数据。使用编码器对节点进行编码得到编码后的节点表示。为解码器设置额外的嵌入信息并将编码后的节点与额外的嵌入信息拼接。设置解码器中的 kvkey-value信息为解码过程做准备。
执行流程图链接
函数代码 def pre_forward(self, reset_state):depot_xy reset_state.depot_xy# shape: (batch, 1, 2)node_xy reset_state.node_xy# shape: (batch, problem, 2)node_demand reset_state.node_demand# shape: (batch, problem)node_xy_demand torch.cat((node_xy, node_demand[:, :, None]), dim2)# shape: (batch, problem, 3)encoded_nodes self.encoder(depot_xy, node_xy_demand)# shape: (batch, problem1, embedding)_ self.decoder.regret_embedding[None, None, :].expand(encoded_nodes.size(0), 1,self.decoder.regret_embedding.size(-1))# _ 的shape(batch,1,embedding)self.encoded_nodes torch.cat((encoded_nodes, _), dim1)# self.encoded_nodes的shape(batch,problem2,embedding)self.decoder.set_kv(self.encoded_nodes) 三、class CVRPModel(nn.Module):forward(self, state)
函数功能
forward 是 CVRPModel 类的核心前向传播函数用于根据当前状态state生成模型的输出包括选择的节点selected和相关的概率prob。 它的主要功能是基于当前的状态和历史选择来决定接下来应该选择哪个节点并输出相应的概率。
执行流程图链接
函数代码 def forward(self, state):batch_size state.BATCH_IDX.size(0)pomo_size state.BATCH_IDX.size(1)if state.selected_count 0: # First Move, depotselected torch.zeros(size(batch_size, pomo_size), dtypetorch.long)prob torch.ones(size(batch_size, pomo_size))# # Use Averaged encoded nodes for decoder input_1# encoded_nodes_mean self.encoded_nodes.mean(dim1, keepdimTrue)# # shape: (batch, 1, embedding)# self.decoder.set_q1(encoded_nodes_mean)# Use encoded_depot for decoder input_2encoded_first_node self.encoded_nodes[:, [0], :]# shape: (batch, 1, embedding)self.decoder.set_q2(encoded_first_node)elif state.selected_count 1: # Second Move, POMOselected torch.arange(start1, endpomo_size1)[None, :].expand(batch_size, pomo_size)prob torch.ones(size(batch_size, pomo_size))else:encoded_last_node _get_encoding(self.encoded_nodes, state.current_node)# shape: (batch, pomo, embedding)probs self.decoder(encoded_last_node, state.load, ninf_maskstate.ninf_mask)# shape: (batch, pomo, problem1)if self.training or self.model_params[eval_type] softmax:while True: # to fix pytorch.multinomial bug on selecting 0 probability elementswith torch.no_grad():selected probs.reshape(batch_size * pomo_size, -1).multinomial(1) \.squeeze(dim1).reshape(batch_size, pomo_size)# shape: (batch, pomo)prob probs[state.BATCH_IDX, state.POMO_IDX, selected].reshape(batch_size, pomo_size)# shape: (batch, pomo)if (prob ! 0).all():breakelse:probsprobs[:,:,:-1]selected probs.argmax(dim2)# shape: (batch, pomo)prob None # value not needed. Can be anything.return selected, prob 四、def _get_encoding(encoded_nodes, node_index_to_pick)
函数功能
_get_encoding 的作用是从 encoded_nodes 中按照 node_index_to_pick 选择相应的编码并返回选中的编码信息。
函数执行流程图链接
函数代码
def _get_encoding(encoded_nodes, node_index_to_pick):# encoded_nodes.shape: (batch, problem, embedding)# node_index_to_pick.shape: (batch, pomo)batch_size node_index_to_pick.size(0)pomo_size node_index_to_pick.size(1)embedding_dim encoded_nodes.size(2)gathering_index node_index_to_pick[:, :, None].expand(batch_size, pomo_size, embedding_dim)# shape: (batch, pomo, embedding)picked_nodes encoded_nodes.gather(dim1, indexgathering_index)# shape: (batch, pomo, embedding)return picked_nodes 五、class CVRP_Encoder(nn.Module)
笔记20250226-代码笔记04-class CVRP_Encoder AND class EncoderLayer 六、class EncoderLayer(nn.Module)
笔记20250226-代码笔记04-class CVRP_Encoder AND class EncoderLayer 七、CVRP_Decoder(nn.Module)
笔记20250226-代码笔记05-class CVRP_Decoder 八、def reshape_by_heads(qkv, head_num)
函数功能
reshape_by_heads 函数的功能是将输入的张量如查询 q, 键 k, 或值 v从一个紧凑的多头结构 (batch, n, head_num * key_dim) 转换为适合多头注意力机制计算的结构 (batch, head_num, n, key_dim)。 此操作将多个注意力头的维度进行拆分并将其调整为每个头独立计算的格式。 执行流程图链接
函数代码
def reshape_by_heads(qkv, head_num):# q.shape: (batch, n, head_num*key_dim) : n can be either 1 or PROBLEM_SIZEbatch_s qkv.size(0)n qkv.size(1)q_reshaped qkv.reshape(batch_s, n, head_num, -1)# shape: (batch, n, head_num, key_dim)q_transposed q_reshaped.transpose(1, 2)# shape: (batch, head_num, n, key_dim)return q_transposed 九、def multi_head_attention(q, k, v, rank2_ninf_maskNone, rank3_ninf_maskNone)
函数功能
multi_head_attention 函数的主要功能是实现 多头注意力机制。该函数接收查询Q、键K和值V并计算多头注意力输出。它通过计算查询与键之间的相似度生成加权值的结果并结合所有头的输出生成最终的注意力表示。 执行流程图链接
函数代码
def multi_head_attention(q, k, v, rank2_ninf_maskNone, rank3_ninf_maskNone):# q shape: (batch, head_num, n, key_dim) : n can be either 1 or PROBLEM_SIZE# k,v shape: (batch, head_num, problem, key_dim)# rank2_ninf_mask.shape: (batch, problem)# rank3_ninf_mask.shape: (batch, group, problem)batch_s q.size(0)head_num q.size(1)n q.size(2)key_dim q.size(3)input_s k.size(2)score torch.matmul(q, k.transpose(2, 3))# shape: (batch, head_num, n, problem)score_scaled score / torch.sqrt(torch.tensor(key_dim, dtypetorch.float))if rank2_ninf_mask is not None:score_scaled score_scaled rank2_ninf_mask[:, None, None, :].expand(batch_s, head_num, n, input_s)if rank3_ninf_mask is not None:score_scaled score_scaled rank3_ninf_mask[:, None, :, :].expand(batch_s, head_num, n, input_s)weights nn.Softmax(dim3)(score_scaled)# shape: (batch, head_num, n, problem)out torch.matmul(weights, v)# shape: (batch, head_num, n, key_dim)out_transposed out.transpose(1, 2)# shape: (batch, n, head_num, key_dim)out_concat out_transposed.reshape(batch_s, n, head_num * key_dim)# shape: (batch, n, head_num*key_dim)return out_concat 十、class AddAndInstanceNormalization(nn.Module):init(self, **model_params)
函数功能
对输入数据进行基于嵌入维度的批量标准化操作从而使得模型在训练过程中能够更好地收敛和提高稳定性。
Batch Normalization (BN) 是什么
Batch Normalization (BN) 是一种在训练深度神经网络时常用的技术它的目的是提高网络的训练速度、稳定性并帮助避免梯度消失或爆炸问题。 Batch Normalization 操作的核心思想是对每一层的输入数据进行标准化使得输入数据的均值接近 0方差接近 1。这样可以避免激活函数输出过大或过小的问题帮助优化过程更加稳定。
Batch Normalization 的具体操作
1. 计算均值和方差
对于一批输入样本batch在每个特征维度上计算均值和方差 均值 μ B 1 m ∑ i 1 m x i \mu_B \frac{1}{m} \sum_{i1}^{m} x_i μBm1∑i1mxi 方差 σ B 2 1 m ∑ i 1 m ( x i − μ B ) 2 \sigma_B^2 \frac{1}{m} \sum_{i1}^{m} (x_i - \mu_B)^2 σB2m1∑i1m(xi−μB)2
其中 m m m 是一个批次中的样本数 x i x_i xi是每个样本的输入值。
2. 标准化
使用计算出的均值和方差将输入数据标准化使得每个特征的均值为 0方差为 1 x ^ i x i − μ B σ B 2 ϵ \hat{x}_i \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 \epsilon}} x^iσB2ϵ xi−μB
这里 ϵ \epsilon ϵ是一个非常小的数值用来防止除以零的情况。
3. 缩放和平移
由于标准化可能会影响到模型的表达能力Batch Normalization 还会引入两个可学习的参数 γ \gamma γ缩放参数和 β \beta β平移参数它们允许模型重新调整标准化后的数据 y i γ x ^ i β y_i \gamma \hat{x}_i \beta yiγx^iβ
其中 γ \gamma γ 和 β \beta β是学习的参数通常会通过反向传播进行优化。
Batch Normalization 的优势
加速训练Batch Normalization 通过减少输入数据的偏移internal covariate shift使得每一层的输入分布更加稳定从而加速了网络的训练过程。提高稳定性由于它通过标准化输入避免了梯度爆炸或梯度消失问题使得训练更加稳定。缓解过拟合在一些情况下Batch Normalization 也可以起到正则化的作用减少了模型对训练数据的过拟合。减少对初始化的依赖Batch Normalization 可以在一定程度上缓解对权重初始化的敏感性。
函数代码 def __init__(self, **model_params):super().__init__()embedding_dim model_params[embedding_dim]self.norm nn.InstanceNorm1d(embedding_dim, affineTrue, track_running_statsFalse) 十一、class AddAndInstanceNormalization(nn.Module):forward(self, input1, input2)
函数功能
forward 方法它执行了加法和批量归一化操作。 forward 方法的主要功能是
加法操作将两个输入张量 input1 和 input2 相加。批量归一化将加法结果进行批量归一化Batch Normalization标准化其特征维度。形状恢复批量归一化后将张量的形状恢复到原来的维度。
执行流程:
函数代码
获取输入张量的维度
batch_s input1.size(0)
problem_s input1.size(1)
embedding_dim input1.size(2)
batch_s 表示批次大小problem_s 表示问题的大小特征的数量embedding_dim 表示嵌入的维度。这些维度来自输入张量input1并且假设 input2 具有相同的形状。
加法操作
added input1 input2
对input1 和 nput2 进行逐元素加法。此时added 张量的形状与 input1 和 input2 相同仍为 (batch_s, problem_s, embedding_dim)。
批量归一化
normalized self.norm_by_EMB(added.reshape(batch_s * problem_s, embedding_dim))
将 added 张量的形状重塑为 (batch_s * problem_s, embedding_dim)将批次维度和问题维度合并以便进行批量归一化操作。这样就对每个特征维度embedding_dim做了批量标准化。self.norm_by_EMB 是一个 BatchNorm1d 层它会对每个特征维度执行标准化使得每个特征的均值接近 0方差接近 1。
恢复形状
back_trans normalized.reshape(batch_s, problem_s, embedding_dim)
批量归一化后将 normalized 张量的形状恢复回 (batch_s, problem_s, embedding_dim)即恢复原本的输入形状。
返回结果
return back_trans
返回经过批量归一化的张量 back_trans它的形状与输入相同并且每个特征维度已经经过标准化。 def forward(self, input1, input2):# input.shape: (batch, problem, embedding)added input1 input2# shape: (batch, problem, embedding)transposed added.transpose(1, 2)# shape: (batch, embedding, problem)normalized self.norm(transposed)# shape: (batch, embedding, problem)back_trans normalized.transpose(1, 2)# shape: (batch, problem, embedding)return back_trans 十二、class FeedForward(nn.Module):init(self, **model_params)
函数功能
FeedForward 的类它是一个典型的前馈神经网络Feedforward Neural Network模块实现了一个简单的两层神经网络。
__init__ 方法是类的构造函数用来初始化网络的层和超参数。embedding_dim 和 ff_hidden_dim 是通过 model_params 传递的超参数分别表示嵌入维度和前馈神经网络隐藏层的维度。 embedding_dim 是输入和输出的维度。ff_hidden_dim 是隐藏层的维度即在网络的中间层。 self.W1 和 self.W2是两个全连接层nn.Linear self.W1 将输入的 embedding_dim 维度的向量转换为 ff_hidden_dim 维度的向量。self.W2 将 ff_hidden_dim 维度的向量转换回 embedding_dim 维度的向量。
函数代码 def __init__(self, **model_params):super().__init__()embedding_dim model_params[embedding_dim]ff_hidden_dim model_params[ff_hidden_dim]self.W1 nn.Linear(embedding_dim, ff_hidden_dim)self.W2 nn.Linear(ff_hidden_dim, embedding_dim) 十三、class FeedForward(nn.Module):forward(self, input1)
函数功能
forward 方法定义了数据流通过网络的方式也就是前向传播过程。输入 input1 的形状为 (batch, problem, embedding)即批次大小 batch、问题数量 problem和每个问题的嵌入维度embedding。执行的步骤如下 1.第一层线性变换self.W1输入通过 self.W1 进行线性变换将输入的嵌入维度转换为隐藏层的维度ff_hidden_dim。变换公式为 其中 x 是输入W1 是权重矩阵b1 是偏置。 2.激活函数ReLU对 self.W1 的输出应用 ReLU 激活函数ReLU 将负值归零保留正值。公式为 3.第二层线性变换self.W2通过 self.W2 进行线性变换将隐藏层的输出转换回原始的嵌入维度embedding_dim。变换公式为 最终输出是经过两层线性变换和 ReLU 激活函数处理的结果形状仍然是 (batch, problem, embedding)。
函数代码 def forward(self, input1):# input.shape: (batch, problem, embedding)return self.W2(F.relu(self.W1(input1)))附录
代码全 import torch
import torch.nn as nn
import torch.nn.functional as Fclass CVRPModel(nn.Module):def __init__(self, **model_params):super().__init__()self.model_params model_paramsself.encoder CVRP_Encoder(**model_params)self.decoder CVRP_Decoder(**model_params)self.encoded_nodes None# shape: (batch, problem1, EMBEDDING_DIM)def pre_forward(self, reset_state):depot_xy reset_state.depot_xy# shape: (batch, 1, 2)node_xy reset_state.node_xy# shape: (batch, problem, 2)node_demand reset_state.node_demand# shape: (batch, problem)node_xy_demand torch.cat((node_xy, node_demand[:, :, None]), dim2)# shape: (batch, problem, 3)encoded_nodes self.encoder(depot_xy, node_xy_demand)# shape: (batch, problem1, embedding)_ self.decoder.regret_embedding[None, None, :].expand(encoded_nodes.size(0), 1,self.decoder.regret_embedding.size(-1))# _ 的shape(batch,1,embedding)self.encoded_nodes torch.cat((encoded_nodes, _), dim1)# self.encoded_nodes的shape(batch,problem2,embedding)self.decoder.set_kv(self.encoded_nodes)def forward(self, state):batch_size state.BATCH_IDX.size(0)pomo_size state.BATCH_IDX.size(1)if state.selected_count 0: # First Move, depotselected torch.zeros(size(batch_size, pomo_size), dtypetorch.long)prob torch.ones(size(batch_size, pomo_size))# # Use Averaged encoded nodes for decoder input_1# encoded_nodes_mean self.encoded_nodes.mean(dim1, keepdimTrue)# # shape: (batch, 1, embedding)# self.decoder.set_q1(encoded_nodes_mean)# Use encoded_depot for decoder input_2encoded_first_node self.encoded_nodes[:, [0], :]# shape: (batch, 1, embedding)self.decoder.set_q2(encoded_first_node)elif state.selected_count 1: # Second Move, POMOselected torch.arange(start1, endpomo_size1)[None, :].expand(batch_size, pomo_size)prob torch.ones(size(batch_size, pomo_size))else:encoded_last_node _get_encoding(self.encoded_nodes, state.current_node)# shape: (batch, pomo, embedding)probs self.decoder(encoded_last_node, state.load, ninf_maskstate.ninf_mask)# shape: (batch, pomo, problem1)if self.training or self.model_params[eval_type] softmax:while True: # to fix pytorch.multinomial bug on selecting 0 probability elementswith torch.no_grad():selected probs.reshape(batch_size * pomo_size, -1).multinomial(1) \.squeeze(dim1).reshape(batch_size, pomo_size)# shape: (batch, pomo)prob probs[state.BATCH_IDX, state.POMO_IDX, selected].reshape(batch_size, pomo_size)# shape: (batch, pomo)if (prob ! 0).all():breakelse:probsprobs[:,:,:-1]selected probs.argmax(dim2)# shape: (batch, pomo)prob None # value not needed. Can be anything.return selected, probdef _get_encoding(encoded_nodes, node_index_to_pick):# encoded_nodes.shape: (batch, problem, embedding)# node_index_to_pick.shape: (batch, pomo)batch_size node_index_to_pick.size(0)pomo_size node_index_to_pick.size(1)embedding_dim encoded_nodes.size(2)gathering_index node_index_to_pick[:, :, None].expand(batch_size, pomo_size, embedding_dim)# shape: (batch, pomo, embedding)picked_nodes encoded_nodes.gather(dim1, indexgathering_index)# shape: (batch, pomo, embedding)return picked_nodes########################################
# ENCODER
########################################class CVRP_Encoder(nn.Module):def __init__(self, **model_params):super().__init__()self.model_params model_paramsembedding_dim self.model_params[embedding_dim]encoder_layer_num self.model_params[encoder_layer_num]self.embedding_depot nn.Linear(2, embedding_dim)self.embedding_node nn.Linear(3, embedding_dim)self.layers nn.ModuleList([EncoderLayer(**model_params) for _ in range(encoder_layer_num)])def forward(self, depot_xy, node_xy_demand):# depot_xy.shape: (batch, 1, 2)# node_xy_demand.shape: (batch, problem, 3)embedded_depot self.embedding_depot(depot_xy)# shape: (batch, 1, embedding)embedded_node self.embedding_node(node_xy_demand)# shape: (batch, problem, embedding)out torch.cat((embedded_depot, embedded_node), dim1)# shape: (batch, problem1, embedding)for layer in self.layers:out layer(out)return out# shape: (batch, problem1, embedding)class EncoderLayer(nn.Module):def __init__(self, **model_params):super().__init__()self.model_params model_paramsembedding_dim self.model_params[embedding_dim]head_num self.model_params[head_num]qkv_dim self.model_params[qkv_dim]self.Wq nn.Linear(embedding_dim, head_num * qkv_dim, biasFalse)self.Wk nn.Linear(embedding_dim, head_num * qkv_dim, biasFalse)self.Wv nn.Linear(embedding_dim, head_num * qkv_dim, biasFalse)self.multi_head_combine nn.Linear(head_num * qkv_dim, embedding_dim)self.add_n_normalization_1 AddAndInstanceNormalization(**model_params)self.feed_forward FeedForward(**model_params)self.add_n_normalization_2 AddAndInstanceNormalization(**model_params)def forward(self, input1):# input1.shape: (batch, problem1, embedding)head_num self.model_params[head_num]q reshape_by_heads(self.Wq(input1), head_numhead_num)k reshape_by_heads(self.Wk(input1), head_numhead_num)v reshape_by_heads(self.Wv(input1), head_numhead_num)# qkv shape: (batch, head_num, problem, qkv_dim)out_concat multi_head_attention(q, k, v)# shape: (batch, problem, head_num*qkv_dim)multi_head_out self.multi_head_combine(out_concat)# shape: (batch, problem, embedding)out1 self.add_n_normalization_1(input1, multi_head_out)out2 self.feed_forward(out1)out3 self.add_n_normalization_2(out1, out2)return out3# shape: (batch, problem, embedding)########################################
# DECODER
########################################class CVRP_Decoder(nn.Module):def __init__(self, **model_params):super().__init__()self.model_params model_paramsembedding_dim self.model_params[embedding_dim]head_num self.model_params[head_num]qkv_dim self.model_params[qkv_dim]# self.Wq_1 nn.Linear(embedding_dim, head_num * qkv_dim, biasFalse)self.Wq_2 nn.Linear(embedding_dim, head_num * qkv_dim, biasFalse)self.Wq_last nn.Linear(embedding_dim1, head_num * qkv_dim, biasFalse)self.Wk nn.Linear(embedding_dim, head_num * qkv_dim, biasFalse)self.Wv nn.Linear(embedding_dim, head_num * qkv_dim, biasFalse)self.regret_embedding nn.Parameter(torch.Tensor(embedding_dim))self.regret_embedding.data.uniform_(-1, 1)self.multi_head_combine nn.Linear(head_num * qkv_dim, embedding_dim)self.k None # saved key, for multi-head attentionself.v None # saved value, for multi-head_attentionself.single_head_key None # saved, for single-head attention# self.q1 None # saved q1, for multi-head attentionself.q2 None # saved q2, for multi-head attentiondef set_kv(self, encoded_nodes):# encoded_nodes.shape: (batch, problem1, embedding)head_num self.model_params[head_num]self.k reshape_by_heads(self.Wk(encoded_nodes), head_numhead_num)self.v reshape_by_heads(self.Wv(encoded_nodes), head_numhead_num)# shape: (batch, head_num, problem1, qkv_dim)self.single_head_key encoded_nodes.transpose(1, 2)# shape: (batch, embedding, problem1)def set_q1(self, encoded_q1):# encoded_q.shape: (batch, n, embedding) # n can be 1 or pomohead_num self.model_params[head_num]self.q1 reshape_by_heads(self.Wq_1(encoded_q1), head_numhead_num)# shape: (batch, head_num, n, qkv_dim)def set_q2(self, encoded_q2):# encoded_q.shape: (batch, n, embedding) # n can be 1 or pomohead_num self.model_params[head_num]self.q2 reshape_by_heads(self.Wq_2(encoded_q2), head_numhead_num)# shape: (batch, head_num, n, qkv_dim)def forward(self, encoded_last_node, load, ninf_mask):# encoded_last_node.shape: (batch, pomo, embedding)# load.shape: (batch, pomo)# ninf_mask.shape: (batch, pomo, problem)head_num self.model_params[head_num]# Multi-Head Attention#######################################################input_cat torch.cat((encoded_last_node, load[:, :, None]), dim2)# shape (batch, group, EMBEDDING_DIM1)q_last reshape_by_heads(self.Wq_last(input_cat), head_numhead_num)# shape: (batch, head_num, pomo, qkv_dim)# q self.q1 self.q2 q_last# # shape: (batch, head_num, pomo, qkv_dim)# q q_last# shape: (batch, head_num, pomo, qkv_dim)q self.q2 q_last# # shape: (batch, head_num, pomo, qkv_dim)out_concat multi_head_attention(q, self.k, self.v, rank3_ninf_maskninf_mask)# shape: (batch, pomo, head_num*qkv_dim)mh_atten_out self.multi_head_combine(out_concat)# shape: (batch, pomo, embedding)# Single-Head Attention, for probability calculation#######################################################score torch.matmul(mh_atten_out, self.single_head_key)# shape: (batch, pomo, problem)sqrt_embedding_dim self.model_params[sqrt_embedding_dim]logit_clipping self.model_params[logit_clipping]score_scaled score / sqrt_embedding_dim# shape: (batch, pomo, problem)score_clipped logit_clipping * torch.tanh(score_scaled)score_masked score_clipped ninf_maskprobs F.softmax(score_masked, dim2)# shape: (batch, pomo, problem)return probs########################################
# NN SUB CLASS / FUNCTIONS
########################################def reshape_by_heads(qkv, head_num):# q.shape: (batch, n, head_num*key_dim) : n can be either 1 or PROBLEM_SIZEbatch_s qkv.size(0)n qkv.size(1)q_reshaped qkv.reshape(batch_s, n, head_num, -1)# shape: (batch, n, head_num, key_dim)q_transposed q_reshaped.transpose(1, 2)# shape: (batch, head_num, n, key_dim)return q_transposeddef multi_head_attention(q, k, v, rank2_ninf_maskNone, rank3_ninf_maskNone):# q shape: (batch, head_num, n, key_dim) : n can be either 1 or PROBLEM_SIZE# k,v shape: (batch, head_num, problem, key_dim)# rank2_ninf_mask.shape: (batch, problem)# rank3_ninf_mask.shape: (batch, group, problem)batch_s q.size(0)head_num q.size(1)n q.size(2)key_dim q.size(3)input_s k.size(2)score torch.matmul(q, k.transpose(2, 3))# shape: (batch, head_num, n, problem)score_scaled score / torch.sqrt(torch.tensor(key_dim, dtypetorch.float))if rank2_ninf_mask is not None:score_scaled score_scaled rank2_ninf_mask[:, None, None, :].expand(batch_s, head_num, n, input_s)if rank3_ninf_mask is not None:score_scaled score_scaled rank3_ninf_mask[:, None, :, :].expand(batch_s, head_num, n, input_s)weights nn.Softmax(dim3)(score_scaled)# shape: (batch, head_num, n, problem)out torch.matmul(weights, v)# shape: (batch, head_num, n, key_dim)out_transposed out.transpose(1, 2)# shape: (batch, n, head_num, key_dim)out_concat out_transposed.reshape(batch_s, n, head_num * key_dim)# shape: (batch, n, head_num*key_dim)return out_concatclass AddAndInstanceNormalization(nn.Module):def __init__(self, **model_params):super().__init__()embedding_dim model_params[embedding_dim]self.norm nn.InstanceNorm1d(embedding_dim, affineTrue, track_running_statsFalse)def forward(self, input1, input2):# input.shape: (batch, problem, embedding)added input1 input2# shape: (batch, problem, embedding)transposed added.transpose(1, 2)# shape: (batch, embedding, problem)normalized self.norm(transposed)# shape: (batch, embedding, problem)back_trans normalized.transpose(1, 2)# shape: (batch, problem, embedding)return back_transclass AddAndBatchNormalization(nn.Module):def __init__(self, **model_params):super().__init__()embedding_dim model_params[embedding_dim]self.norm_by_EMB nn.BatchNorm1d(embedding_dim, affineTrue)# Funny Batch_Norm, as it will normalized by EMB dimdef forward(self, input1, input2):# input.shape: (batch, problem, embedding)batch_s input1.size(0)problem_s input1.size(1)embedding_dim input1.size(2)added input1 input2normalized self.norm_by_EMB(added.reshape(batch_s * problem_s, embedding_dim))back_trans normalized.reshape(batch_s, problem_s, embedding_dim)return back_transclass FeedForward(nn.Module):def __init__(self, **model_params):super().__init__()embedding_dim model_params[embedding_dim]ff_hidden_dim model_params[ff_hidden_dim]self.W1 nn.Linear(embedding_dim, ff_hidden_dim)self.W2 nn.Linear(ff_hidden_dim, embedding_dim)def forward(self, input1):# input.shape: (batch, problem, embedding)return self.W2(F.relu(self.W1(input1)))