날씬한 문서에 제공된 사전 훈련 된 ResNet V2 모델을 사용하는 이미지 분류기를 만들려고합니다.Pretrained ResNet V2 모델을 사용하여 Slim 분류 자 생성
import tensorflow as tf
slim = tf.contrib.slim
from PIL import Image
from inception_resnet_v2 import *
import numpy as np
checkpoint_file = 'inception_resnet_v2_2016_08_30.ckpt'
sample_images = ['carrot.jpg']
input_tensor = tf.placeholder(tf.float32, shape=(None,299,299,3), name='input_image')
scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor)
scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5)
scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0)
variables_to_restore = slim.get_model_variables()
print(variables_to_restore)
init_fn = slim.assign_from_checkpoint_fn(
checkpoint_file,
slim.get_model_variables('InceptionResnetV2'))
sess = tf.Session()
init_fn(sess)
arg_scope = inception_resnet_v2_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = inception_resnet_v2(scaled_input_tensor, is_training=False)
for image in sample_images:
im = Image.open(image).resize((299,299))
im = np.array(im)
im = im.reshape(-1,299,299,3)
predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_tensor: im})
print (np.max(predict_values), np.max(logit_values))
print (np.argmax(predict_values), np.argmax(logit_values))
문제는 내가이 오류가 계속입니다 : 여기
지금까지 코드Traceback (most recent call last):
File "./classify.py", line 21, in <module>
slim.get_model_variables('InceptionResnetV2'))
File "/home/ubuntu/tensorflow/local/lib/python2.7/site-packages/tensorflow/contrib/framework/python/ops/variables.py", line 584, in assign_from_checkpoint_fn
saver = tf_saver.Saver(var_list, reshape=reshape_variables)
File "/home/ubuntu/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1040, in __init__
self.build()
File "/home/ubuntu/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1061, in build
raise ValueError("No variables to save")
ValueError: No variables to save
그래서 TF는 것을/슬림 어떤 변수를 찾을 수 없습니다 이것은 만들어 전화 할 때 명확함 :
variables_to_restore = slim.get_model_variables()
print(variables_to_restore)
빈 배열을 출력합니다.
사전 교육을받은 모델을 사용하려면 어떻게해야합니까?