2017-12-12 16 views
0

Keras 콜백와 함께 최고의 val_loss : https://blog.keras.io/building-autoencoders-in-keras.html저장 여기에서 논의 된 내용에 근거하여 mnist 데이터 세트에 대한 잡음 제거의 autoencoder을 넣어

내가 어떻게 입력 이미지 변화의 재건을 통해 볼려고 해요 시각; DAE 스파이크 (훈련 및 검증 모두)가 손실되는 경우가 있음을 알고 있습니다 (예 : ~ 0.12 ~ 3.0의 손실). 훈련 과정에서 이러한 "실수"를 피하기 위해 Keras의 콜백을 사용하고 최상의 가중치 (val_loss wiss)를 저장하고 훈련의 각 "세그먼트"(내 경우에는 = 10 epoch) 이후에로드하려고합니다.

File "noise_e_mini.py", line 71, in <module> callbacks=([checkpointer])) File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1650, in fit validation_steps=validation_steps) File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1145, in _fit_loop callbacks.set_model(callback_model) File "/usr/local/lib/python2.7/dist-packages/keras/callbacks.py", line 48, in set_model callback.set_model(model) AttributeError: 'tuple' object has no attribute 'set_model'

내 코드는 다음과 같습니다 : 내가 잘못 뭐하는 거지

from keras.layers import Input, Dense 
from keras.models import Model 
from keras import regularizers 
from keras.callbacks import ModelCheckpoint 
input_img = Input(shape=(784,)) 

filepath_for_w='denoise_by_AE_weights_1.h5' 


def autoencoder_block(input,l1_score_encode,l1_score_decode): 


    # encoder: 
    encoded = Dense(256, activation='relu',activity_regularizer=regularizers.l1(l1_score_encode))(input_img) 
    encoded = Dense(128, activation='relu',activity_regularizer=regularizers.l1(l1_score_encode))(encoded) 
    encoded = Dense(64, activation='relu',activity_regularizer=regularizers.l1(l1_score_encode))(encoded) 
    encoded = Dense(32, activation='relu',activity_regularizer=regularizers.l1(l1_score_encode))(encoded) 

    encoder = Model (input=input_img, output=encoded) 

    # decoder: 
    connection_layer= Input(shape=(32,)) 
    decoded = Dense(64, activation='relu',activity_regularizer=regularizers.l1(l1_score_decode))(connection_layer) 
    decoded = Dense(128, activation='relu',activity_regularizer=regularizers.l1(l1_score_decode))(decoded) 
    decoded = Dense(256, activation='relu',activity_regularizer=regularizers.l1(l1_score_decode))(decoded) 
    decoded = Dense(784, activation='sigmoid',activity_regularizer=regularizers.l1(l1_score_decode))(decoded) 

    decoder = Model (input=connection_layer , output=decoded) 

    crunched = encoder(input_img) 
    final = decoder(crunched) 

    autoencoder = Model(input=input_img, output=final) 
    autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy') 
    return (autoencoder) 



from keras.datasets import mnist 
import numpy as np 
(x_train, y_train), (x_test, y_test) = mnist.load_data() 


x_train = x_train.astype('float32')/255. 
x_test = x_test.astype('float32')/255. 
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:]))) 
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:]))) 
print x_train.shape 
print x_test.shape 

noise_factor = 0.5 

x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape) 
x_test_noisy = np.clip(x_test_noisy, 0., 1.) 



autoencoder=autoencoder_block(input_img,0,0) 

for i in range (10): 

    x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape) 
    x_train_noisy = np.clip(x_train_noisy, 0., 1.) 
    checkpointer=ModelCheckpoint(filepath_for_w, monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=True, mode='auto', period=1), 

    autoencoder.fit(x_train_noisy, x_train, 
       epochs=10, 
       batch_size=256,    
       shuffle=True, 
       validation_data=(x_test_noisy, x_test), 
       callbacks=([checkpointer])) 
    autoencoder.load_weights(filepath_for_w) # load weights from the best in the run 

    decoded_imgs = autoencoder.predict(x_test_noisy) # save results for this stage for presentation 
    np.save('decoded'+str(i)+'.npy',decoded_imgs) #### 

np.save('tested.npy',x_test_noisy) 
np.save ('true_catagories.npy',y_test) 
np.save('original.npy',x_test) 


autoencoder.save('denoise_by_AE_model_1.h5') 

그러나, 나는 오류 메시지는 무엇입니까?

callbacks=([checkpointer])) 

당신은 콜백 목록이 아니라 튜플을 필요로 괄호를 삭제해야이 줄 내 많은 감사합니다 :)

답변

0

귀하의 문제 아마 라인, 시도 : 나는 또한 눈치

callbacks=[checkpointer] 

당신의 checkpointer가 쉼표로 끝나면 제거해야합니다.

checkpointer=ModelCheckpoint(filepath_for_w, monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=True, mode='auto', period=1),