2017-12-12 27 views
0

Simple multi-task network can be done here.하지만이 같은 것을 원합니다. enter image description here. 지금은 아래와 같이 모델 구성 :torch7에서 다중 작업 학습을 수행하는 방법은 무엇입니까?

model = nn.Sequential() 
model:add(nn.Linear(3,5)) 
prl1 = nn.ConcatTable() 
prl1:add(nn.Linear(5,1)) 
prl2 = nn.ConcatTable() 
prl2:add(nn.Linear(5,1)) 
prl2:add(nn.Linear(5,1)) 
prl1:add(prl2) 
model:add(prl1) 

을 그리고 내 출력은 다음과 같습니다

input = torch.rand(5,3) 
output = model:forward(input) 
output 
{ 
    1 : DoubleTensor - size: 5x1 
    2 : 
    { 
     1 : DoubleTensor - size: 5x1 
     2 : DoubleTensor - size: 5x1 
    } 
} 

어떻게 기준을 구축해야합니까?

가 이 출력을 상기 네트워크 대신 nn.ConcatTable의

1. 한 nn.Concat 될 수있는 N × M 개의 단순한 텐서 예컨대 : I는 두 단계를 알아낼

답변

0

nx.ConcatTable 대신 nn.Concat을 사용하는 동안 5x3 텐서가 위의 네트워크에 들어갑니다.

2.NxM 텐서를 얻은 후, 각 결과 Tensor를 포함하는 출력을 간단한 테이블로 만들기 위해 nn.ConcatTable, nn.Concat 및 nn.Select의 조합을 사용합니다.

th> output 
{ 
    1 : DoubleTensor - size: 5x3 
    2 : DoubleTensor - size: 5x2 
} 
:

model = nn.Sequential() 
model:add(nn.Linear(3,5)) 

prl = nn.ConcatTable() 

spl1 = nn.Concat(2) 

seq1 = nn.Sequential() 
seq1:add(nn.Select(2, 1)) 
seq1:add(nn.Reshape(1)) 

seq2 = nn.Sequential() 
seq2:add(nn.Select(2, 2)) 
seq2:add(nn.Reshape(1)) 

seq3 = nn.Sequential() 
seq3:add(nn.Select(2, 3)) 
seq3:add(nn.Reshape(1)) 

spl1:add(seq1) 
spl1:add(seq2) 
spl1:add(seq3) 
prl:add(spl1) 

spl2 = nn.Concat(2) 

seq4 = nn.Sequential() 
seq4:add(nn.Select(2, 4)) 
seq4:add(nn.Reshape(1)) 

seq5 = nn.Sequential() 
seq5:add(nn.Select(2, 5)) 
seq5:add(nn.Reshape(1)) 

spl2:add(seq4) 
spl2:add(seq5) 
prl:add(spl2) 

model:add(prl) 

input = torch.rand(5,3) 
output = model:forward(input) 

출력은 같을 것이다 : 여기서

2 단계의 간단한 예는