2017-11-19 19 views
4

신경망을 배우고 있으며 파이썬에서 함수 cross_entropy을 작성하고 싶습니다. 이크로스 엔트로피 함수 (파이썬)

cross entropyN 샘플의 개수

, k 같이 정의된다 클래스 수, log 자연 대수, 샘플 i 달리 클래스 j0 인 경우 t_i,j은 1이며 , p_i,j은 샘플 i이 클래스 j에있을 것으로 예상되는 확률입니다. 대수와 관련된 숫자 문제를 피하려면 예측을 [10^{−12}, 1 − 10^{−12}] 범위로 잘라냅니다.

위의 설명에 따르면, 나는 예측을 clippint로 예측하여 [epsilon, 1 − epsilon] 범위로 작성한 다음 위 공식에 따라 cross_entropy를 계산합니다.

def cross_entropy(predictions, targets, epsilon=1e-12): 
    """ 
    Computes cross entropy between targets (encoded as one-hot vectors) 
    and predictions. 
    Input: predictions (N, k) ndarray 
      targets (N, k) ndarray   
    Returns: scalar 
    """ 
    predictions = np.clip(predictions, epsilon, 1. - epsilon) 
    ce = - np.mean(np.log(predictions) * targets) 
    return ce 

다음 코드는 cross_entropy 기능이 올바른지 확인하는 데 사용됩니다.

predictions = np.array([[0.25,0.25,0.25,0.25], 
         [0.01,0.01,0.01,0.96]]) 
targets = np.array([[0,0,0,1], 
        [0,0,0,1]]) 
ans = 0.71355817782 #Correct answer 
x = cross_entropy(predictions, targets) 
print(np.isclose(x,ans)) 

위 코드의 출력은 그 cross_entropy가 정확하지 함수를 정의 내 코드를 말하고, False입니다. 그런 다음 cross_entropy(predictions, targets)의 결과를 인쇄합니다. 그것은 0.178389544455이고 올바른 결과는 ans = 0.71355817782이어야합니다. 아무도 내 코드의 문제점을 확인하는 데 도움을 줄 수 있습니까?

답변

5

당신은 그다지 멀지 않지만 N 합계의 평균값을 취합니다. 여기서 N = 2 (이 경우)입니다. 그래서 코드를 읽을 수 :

def cross_entropy(predictions, targets, epsilon=1e-12): 
    """ 
    Computes cross entropy between targets (encoded as one-hot vectors) 
    and predictions. 
    Input: predictions (N, k) ndarray 
      targets (N, k) ndarray   
    Returns: scalar 
    """ 
    predictions = np.clip(predictions, epsilon, 1. - epsilon) 
    N = predictions.shape[0] 
    ce = -np.sum(np.sum(targets*np.log(predictions+1e-9)))/N 
    return ce 

predictions = np.array([[0.25,0.25,0.25,0.25], 
         [0.01,0.01,0.01,0.96]]) 
targets = np.array([[0,0,0,1], 
        [0,0,0,1]]) 
ans = 0.71355817782 #Correct answer 
x = cross_entropy(predictions, targets) 
print(np.isclose(x,ans)) 

을 여기에, 나는 당신이() np.sum 고수 경우 좀 더 명확한 생각합니다. 또한 계산에 log (0)을 가질 가능성을 피하기 위해 1e-9를 np.log()에 추가했습니다. 희망이 도움이!