2017-12-14 12 views
1

아래 코드를 사용하여 tfrecords 파일을 만들 수 있습니다. 나는 다음과 같은 오류를 얻고있다Tensorflow tfrecord 파일에서 읽을 수 없습니다.

def read_and_decode(filename_queue): 
    reader = tf.TFRecordReader() 
    _, serialized_example = reader.read(filename_queue) 
    img_features = tf.parse_single_example(
     serialized_example, 
     features={ 
      'height': tf.FixedLenFeature([], tf.int64), 
      'width': tf.FixedLenFeature([], tf.int64), 
      'depth': tf.FixedLenFeature([], tf.int64), 
      'image_raw': tf.FixedLenFeature([], tf.string), 
      'label': tf.FixedLenFeature([], tf.int64), 
     }) 

    image = tf.decode_raw(img_features['image_raw'], tf.float32) 
    label = tf.cast(img_features['label'], tf.int32) 
    height = tf.cast(img_features['height'], tf.int32) 
    width = tf.cast(img_features['width'], tf.int32) 
    depth = tf.cast(img_features['depth'], tf.int32) 
    image_shape = tf.stack([depth,height, width]) 
    image = tf.reshape(image, image_shape) 
    return image,label 

def inputs(batch_size, num_epochs): 
    filename = ['set1.tfrecords'] 
    # dir_path is a global variable 
    file_path = dir_path + 'set1.tfrecords' 
    filename_queue = tf.train.string_input_producer([file_path], num_epochs=1) 
    image,label = read_and_decode(filename_queue) 
    images, sparse_labels = tf.train.shuffle_batch(
     [image, label], batch_size=batch_size, num_threads=2, 
     capacity=1000 + 3 * batch_size, min_after_dequeue=1000) 
    return images, sparse_labels 

아래의 기능을 사용하여 tfrecord 파일에서 데이터를 읽는 동안

def _int64_feature(value): 
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 
def _bytes_feature(value): 
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 

def convert_to_tfrecord(images,labels,file_name): 
    # images is a numpy array of shape (num_images,channel,rows,column) 
    # labels is a numpy array of shape (num_images,) 
    num_labels = np.shape(labels) 
    (num_images,depth,rows,cols) = np.shape(images) 
    writer = tf.python_io.TFRecordWriter(file_name) 
    for index in range(num_images): 
     image_raw = images[index] 
     image_raw = image_raw.astype(np.float32) 
     image_raw = image_raw.tostring() 
     example = tf.train.Example(features=tf.train.Features(feature={ 
      'height': _int64_feature(rows), 
      'width': _int64_feature(cols), 
      'depth': _int64_feature(depth), 
      'label': _int64_feature(int(labels[index])), 
      'image_raw': _bytes_feature(image_raw)})) 

     writer.write(example.SerializeToString()) 
    writer.close() 

는하지만, 지속적으로

images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1, min_after_dequeue=10) 

File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/input.py", line 1225, in shuffle_batch 
name=name) 

File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/input.py", line 781, in _shuffle_batch 
dtypes=types, shapes=shapes, shared_name=shared_name) 

File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/data_flow_ops.py", line 641, in __init__ 
shapes = _as_shape_list(shapes, dtypes) 

File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/data_flow_ops.py", line 77, in _as_shape_list 
raise ValueError("All shapes must be fully defined: %s" % shapes) 

ValueError: All shapes must be fully defined: [TensorShape([Dimension(None)]), TensorShape([])] 

위의 오류에 대한 이유 무엇입니까 그리고 이것을 극복하는 방법? tf.python_io.tf_record_iterator(path=filename)을 사용하여 파일을 반복하여 tfrecords 파일을 읽을 수 있습니다.

+0

을 오류가 어떤 언급되지 않은 'read_and_decode'에없는'tf.train.shuffle_batch'와 관련이 있기 때문에 여러분이 올린 코드의 행을 찾으십시오. – GPhilo

+0

좋습니다. 그 부분을 포함하는 다른 기능을 추가했습니다. @GPhilo –

+0

이미지가 모두 같은 크기입니까, 아니면 크기가 다를 수 있습니까? – GPhilo

답변

2

tf.train.shuffle_batch은 일괄 처리 할 수 ​​있으려면 텐서스의 모양을 알아야하기 때문에 오류가 발생합니다 (일괄 처리의 항목은 모두 같은 모양이어야합니다). 그러나 원시 데이터의 크기가 다를 수 있으므로 tf.decode_raw은 텐서의 모양을 설정하지 않습니다. 코멘트에서

, 당신은 모든 이미지 모양 (192,81,2)을 언급, 그래서 당신은 단지 read_and_decode에서 반환하기 전에 이미지 텐서에서 그 모양을 설정해야합니다

def read_and_decode(filename_queue): 
    # rest of your code here 
    image_shape = [height, width, depth] 
    image = tf.reshape(image, image_shape) 
    image.set_shape(image_shape) #<<<<<<<<<<<<<<< 
    return image,label 
+0

고마워요. 그것은 일했다 :) –

+0

그것이 도움이 되었기 때문에 기쁘다! 답변을 해결 된 것으로 표시하십시오. 이렇게하면 영원히 열리지 않습니다 (사람들은 실제로 비슷한 답변을 여기에 가리키는 것과 같이 표시 할 수 있습니다) – GPhilo