Taking the last state from BiLSTM (BiGRU) in PyTorch
在阅读了几篇文章之后,我仍然对从BiLSTM获取最后的隐藏状态的实现的正确性感到困惑。
最后一个来源(4)的方法对我来说似乎是最干净的方法,但是我仍然不确定我是否正确理解了线程。我是否在使用LSTM和反向LSTM中正确的最终隐藏状态?这是我的实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | # pos contains indices of words in embedding matrix # seqlengths contains info about sequence lengths # so for instance, if batch_size is 2 and pos=[4,6,9,3,1] and # seqlengths contains [3,2], we have batch with samples # of variable length [4,6,9] and [3,1] all_in_embs = self.in_embeddings(pos) in_emb_seqs = pack_sequence(torch.split(all_in_embs, seqlengths, dim=0)) output,lasthidden = self.rnn(in_emb_seqs) if not self.data_processor.use_gru: lasthidden = lasthidden[0] # u_emb_batch has shape batch_size x embedding_dimension # sum last state from forward and backward direction u_emb_batch = lasthidden[-1,:,:] + lasthidden[-2,:,:] |
正确吗?
在一般情况下,如果要创建自己的BiLSTM网络,则需要创建两个常规LSTM,并用常规输入序列馈送一个LSTM,而用反向输入序列馈送另一个。完成两个序列的输入后,您只需从两个网络中获取最后一个状态,然后以某种方式将它们绑定在一起(求和或并置)。
据我了解,您在本示例中使用内置BiLSTM(在nn.LSTM构造函数中设置
如果是这种情况,并且您想对隐藏状态求和,则必须
1 | u_emb_batch = (lasthidden[0, :, :] + lasthidden[1, :, :]) |
假设您只有一层。如果您有更多的图层,您的变体似乎会更好。
这是因为结果是结构化的(请参见文档):
h_n of shape (num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t = seq_len
顺便说一句
1 | u_emb_batch_2 = output[-1, :, :HIDDEN_DIM] + output[-1, :, HIDDEN_DIM:] |
应提供相同的结果。
这里是对使用解压缩序列的人员的详细说明:
假设您的示例中为
1 2 3 | output = output.view(-1, 2, 2, 256) # (seq_len, batch_size, num_directions, hidden_size) output_forward = output[:, :, 0, :] # (seq_len, batch_size, hidden_size) output_backward = output[:, :, 1, :] # (seq_len, batch_size, hidden_size) |
(注意:
等效地,您可以在形状为
1 2 | # Split in 2 tensors along dimension 2 (num_directions) output_forward, output_backward = torch.chunk(output, 2, 2) |
现在,您可以使用
的元素来向后通过的最后一个隐藏状态。
1 2 3 4 5 6 7 8 9 10 11 | # First we unsqueeze seqlengths two times so it has the same number of # of dimensions as output_forward # (batch_size) -> (1, batch_size, 1) lengths = seqlengths.unsqueeze(0).unsqueeze(2) # Then we expand it accordingly # (1, batch_size, 1) -> (1, batch_size, hidden_size) lengths = lengths.expand((1, -1, output_forward.size(2))) last_forward = torch.gather(output_forward, 0, lengths - 1).squeeze(0) last_backward = output_backward[0, :, :] |
请注意,由于基于0的索引,我从
在这一点上,
我测试了biLSTM输出和h_n:
1 2 3 4 5 6 7 8 9 | # shape of x is size(batch_size, time_steps, input_size) # shape of output (batch_size, time_steps, hidden_size * num_directions) # shape of h_n is size(num_directions, batch_size, hidden_size) output, (h_n, _c_n) = biLSTM(x) print('step 0 of output from reverse == h_n from reverse?', output[:, 0, hidden_size:] == h_n[1]) print('step -1 of output from reverse == h_n from reverse?', output[:, -1, hidden_size:] == h_n[1]) |
输出
1 2 | step 0 of output from reverse == h_n from reverse? True step -1 of output from reverse == h_n from reverse? False |
这证实了反向的h_n是第一步的隐藏状态。
因此,如果您确实需要从正向和反向两个方向获取最后一个时间步的隐藏状态,则应使用:
1 | sum_lasthidden = output[:, -1, :hidden_size] + output[:, -1, hidden_size:] |
不是
1 | h_n[0,:,:] + h_n[1,:,:] |
因为
所以@igrinis的答案是
1 | u_emb_batch = (lasthidden[0, :, :] + lasthidden[1, :, :]) |
不正确。
但是从理论上讲,反向的最后一个时间步隐藏状态仅包含序列最后一个时间步的信息。