2017-12-01 46 views
1

keras에서 나는 깊이 감시 된 길쌈 네트워크를 설계했습니다. 정확히 9 개의 출력 레이어가 있습니다.Keras : model.fit_generator가있는 다중 출력 모델 용 생성기 사용

yield(X, {'conv10': y, 'seg_1': y, 'seg_2': y, 'seg_3': y, 'seg_4': y, 'seg_5': y, 'seg_6': y, 'seg_7': y, 'seg_8': y}) 

가 나는의 권고 다음이 sintax을 준 :

  1. Keras: How to use fit_generator with multiple outputs of different type
  2. https://keras.io/getting-started/functional-api-guide/ 나는 산출 간단한 발전기를 개발했다.

그러나 나는이 오류가 계속 :

Traceback (most recent call last): 
    File "modeltrain.py", line 180, in <module> 
    model.fit_generator(next_batch(X_train_r, y_train_r, batch_size), steps_per_epoch=(X_train_r.shape[0]/batch_size), validation_data=(X_val_r, y_val_r), epochs=100, callbacks=[csv_logger, model_check]) 
    File "/home/m/.local/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper 
    return func(*args, **kwargs) 
    File "/home/m/.local/lib/python3.6/site-packages/keras/engine/training.py", line 1978, in fit_generator 
    val_x, val_y, val_sample_weight) 
    File "/home/m/.local/lib/python3.6/site-packages/keras/engine/training.py", line 1382, in _standardize_user_data 
    exception_prefix='target') 
    File "/home/m/.local/lib/python3.6/site-packages/keras/engine/training.py", line 111, in _standardize_input_data 
    'Found: array with shape ' + str(data.shape)) 
ValueError: The model expects 9 target arrays, but only received one array. Found: array with shape (70, 512, 512, 1) 

내가 할 그 밖의 무엇을 모른다!

# Importing the pre processed data in the text file. 

X_train= np.loadtxt("X_train.txt") 
X_test= np.loadtxt("X_test.txt") 
X_val= np.loadtxt("X_val.txt") 
y_train= np.loadtxt("y_train.txt") 
y_test= np.loadtxt("y_test.txt") 
y_val= np.loadtxt("y_val.txt")enter 

# Resize the input matrix so that it satisfies (batch, x, y, z) 

new_size=512 
X_train_r=X_train.reshape(X_train.shape[0],new_size,new_size) 
X_train_r=np.expand_dims(X_train_r, axis=3) 
y_train_r=y_train.reshape(y_train.shape[0],new_size,new_size) 
y_train_r=np.expand_dims(y_train_r, axis=3) 
X_val_r=X_val.reshape(X_val.shape[0],new_size,new_size) 
X_val_r=np.expand_dims(X_val_r, axis=3) 
y_val_r=y_val.reshape(y_val.shape[0],new_size,new_size) 
y_val_r=np.expand_dims(y_val_r, axis=3) 
X_test_r=X_test.reshape(X_test.shape[0],new_size,new_size) 
X_test_r=np.expand_dims(X_test_r, axis=3) 
y_test_r=y_test.reshape(y_test.shape[0],new_size,new_size) 
y_test_r=np.expand_dims(y_test_r, axis=3) 

def next_batch(Xs, ys, size): 
    while true: 
     perm=np.random.permutation(Xs.shape[0]) 
     for i in np.arange(0, Xs.shape[0], size): 
      X=Xs[perm[i:i+size]] 
      y=ys[perm[i:i+size]] 
      yield(X, {'conv10': y, 'seg_1': y, 'seg_2': y, 'seg_3': y, 'seg_4': y, 'seg_5': y, 'seg_6': y, 'seg_7': y,'seg_8': y }) 

# Model Training 
model= get_unet() 
batch_size=1 

#Compile the model 
adam=optimizers.Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08) 
model.compile(loss={'conv10': dice_coef_loss, 'seg_8': loss_seg, 'seg_7': loss_seg , 'seg_6': loss_seg, 'seg_5': loss_seg , 'seg_4': loss_seg , 'seg_3': loss_seg, 'seg_2': loss_seg, 'seg_1': loss_seg}, optimizer=adam, metrics=['accuracy']) 

    #Fit the model 
    model.fit_generator(next_batch(X_train_r, y_train_r, batch_size), steps_per_epoch=(X_train_r.shape[0]/batch_size), validation_data=(X_val_r, y_val_r), epochs=100) 
+1

seg_8은 컴파일되지만 생성기에서는 seg_8의 경우에는 Y가 없습니다. 그들은 일치해야하지 않습니까? – Ajjo

+0

예, 일부 테스트를 실행하고 코드에서 해당 부분을 지웠습니다. 코드를 업데이트했습니다. –

답변

0

귀하의 코드는 훈련, 유효성 검사 중하지 실패 : 여기

는 코드입니다. validation_data 매개 변수가 생성기에서 전달되어야 할 때 배열에서 전달되는 것처럼 보입니다. 다음은 동일한 생성기를 유효성 검사 및 교육에 사용하는 간단한 예제입니다.

a = Input(shape=(10,)) 
o1 = Dense(5, name='output1')(a) 
o2 = Dense(7, name='output2')(a) 
model = Model(inputs=a, outputs=[o1,o2]) 
model.compile(optimizer='sgd', loss='mse') 

def generator(): 
    batch_size = 8 
    x = np.zeros((batch_size, 10)) 
    y1 = np.zeros((batch_size, 5)) 
    y2 = np.zeros((batch_size, 7)) 
    while True: 
     yield x, {'output1': y1, 'output2': y2} 

model.fit_generator(generator(), 1, 1, validation_data=generator(), validation_steps=1) 
+0

잡기에 너무 감사드립니다! –