2016-10-20 4 views
0

동일한 가중치와 바이어스로 두 개의 네트워크를 만들려고합니다. 비슷한 학습 곡선이 기대됩니다. 반복 2에서 두 네트워크의 모든 얼룩은 동일하지만 (데이터 & Diff), params (가중치 & 바이어스)는 다릅니다!Pycafe- 동일한 가중치와 바이어스가있는 두 개의 네트워크 만들기

내가 여기서 잘못하고있는 것은 무엇입니까?

참고 : 네트워크에는 데이터 세트 및 드롭 아웃 레이어에 대한 임의의 셔플이 없습니다.

감사

solver1 = caffe.SGDSolver('lenet_solver.prototxt') 
solver2 = caffe.SGDSolver('lenet_solver.prototxt') 
solver1.step(1) 
solver2.step(1) 
CopySolver(solver1,solver2) 
for i in range(10): 
    solver1.step(1)  
    solver2.step(1) 
    print solver1.net.params['ip2'][1].diff 
    print solver2.net.params['ip2'][1].diff 

def CopySolver(SolverA,SolverB): 
    params = SolverA.net.params.keys() 
    paramsA = {pr: (SolverA.net.params[pr][0].data,SolverA.net.params[pr][1].data) for pr in params}  
    paramsB = {pr: (SolverB.net.params[pr][0].data,SolverB.net.params[pr][1].data) for pr in params}     
    for pr in params: 
     paramsB[pr][1][...] = paramsA [pr][1] #bias 
     paramsB[pr][0][...] = paramsA [pr][0] #weights 

답변

0

당신은 좋은 일을하고 있습니다. 모든 열차는 개개인입니다. 당신은 2 개의 동일한 그물을 가질 수 있고 동일한 데이터 세트에서 그들을 훈련시킬 수 있습니다. 그러나 각각의 그물은 무작위 적으로 시작될 것입니다. 그래서 왜 그물마다 다른 params를 얻었습니까?

+0

후 반복 한 후 초기 점은 동일합니다. 나는 학습 과정에서 두 네트워크의 매개 변수간에 차이가 없다고 기대한다. –

1

당신은 솔버의 추진력을 고려하지 않았습니다. 한 솔버 객체에서 다른 솔버 객체로 넷 매개 변수를 복사 한 후 solver1과 solver2간에 솔버의 모멘텀 정보 (예 : SGD)가 여전히 다릅니다. "lenet_solver.prototxt"에 "momentum : 0"을 설정하면 예상되는 동작을 얻게됩니다.

그렇지 않으면 매개 변수를 저장하고 새 솔버 개체를 두 개 만들고 매개 변수를로드 한 다음 교육을 다시 시작할 수 있습니다. 이렇게하면 두 가지가 초기 추진력없이 시작되는 것을 보장 할 수 있습니다. 여기 처럼이 볼 수있는 방법 예 : 내가 네트워크 2에 대한 네트워크 1 (모든 편견 및 모든 무게를) 복사이 코드에서

solver1 = caffe.SGDSolver('lenet_solver.prototxt') 
solver2 = caffe.SGDSolver('lenet_solver.prototxt') 
solver1.step(1) 
solver2.step(1) 

solver1.net.save("tmp.caffemodel") 

solver1 = caffe.SGDSolver('lenet_solver.prototxt') 
solver2 = caffe.SGDSolver('lenet_solver.prototxt') 
solver1.net.copy_from("tmp.caffemodel") 
solver2.net.copy_from("tmp.caffemodel") 

for i in range(10): 
    solver1.step(1) 
    solver2.step(1) 
    print solver1.net.params['ip2'][1].diff 
    print solver2.net.params['ip2'][1].diff 
+0

특별히 Erik B.에게 감사합니다. 내 문제가 해결되었습니다. –