[docs]classMultiHeadAttention(torch.nn.Module):def__init__(self,embed_dim,num_heads):super(MultiHeadAttention,self).__init__()self.embed_dim=embed_dimself.num_heads=num_headsassertembed_dim%num_heads==0,"Embedding dimension must be divisible by the number of heads."self.head_dim=embed_dim//num_headsself.q_proj=torch.nn.Linear(embed_dim,embed_dim)self.k_proj=torch.nn.Linear(embed_dim,embed_dim)self.v_proj=torch.nn.Linear(embed_dim,embed_dim)self.out_proj=torch.nn.Linear(embed_dim,embed_dim)