2017-10-04 4 views
1

두 개체를 분류하려고합니다. evaluate.py 스크립트에서 Accuracy와 Cross Entropy를 얻고 싶습니다.TensorFlow 메서드에 대한 잘못된 인수 전달

다음은 내가 시도하는 코드입니다. 나는 다음과 같은 오류가 예측 위의 스크립트를 실행할 때

evaluate.py (by tensorflow for poets) 
#!/usr/bin/python 
# 
# Copyright 2017 Google Inc. 
# 
# Licensed under the Apache License, Version 2.0 (the "License"); 
# you may not use this file except in compliance with the License. 
# You may obtain a copy of the License at 
# 
#  http://www.apache.org/licenses/LICENSE-2.0 
# 
# Unless required by applicable law or agreed to in writing, software 
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 
implied. 
# See the License for the specific language governing permissions and 
# limitations under the License. 
from __future__ import absolute_import 
from __future__ import division 
from __future__ import print_function 

import os 

import sys 
import argparse 

import numpy as np 
import PIL.Image as Image 
import tensorflow as tf 

import scripts.retrain as retrain 
from scripts.count_ops import load_graph 


def evaluate_graph(graph_file_name): 
    with load_graph(graph_file_name).as_default() as graph: 
    ground_truth_input = tf.placeholder(
     tf.float32, [None, 5], name='GroundTruthInput') 

    image_buffer_input = graph.get_tensor_by_name('input:0') 
    final_tensor = graph.get_tensor_by_name('final_result:0') 
    accuracy, _ = retrain.add_evaluation_step(final_tensor, ground_truth_input) 

    logits = graph.get_tensor_by_name("final_training_ops/Wx_plus_b/add:0") 
    xent = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
     labels=ground_truth_input, 
     logits=logits)) 

#image_dir = 'tf_files/flower_photos' 
image_dir = 'tf_files/test_images' 
testing_percentage = 10 
validation_percentage = 10 
validation_batch_size = 100 
category = 'testing' 

image_lists = retrain.create_image_lists(
    image_dir, testing_percentage, 
    validation_percentage) 
class_count = len(image_lists.keys()) 

ground_truths = [] 
filenames = [] 

for label_index, label_name in enumerate(image_lists.keys()): 
    for image_index, image_name in enumerate(image_lists[label_name][category]): 
     image_name = retrain.get_image_path(
      image_lists, label_name, image_index, image_dir, category) 
     ground_truth = np.zeros([1, class_count], dtype=np.float32) 
     ground_truth[0, label_index] = 1.0 
     ground_truths.append(ground_truth) 
     filenames.append(image_name) 

accuracies = [] 
xents = [] 
with tf.Session(graph=graph) as sess: 
    for filename, ground_truth in zip(filenames, ground_truths): 
     image = Image.open(filename).resize((224, 224), Image.ANTIALIAS) 
     image = np.array(image, dtype=np.float32)[None, ...] 
     image = (image - 128)/128.0 

     feed_dict = { 
      image_buffer_input: image, 
      ground_truth_input: ground_truth} 

     eval_accuracy, eval_xent = sess.run([accuracy, xent], feed_dict) 

     accuracies.append(eval_accuracy) 
     xents.append(eval_xent) 

return np.mean(accuracies), np.mean(xents) 


if __name__ == "__main__": 
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
    accuracy, xent = evaluate_graph(*sys.argv[1:]) 
    print('Accuracy: %g' % accuracy) 
    print('Cross Entropy: %g' % xent) 

그러나 :

ValueError: Cannot feed value of shape (1, 224, 224) for Tensor u'input:0', which has shape '(1, 224, 224, 3)'

어떻게이 오류를 해결할 수 있습니까?

+0

그레이 스케일 이미지 (1 채 널)를 RGB (3 채 널)가 필요한 입력에 공급하려고합니다. – lejlot

답변

1

회색 음영 이미지를 입력 자리 표시 자로 보내는 것처럼 보입니다. 그레이 스케일 이미지는 이미지 actualy RGB를하는 경우, 당신이있을 수 있습니다 당신이하려고하는 pretrained 네트워크는 3 개 채널 RGB 이미지를 필요로하는 동안 따라서 모양 (224, 224,) (크기 1의 차원 생략) (224, 224, 3)

모양 만 1 개 채널을 가지고 여기에 오류 :

image = np.array(image, dtype=np.float32)[None, ...] 

이 색인 : [None, ...]이 필요하지 않는 것 같습니다. 이미지가 실제로 그레이 스케일이 있다면, 당신은 PIL.convert() 사용하여 RGB 형식으로 변환 할 수 있습니다

(하나 개의 채널이 반복됩니다 3 회) :

image = image.convert("RGB") 

비록 채널 중복 CNN은 비효율적 3 채널을 실행하는 (계산 같은 데이터에 대해 3 번 수행됨) 컬러 이미지보다 성능이 떨어지면 스크립트를 실행해야하며 빠르게 추적 할 수 있습니다.

+0

주셔서 감사합니다. 그레이 스케일 이미지를 사용했습니다. – Jun

+0

@Jun RGB 이미지를 "가짜"만드는 방법에 대한 정보를 주셨습니다. – Drop