2017-04-25 6 views
1

날씬한 문서에 제공된 사전 훈련 된 ResNet V2 모델을 사용하는 이미지 분류기를 만들려고합니다.Pretrained ResNet V2 모델을 사용하여 Slim 분류 자 ​​생성

import tensorflow as tf 
slim = tf.contrib.slim 
from PIL import Image 
from inception_resnet_v2 import * 
import numpy as np 

checkpoint_file = 'inception_resnet_v2_2016_08_30.ckpt' 
sample_images = ['carrot.jpg'] 

input_tensor = tf.placeholder(tf.float32, shape=(None,299,299,3), name='input_image') 
scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor) 
scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5) 
scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0) 


variables_to_restore = slim.get_model_variables() 
print(variables_to_restore) 

init_fn = slim.assign_from_checkpoint_fn(
     checkpoint_file, 
     slim.get_model_variables('InceptionResnetV2')) 

sess = tf.Session() 
init_fn(sess) 
arg_scope = inception_resnet_v2_arg_scope() 
with slim.arg_scope(arg_scope): 
    logits, end_points = inception_resnet_v2(scaled_input_tensor, is_training=False) 

for image in sample_images: 
    im = Image.open(image).resize((299,299)) 
    im = np.array(im) 
    im = im.reshape(-1,299,299,3) 
    predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_tensor: im}) 
    print (np.max(predict_values), np.max(logit_values)) 
    print (np.argmax(predict_values), np.argmax(logit_values)) 

문제는 내가이 오류가 계속입니다 : 여기

지금까지 코드

Traceback (most recent call last): 
    File "./classify.py", line 21, in <module> 
    slim.get_model_variables('InceptionResnetV2')) 
    File "/home/ubuntu/tensorflow/local/lib/python2.7/site-packages/tensorflow/contrib/framework/python/ops/variables.py", line 584, in assign_from_checkpoint_fn 
    saver = tf_saver.Saver(var_list, reshape=reshape_variables) 
    File "/home/ubuntu/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1040, in __init__ 
    self.build() 
    File "/home/ubuntu/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1061, in build 
    raise ValueError("No variables to save") 
ValueError: No variables to save 

그래서 TF는 것을/슬림 어떤 변수를 찾을 수 없습니다 이것은 만들어 전화 할 때 명확함 :

variables_to_restore = slim.get_model_variables() 
    print(variables_to_restore) 

빈 배열을 출력합니다.

사전 교육을받은 모델을 사용하려면 어떻게해야합니까?

답변

0

"InceptionResnetV2"라는 이름으로 시작하는 변수를 저장하지 않고 그래프에 모델을 아직 구성하지 않았기 때문에 이러한 현상이 발생합니다.

slim.get_variables_to_restore()을 사용하기 전에 모델 구성을해야한다고 생각합니다. 예를 들어

:

with slim.arg_scope(arg_scope): 
    logits, end_points = inception_resnet_v2(scaled_input_tensor, is_training=False) 

variables_to_restore = slim.get_model_variables() 

이 방법 텐서 변수를 구축 할 것입니다 그리고 당신은 variables_to_restore가 더 이상 비어 볼 수 없습니다.

0

수동으로 모델 변수를 추가해야합니다.

with slim.arg_scope(arg_scope): 
    logits, end_points = inception_resnet_v2(scaled_input_tensor, is_training=False) 

# Add model variables 
for var in tf.global_variables(scope='inception_resnet_v2'): 
    slim.add_model_variable(var) 
시도