Pytorch中LSTM相关问题

个人认为在工程方面,主要的工作量在于两点,1.预处理数据,2.计算各个中间阶段输出tensor的shape。只要明白每个阶段输出的shape,就可以知道具体值的意思,内部处理的方式,让代码组织结构清晰起来。所以弄清楚每个操作的输出的shape很重要。

这里记录一下pytorch中LSTM的输出shape,代码如:

1
2
3
4
5
6
7
8
9
import torch
from torch import nn

input = torch.randn(16, 80, 100)
lstm = nn.LSTM(100, 256, 1, batch_first=True, bidirectional=False)
out, (h_t, c_t) = lstm(input)
print(out.shape)
print(c_t.shape)
print(h_t.shape)

输出是:

1
2
3
torch.Size([16, 80, 256])
torch.Size([1, 16, 256])
torch.Size([1, 16, 256])

把几个值的相互关系放在这里便于理解:

单层单向的情况下:( num_layer=0, bidirection=False )

out的输出为 batch x seq_len x hidden 保存了一句话从头到尾每一个token的hidden

h的输出为 1 x batch x hidden 保存了网络输出通道的每句话的最后一个hidden

out[:, -1, :] 和h的值相等

多层单向的情况下:(num_layer=n, bidirection=False)

out的输出维度为 batch x seq_len x hidden

h的输出维度为 n x batch x hidden

此时out只保留最上层网络的每一个h

h的值为每一层网络的值

单层双向:( num_layer=0, bidirection=True )

out的维度为 batch x seq_len x hidden*2

h的维度为 2 x batch x hidden

此时out保存的是每一个句子中每一个token在前向和后向的两个h,并做了串接

h保存的是前向网络和后向网络两个网络的最后一步的h,且都以行进的方向作为下标

也就是说 对于一个句子,h[0] h[1]两个网络的输出中 h[1] 逆向后与h[0]串接才能组成out的最后一步

多层多向:( num_layer=n, bidirection=True )

out的维度为 batch x seq_len x hidden*2

h的 维度为 n*2 x batch x hidden

总结:out保存的是最上层的双向单元串接后(如果是双向的话)的每一个step(每一个token)的h

h保存的是每一个网络层、反向的最后一步的h,双向lstm中的h反向层的h是前进顺序的,相对来说是逆序的。

具体的描述可以参考这篇:学会区分 RNN 的 output 和 state 。不过里面把(h,c)写成了(c,h),需要自己分辨一下。