1

전체 상태를 드러내지 않는 관측으로 인해, 반복적 인 신경망으로 보강을해야 네트워크에 과거의 일종의 기억이 있습니다. 간단히하기 위해 LSTM을 사용한다고 가정 해 보겠습니다.PyTorch에서 LSTM을 사용하여 보강 학습을하는 방법은 무엇입니까?

내장 된 PyTorch LSTM은 모양 Time x MiniBatch x Input D의 입력을 피드에 입력해야하며 텐서 형태의 Time x MiniBatch x Output D을 출력합니다.

그러나 보강 학습에서는 시간이 t+1 인 입력을 알고 있으므로 환경에서 작업을 수행하고 있으므로 t 시간의 출력을 알아야합니다.

내장 된 PyTorch LSTM을 사용하여 강화 학습 설정에서 BPTT를 수행 할 수 있습니까? 그리고 만약 그렇다면 어떻게 할 수 있습니까?

답변

1

입력 시퀀스를 루프로 LSTM에 공급할 수도 있습니다. 이 같은 것을 :

h, c = Variable(torch.zeros()), Variable(torch.zeros()) 
for i in range(T): 
    input = Variable(...) 
    _, (h, c) = lstm(input, (h,c)) 

예를 들어 작업을 평가하기 위해 (h, c) 및 입력을 사용할 수 있습니다. 계산 그래프를 깨지 않는 한 Variables가 모든 기록을 유지하면서 백 프로 퍼 게이트 할 수 있습니다.