2017-12-22 31 views
2

저는 Keras 신경망을 사용하고 있으며, 지금까지 본 모든 튜토리얼 에서처럼 하드 코딩되지 않은 입력 차원을 자동으로 설정하고 싶습니다. 내가 어떻게 이걸 이룰 수 있니?KerasRegressor에서 입력 인수를 지정하십시오.

내 코드 :

from keras.models import Sequential 
from keras.layers import Dense 
from keras.wrappers.scikit_learn import KerasRegressor 
seed = 1 

X = df_input 
Y = df_res 

def baseline_model(x): 
    # create model 
    model = Sequential()  
    model.add(Dense(20, input_dim=x, kernel_initializer='normal', activation=relu)) 
    model.add(Dense(1, kernel_initializer='normal')) 
    # Compile model 
    model.compile(loss='mean_absolute_error', optimizer='adam') 
    return model 

inpt = len(X.columns) 
estimator = KerasRegressor(build_fn = baseline_model(inpt ) , epochs=2, batch_size=1000, verbose=2) 
estimator.fit(X,Y) 

그리고 오류가 나는 얻을 :

Traceback (most recent call last):

File ipython-input-2-49d765e85d15, line 20, in estimator.fit(X,Y)

TypeError: call() missing 1 required positional argument: 'inputs'

+0

견적을 전송할 수있는 방법이 아니기 때문에이 오류가 발생합니다. 구체적으로 scikit-learn API가있는 객체입니다. 즉,'estimator.fit (X, Y)'로 추정기를 훈련시키고'estimator.predict (X, Y)'로 예측을 할 수 있습니다. – rvinas

+0

고맙습니다. 작동중인 솔루션을 염두에 두시겠습니까? –

+0

'estimator (X, Y)'를'estimator.fit (X, Y)'로 바꿉니다. – rvinas

답변

0

나는 당신이 다음과 baseline_model로 포장합니다 :

def baseline_model(x): 
    def baseline_model(): 
     # create model 
     model = Sequential() 
     model.add(Dense(20, input_dim=x, kernel_initializer='normal', activation='relu')) 
     model.add(Dense(1, kernel_initializer='normal')) 
     # Compile model 
     model.compile(loss='mean_absolute_error', optimizer='adam') 
     return model 
    return baseline_model 

을 그리고 정의하고 KerasRegressor로 적합 :

estimator = KerasRegressor(build_fn=baseline_model(inpt), epochs=2, batch_size=1000, verbose=2) 
estimator.fit(X, Y) 

이렇게하면 baseline_model에 입력 크기를 하드 코딩 할 필요가 없습니다.

+0

감사합니다. 나는이 솔루션에 대해 생각해 본 적이 없습니다. 나는 기능 계층을 추가하는 것이 왜 효과가 있는지 확신 할 수 없다. –

+0

도와 드릴 수있어서 기쁩니다. 여기에있는 것은'KerasRegressor'는 모델 자체가 아닌 모델을 만드는 호출 가능 함수를 기대한다는 것입니다. 이런 방식으로 함수를 래핑하면 지정된 'input_dim'을 사용하여 (호출하지 않고) 빌드 함수를 리턴 할 수 있습니다. – rvinas