pytorch multi head attention(pytorch torch.load)

作者:电脑培训网 2024-04-15 12:35:54 997

Pytorch文档解读|torch.nn.MultiheadAttention的用法及参数分析官方文档链接:MultiheadAttention—PyTorch1.12文档

目录

pytorch multi head attention(pytorch torch.load)

更注重头的原理

多关注pytorch

官方参数解释解读:

多注意pytorchheader的使用

完整使用代码

多注意头原理

MultiheadAttention,翻译成中文为多注意力头,由多个单注意力头组成。

他们看着像是:

单头注意力图如下:

单注意力头

整体称为单注意力头,因为运算后,每个输入仅产生一个输出结果。一般在网络中,输出可以称为网络提取的特征,那么我们肯定要提取多个特征,[比如我的输入是狗的向量序列,我肯定希望网络能够提取特征例如形状、颜色、纹理等,所以单一的注意力肯定是不够的]

所以最简单的想法和最优雅的方式就是将多个头水平拼接在一起。我在每个操作中同时提到多个功能,所以长头看起来像这样:

多注意力头紫色矩形块就是上图中的单注意力头,内部结构未展示。如果将h个单注意力头拼接在一起,则放置位置将如图所示。

因为是拼接的,所以每个单独的attentionhead实际上输出的是自己的输出,所以会得到h个features。当h个特征拼接在一起时,就成为multi-attention的输出特征。

pytorch的多注意头

首先我们可以看到调用的时候我们只需要写torch.nn.MultiheadAttention即可,例如

importtorchimporttorch.nnasn#首先确定参数dims=256*10#所有头需要的输入维度总和head=10#单注意力头总数dropout_pro=0.0#单注意力头#传入参数获取我们所需的多注意力头层=torch.nn.MultiheadAttention(embed_dim=dims,num_heads=Heads,dropout=dropout_pro)

解读官方给的参数解释:

embed_dim-模型总维度模型总维度

那么这里应该输入的是每个头输入的维度头的数量

num_heads-并行注意力头的数量。请注意,embed_dim将分为num_heads。

num_heads是注意力头总数

注意括号里的句子。除num_heads外,每个头的尺寸均为embed_dim。

也就是说,如果我的词向量的维度是n,,并且我打算用m个头来提取序列的特征,那么这里的embed_dim的值应该是为nm,num_heads的值为m。

[更新]这里实际上有点扭曲。虽然官方文档说每个头的维度需要除以头的数量,但是当你自己写网络定义时,如果输入多注意力头并且特征为256,你仍然可以这里定义的时候写成256!如果使用4个头,源代码中每个头的特征确实会变成64维,最后重新拼接成64乘以4=256并输出,但是我们不需要担心这个内部过程我们自己。

还有其他参数可以手动设置:

dropout——attn_output_weights上的dropout概率。默认:0.0。

偏差-如果指定,则向输入/输出投影层添加偏差。默认:True。

add_bias_kv如果指定,则在dim=0时向键和值序列添加偏差。默认:False。

add_zero_attn如果指定,则在dim=1处向键和值序列添加一批新的零。默认:False。

kdim按键功能总数。默认:无。

vdim——值的特征总数。默认:无。

batch_first如果为True,则输入和输出张量提供为。默认:False。

多注意头的pytorch使用

如果你看一下定义,你应该发现torch.nn.MultiheadAttention是一个类

我们刚刚输入了多注意力头的参数,并且只是“实例化”了我们想要指定的多注意力头。

所以如果我们想在训练时使用它,就需要给它喂数据,即调用forward函数来完成前向传播动作。

转发函数定义如下:

转发

以下是传入参数的解释

前三个参数是注意力的三个基本向量元素Q、K、V。

query非批处理输入的形状查询嵌入,当batch_first=False或当batch_first=True时,其中是目标序列长度,是批量大小,是查询嵌入维度embed_dim。查询与键值对进行比较以产生输出。有关更多详细信息,请参阅“您所需要的就是注意力”。

翻译一下,如果输入不是batch的形式,query的shape就是目标序列的长度,也就是queryembedding的维度,即输入词向量后q的维度转化为q。这个注释说的是embed_dim,说明输入的词向量与q维度一致;

如果以batch形式输入且batch_first=False,则查询的形状为。如果batch_first=True,则形状为。[batch_first可以在实例化时设置,默认为False]

key当batch_first=False或batch_first=True时,非批处理输入的形状的关键嵌入,其中S是源序列长度,是批量大小,是关键嵌入维度kdim。有关更多详细信息,请参阅“您所需要的就是注意力”。

key为K。同样,查询是batch的形式,且batch_first=False,则key的形状为。是keyembedding的维度,默认同,是原始序列的长度

value非批处理输入的形状的值嵌入,当batch_first=False或当batch_first=True时,其中是源序列长度,是批量大小,是值嵌入维度vdim。有关更多详细信息,请参阅“您所需要的就是注意力”。

value为V,与key相同

其他参数我们先不赘述。

key_padding_mask如果指定,则为形状(N,S)(N,S)的掩码,指示为了引起注意而忽略键内的哪些元素。对于非批量查询,形状应该是(S)(S)。支持二进制和字节掩码。对于二进制掩码,aTruevalue表示为了引起注意,相应的键值将被忽略。对于字节掩码,非零值表示相应的键值将被忽略。

need_weights如果指定,除了attn_outputs之外,还返回attn_output_weights。默认:True。

attn_mask如果指定,则为2D或3D掩码,以防止注意某些位置。形状必须为(L,S)(L,S)或(N\cdot\text{num\_heads},L,S)(Nnum_heads,L,S),其中N是批量大小,LL是目标序列length,S为源序列长度。2D掩码将在批次中广播,而3D掩码允许批次中的每个条目使用不同的掩码。支持二进制、字节和浮点掩码。对于二元掩码,aTrue值表示对应的位置不允许出现。对于字节掩码,非零值表示相应的位置不允许出现。对于浮动掩码,掩码值将添加到注意力权重中。

average_attn_weights如果为true,则表示返回的attn_weights应在各个头之间进行平均。否则,attn_weights是按头单独提供的。请注意,此标志仅在need_weights=True时有效。默认:True

图层输出格式:

attn_output-当输入未批处理时,当batch_first=False或当batch_first=True时,shape的注意力输出,其中是目标序列长度,是批量大小,是嵌入维度embed_dim。

输入batch,batch_first=False。Attention输出的shape为,即目标序列的长度、batch的大小、embed_dim

attn_output_weights-仅当need_weights=True时返回。如果average_attn_weights=True,则当输入未批处理时,返回形状头的平均注意力权重,或者其中N是批处理大小,是目标序列长度,S是源序列长度。如果average_weights=False,则当输入未批处理时,返回每个形状头的注意力权重。

仅当need_weights的值为True时才返回该参数。

完整的使用代码

multihead_attn=nn.MultiheadAttentionattn_output,attn_output_weights=multihead_attn

相关推荐

  • crypthelper.exe(crypto进程)

    crypthelper.exe(crypto进程)

    很多网友问crypserv.exe是什么进程?cryptoserv.exe是病毒吗?下面小编就来介绍一下crypserv.exe的相关内容。不懂的请过来了解一下…

    crypthelper.exe(crypto进程) 2024-05-03 07:40:42
  • linuxweb服务器搭建教程(linux架设web服务器)

    linuxweb服务器搭建教程(linux架设web服务器)

    Linux下Web服务器搭建Web服务器:专门处理HTTP请求的服务器,常被称为Web服务器。这个我有时间仔细研究一下。另外,你可以参考mac/linux上安装…

    linuxweb服务器搭建教程(linux架设web服务器) 2024-05-03 01:43:52
  • rgb融合技术(rgb to hex)

    rgb融合技术(rgb to hex)

    RGB和Depth融合方法总结1:在MMFNet中,作者提出了几种传统的融合方法。(a)首先将RGB和Depth连接起来,然后进行卷积,最终生成特征图。(b)分…

    rgb融合技术(rgb to hex) 2024-05-01 07:00:45
  • kernelcache什么意思(kcleaner是什么文件夹)

    kernelcache什么意思(kcleaner是什么文件夹)

    每次玩游戏都会出现kprcycleaner.exe,占用大量内存,导致游戏无法加载。怎样设置才能让它不自动启动呢?在搜索中找不到它。当你打开任务管理并想将其关闭…

    kernelcache什么意思(kcleaner是什么文件夹) 2024-05-01 05:55:25
  • nginx跨域解决方案 8082(nginx跨域配置详解)

    nginx跨域解决方案 8082(nginx跨域配置详解)

    Nginx跨域解决方案前提条件:前端网站地址:服务器网址:网站8080访问服务器接口时会出现跨域问题。跨域设计主要包括4个响应头:Access-Control-…

    nginx跨域解决方案 8082(nginx跨域配置详解) 2024-04-27 16:00:34