2017-12-11 12 views
0

견적을 평가하는 동안 tensorflow가 결정 계수 (R 제곱)를 계산하기를 원합니다. 그런 다음tf.estimator를 사용한 맞춤 메트릭

def r_squared(labels, predictions, weights=None, 
       metrics_collections=None, 
       updates_collections=None, 
       name=None): 

    total_error = tf.reduce_sum(tf.square(labels - tf.reduce_mean(labels))) 
    unexplained_error = tf.reduce_sum(tf.square(labels - predictions)) 
    r_sq = 1 - tf.div(unexplained_error, total_error) 

    # update_rsq_op = ? 

    if metrics_collections: 
     ops.add_to_collections(metrics_collections, r_sq) 

    # if updates_collections: 
    #  ops.add_to_collections(updates_collections, update_rsq_op) 

    return r_sq #, update_rsq_op 

, 내가 EstimatorSpec에 메트릭으로이 기능을 사용 : 내가 loosly 공식 통계의 구현을 기반으로 다음과 같은 방법으로 그것을 구현하려고하지만

estim_specs = tf.estimator.EstimatorSpec(
    ... 
    eval_metric_ops={ 
     'r_squared': r_squared(labels, predictions), 
     ... 
    }) 

이 실패 R 제곱의 구현은 update_op을 반환하지 않기 때문에.

TypeError: Values of eval_metric_ops must be (metric_value, update_op) tuples, given: Tensor("sub_4:0", dtype=float64) for key: r_squared 

이제 update_op은 어떻게해야할까요? 실제로 update_op을 구현해야합니까? 아니면 어떻게 든 더미 update_op을 만들 수 있습니까? 필요한 경우 어떻게 구현합니까?

+0

의 사용 가능한 복제 [사용자 평가 \ _metric \ _ops Tensorflow에서 견적에 (https://stackoverflow.com/questions/45643809/custom-eval-metric-ops-in-estimator-in-tensorflow) – CvW

답변

1

좋아요, 그래서 알아낼 수있었습니다. 내 측정 항목을 평균 척도로 랩하고 update_op을 사용할 수 있습니다. 이것은 나를 위해 작동하는 것 같다.

def r_squared(labels, predictions, weights=None, 
       metrics_collections=None, 
       updates_collections=None, 
       name=None): 

    total_error = tf.reduce_sum(tf.square(labels - tf.reduce_mean(labels))) 
    unexplained_error = tf.reduce_sum(tf.square(labels - predictions)) 
    r_sq = 1 - tf.div(unexplained_error, total_error) 

    m_r_sq, update_rsq_op = tf.metrics.mean(r_sq) 

    if metrics_collections: 
     ops.add_to_collections(metrics_collections, m_r_sq) 

    if updates_collections: 
     ops.add_to_collections(updates_collections, update_rsq_op) 

    return m_r_sq, update_rsq_op