2

시퀀스 분류를위한주의 메커니즘을 사용하여 양방향 RNN을 작성하려고합니다. 도우미 기능을 이해하는 데 몇 가지 문제가 있습니다. 훈련에 사용 된 것은 디코더 입력이 필요하다는 것을 알았지 만 전체 시퀀스의 단일 레이블을 원하기 때문에 정확히 입력해야 할 항목을 알지 못합니다. 이것은 지금까지 구축 한 구조이다시퀀스 분류를위한주의 메커니즘 (seq2seq tensorflow r1.1)

# Encoder LSTM cells 
lstm_fw_cell = rnn.BasicLSTMCell(n_hidden) 
lstm_bw_cell = rnn.BasicLSTMCell(n_hidden) 

# Bidirectional RNN 
outputs, states = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, 
        lstm_bw_cell, inputs=x, 
        sequence_length=seq_len, dtype=tf.float32) 

# Concatenate forward and backward outputs 
encoder_outputs = tf.concat(outputs,2) 

# Decoder LSTM cell 
decoder_cell = rnn.BasicLSTMCell(n_hidden) 

# Attention mechanism 
attention_mechanism = tf.contrib.seq2seq.LuongAttention(n_hidden, encoder_outputs) 
attn_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, 
      attention_mechanism, attention_size=n_hidden) 
      name="attention_init") 

# Initial attention 
attn_zero = attn_cell.zero_state(batch_size=tf.shape(x)[0], dtype=tf.float32) 
init_state = attn_zero.clone(cell_state=states[0]) 

# Helper function 
helper = tf.contrib.seq2seq.TrainingHelper(inputs = ???) 

# Decoding 
my_decoder = tf.contrib.seq2seq.BasicDecoder(cell=attn_cell, 
      helper=helper, 
      initial_state=init_state) 

decoder_outputs, decoder_states = tf.contrib.seq2seq.dynamic_decode(my_decoder) 

내 입력 [BATCH_SIZE, sequence_length, n_features] 서열이다 내 출력은 N 가능한 클래스 [BATCH_SIZE, n_classes] 단일 벡터이다.

여기에 무엇이 누락되었거나 시퀀스 분류에 seq2seq를 사용할 수 있는지 알고 계십니까?

답변

1

Seq2Seq 모델은 정의상 이와 같은 작업에는 적합하지 않습니다. 이름에서 알 수 있듯이 일련의 입력 (문장의 단어)을 일련의 레이블 (단어의 부분)으로 변환합니다. 귀하의 경우, 일련의 샘플이 아닌 샘플 당 하나의 라벨을 찾고 있습니다.

인코더의 출력 또는 상태 (RNN) 만 필요하기 때문에 다행히도 이미이 기능을 사용할 수 있습니다.

이것을 사용하여 분류자를 만드는 가장 간단한 방법은 RNN의 최종 상태를 사용하는 것입니다. 모양 [n_hidden, n_classes]을 사용하여 위에 연결된 레이어를 추가하십시오. 이것에서 최종 카테고리를 예측하는 softmax 레이어와 손실을 트레이닝 할 수 있습니다.

원칙적으로 이것은주의 메커니즘을 포함하지 않습니다. 그러나 하나를 포함시키려는 경우 RNN의 각 출력을 학습 벡터로 계량 한 다음 합계를 계산하여 수행 할 수 있습니다. 그러나 이것이 결과를 향상시키는 것은 아닙니다. 더 많은 참조를 위해, 내가 실수하지 않는다면 https://arxiv.org/pdf/1606.02601.pdf이주의 메커니즘을 구현합니다.