关于python:从PyTorch中的BiLSTM(BiGRU)获取最后一个状态

Taking the last state from BiLSTM (BiGRU) in PyTorch

在阅读了几篇文章之后,我仍然对从BiLSTM获取最后的隐藏状态的实现的正确性感到困惑。

  • 了解PyTorch中的双向RNN(TowardsDataScience)
  • seq2seq模型的PackedSequence(PyTorch论坛)
  • hiddena€和有什么不一样?和一个输出在PyTorch LSTM中? (堆栈溢出)
  • 在一批序列中选择张量(Pytorch形式)

  • 最后一个来源(4)的方法对我来说似乎是最干净的方法,但是我仍然不确定我是否正确理解了线程。我是否在使用LSTM和反向LSTM中正确的最终隐藏状态?这是我的实现

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    # pos contains indices of words in embedding matrix
    # seqlengths contains info about sequence lengths
    # so for instance, if batch_size is 2 and pos=[4,6,9,3,1] and
    # seqlengths contains [3,2], we have batch with samples
    # of variable length [4,6,9] and [3,1]

    all_in_embs = self.in_embeddings(pos)
    in_emb_seqs = pack_sequence(torch.split(all_in_embs, seqlengths, dim=0))
    output,lasthidden = self.rnn(in_emb_seqs)
    if not self.data_processor.use_gru:
        lasthidden = lasthidden[0]
    # u_emb_batch has shape batch_size x embedding_dimension
    # sum last state from forward and backward  direction
    u_emb_batch = lasthidden[-1,:,:] + lasthidden[-2,:,:]

    正确吗?


    在一般情况下,如果要创建自己的BiLSTM网络,则需要创建两个常规LSTM,并用常规输入序列馈送一个LSTM,而用反向输入序列馈送另一个。完成两个序列的输入后,您只需从两个网络中获取最后一个状态,然后以某种方式将它们绑定在一起(求和或并置)。

    据我了解,您在本示例中使用内置BiLSTM(在nn.LSTM构造函数中设置bidirectional=True)。然后,在喂完批次后,您将获得并置的输出,因为PyTorch会为您处理所有麻烦。

    如果是这种情况,并且您想对隐藏状态求和,则必须

    1
    u_emb_batch = (lasthidden[0, :, :] + lasthidden[1, :, :])

    假设您只有一层。如果您有更多的图层,您的变体似乎会更好。

    这是因为结果是结构化的(请参见文档):

    h_n of shape (num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t = seq_len

    顺便说一句

    1
    u_emb_batch_2 = output[-1, :, :HIDDEN_DIM] + output[-1, :, HIDDEN_DIM:]

    应提供相同的结果。


    这里是对使用解压缩序列的人员的详细说明:

    output的形状为(seq_len, batch, num_directions * hidden_size)(请参阅文档)。这意味着GRU的前向和后向传递的输出沿第3维连接。

    假设您的示例中为batch=2hidden_size=256,则可以通过执行以下操作轻松地分离正向和反向传递的输出:

    1
    2
    3
    output = output.view(-1, 2, 2, 256)   # (seq_len, batch_size, num_directions, hidden_size)
    output_forward = output[:, :, 0, :]   # (seq_len, batch_size, hidden_size)
    output_backward = output[:, :, 1, :]  # (seq_len, batch_size, hidden_size)

    (注意:-1告诉pytorch从其他维度推断出该尺寸。请参见此问题。)

    等效地,您可以在形状为(seq_len, batch, num_directions * hidden_size)的原始output上使用torch.chunk函数:

    1
    2
    # Split in 2 tensors along dimension 2 (num_directions)
    output_forward, output_backward = torch.chunk(output, 2, 2)

    现在,您可以使用seqlengths(重塑形状后)torch.gather前向通过的最后一个隐藏状态,以及通过选择位置0

    的元素来向后通过的最后一个隐藏状态。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    # First we unsqueeze seqlengths two times so it has the same number of
    # of dimensions as output_forward
    # (batch_size) -> (1, batch_size, 1)
    lengths = seqlengths.unsqueeze(0).unsqueeze(2)

    # Then we expand it accordingly
    # (1, batch_size, 1) -> (1, batch_size, hidden_size)
    lengths = lengths.expand((1, -1, output_forward.size(2)))

    last_forward = torch.gather(output_forward, 0, lengths - 1).squeeze(0)
    last_backward = output_backward[0, :, :]

    请注意,由于基于0的索引,我从lengths中减去了1

    在这一点上,last_forwardlast_backward的形状均为(batch_size, hidden_dim)


    我测试了biLSTM输出和h_n:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    # shape of x is size(batch_size, time_steps, input_size)
    # shape of output (batch_size, time_steps, hidden_size * num_directions)
    # shape of h_n is size(num_directions, batch_size, hidden_size)
    output, (h_n, _c_n) = biLSTM(x)

    print('step 0 of output from reverse == h_n from reverse?',
        output[:, 0, hidden_size:] == h_n[1])
    print('step -1 of output from reverse == h_n from reverse?',
        output[:, -1, hidden_size:] == h_n[1])

    输出

    1
    2
    step 0 of output from reverse == h_n from reverse? True
    step -1 of output from reverse == h_n from reverse? False

    这证实了反向的h_n是第一步的隐藏状态。

    因此,如果您确实需要从正向和反向两个方向获取最后一个时间步的隐藏状态,则应使用:

    1
    sum_lasthidden = output[:, -1, :hidden_size] + output[:, -1, hidden_size:]

    不是

    1
    h_n[0,:,:] + h_n[1,:,:]

    因为h_n[1,:,:]是从相反方向开始的第一时间步的隐藏状态。

    所以@igrinis的答案是

    1
    u_emb_batch = (lasthidden[0, :, :] + lasthidden[1, :, :])

    不正确。

    但是从理论上讲,反向的最后一个时间步隐藏状态仅包含序列最后一个时间步的信息。