2017-10-12 5 views
0

scikit-learn의 cross_val_score() 기능으로 Keras 신경 네트워크의 교차 검증을하고 싶습니다.scikit-learn에서 cross_val_score()의 각 폴드 후 함수를 실행하려면 어떻게해야합니까?

각 폴드 후 결과가 기억 될뿐만 아니라 전체 Keras 모델도 문제가됩니다. 그래서 각 접힌 후 K.clear_session()을 사용하여이 모델을 정리하고 싶습니다. 그러나 이것은 상황에 대한 세부 사항 일뿐입니다.

내 주요 질문 : scikit-learn에서 cross_val_score()를 사용하여 각 폴드 후 사용자 정의 기능을 어떻게 실행할 수 있습니까? 즉 : 각 폴드 후에 실행되어야하는 콜백을 실행할 수 있습니까? 아니면 다른 해결 방법이 있습니까?

답변

0

사용자 정의 콜백을 작성하고이 콜백의 on_train_end (self, logs = {}) 메소드를 다시 작성할 수 있습니다. 이 새로운 방법은 각 교육 단계가 끝나면 작업을 수행합니다. 뭐 그런 :

class CustomCall(Callback): 

    def __init__(self): 
     super(CustomCall, self).__init__() 

    def on_epoch_begin(self, epoch, logs={}): 
     return 

    def on_epoch_end(self, epoch, logs={}): 
     return 

    def on_batch_begin(self, batch, logs={}): 
     return 

    def on_train_end(self, logs={}): 
     # Stuff here 
     print('\n Delete previous trained model : ') 
     K.clear_session() 
     return 
+0

불행하게도, 문제는, K.clear_session()가되지 cross_val_score의 내부 훈련 후, 모델의 평가 후에 호출되어야합니다(). 그래서 Keras 훈련이 끝난 후에가 아니라, 크로스 볼 폴드가 끝날 때 K.clear_session()을 호출해야합니다. –