2017-12-23 7 views
1

저는 TensorFlow (TF)를 배우고 있으며, 하루 만에 끝났습니다. 그래서 물어 보는 것이 너무 근본적이라면 사전에 사과드립니다. 공식 TF 웹 사이트에서 linear classification example을 공부하고있었습니다.TensorFlow의 반복기 사용 예제 코드

저자는 input_fun이라는 함수를 정의하여 데이터를 읽습니다. 기능은 다음과 같습니다 :

def input_fn(data_file, num_epochs, shuffle, batch_size): 
    """Generate an input function for the Estimator.""" 
    assert tf.gfile.Exists(data_file), (
     '%s not found. Please make sure you have either run data_download.py or ' 
     'set both arguments --train_data and --test_data.' % data_file) 

    def parse_csv(value): 
    print('Parsing', data_file) 
    columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS) 
    features = dict(zip(_CSV_COLUMNS, columns)) 
    labels = features.pop('income_bracket') 
    return features, tf.equal(labels, '>50K') 

    # Extract lines from input files using the Dataset API. 
    dataset = tf.data.TextLineDataset(data_file) 

    if shuffle: 
    dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train']) 

    dataset = dataset.map(parse_csv, num_parallel_calls=5) 

    # We call repeat after shuffling, rather than before, to prevent separate 
    # epochs from blending together. 
    dataset = dataset.repeat(num_epochs) 
    dataset = dataset.batch(batch_size) 

    iterator = dataset.make_one_shot_iterator() 
    features, labels = iterator.get_next() 
    return features, labels 

두 번째 마지막 줄을 이해할 수 없습니다. 원샷 반복기는 get_next()을 한 번만 호출하지만 행을 추출하기 위해 데이터를 여러 번 반복하지 않아야합니다 (예 : 행 시간 수). this example here?

답변

2

여기에서 get_next()는 기본적으로 dequeue op입니다. get_next()에 의해 호출 된 요소를 소비 (사용/호출)하면 큐에있는 데이터가 대기열에서 제거되고 다음 이미지/레이블이 다음에 호출 할 때 대기열에서 제외 된 해당 위치로 이동됩니다 .

현재이 함수는 엘리먼트를 dequeing하기위한 tensorflow 연산 만 리턴합니다. 트레이닝 루프에서 소모 할 수 있습니다.