2016-12-17 1 views
2

저장된 파일을 구문 분석하여 훈련 된 모델의 매개 변수 (컨볼 루션 및 완전히 연결된 레이어의 가중치 및 바이어스)를 다른 프레임 워크 또는 iOS 및 Torch를 포함한 언어로 전달하고 싶습니다.TensorFlow : 훈련 된 모델 매개 변수를 다른 프레임 워크로 가져올 수있는 파일에 저장하는 방법은 무엇입니까?

나는 tf.train.write_graph(session.graph_def, '', 'graph.pb')을 시도했지만 가중치 및 바이어스가없는 그래프 아키텍처 만 포함 된 것으로 보입니다. 그렇다면 검사 점 파일 (saver.save(session, "model.ckpt"))을 만드는 것이 가장 좋습니다. ckpt 파일 유형을 Swift 또는 다른 언어로 쉽게 파싱 할 수 있습니까?

의견이 있으면 알려주십시오.

답변

1

.ckpt 파일을 구문 분석하는 대신 텐서 (경우에 따라 길쌈 레이어의 가중치)를 평가하고 수치가 배열로 표시되도록 할 수 있습니다. 여기에 빠른 장난감 예이다 (r0.10 테스트는 -이 수도 최신 버전에서 몇 가지 작은 API 변경) :

import tensorflow as tf 
import numpy as np 

x = tf.placeholder(np.float32, [2,1]) 
w = tf.Variable(tf.truncated_normal([2,2], stddev=0.1)) 
b = tf.Variable(tf.constant(1.0, shape=[2,1])) 
z = tf.matmul(w, x) + b 

with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 
    w_val, z_val = sess.run([w, z], feed_dict={x: np.arange(2).reshape(2,1)}) 
    print(w_val) 
    print(z_val) 

출력 :

[[-0.02913031 0.13549708] 
[ 0.13807134 0.03763327]] 
[[ 1.13549709] 
[ 1.0376333 ]] 

당신은 문제가 텐서에 대한 참조를 받고있는 경우 (상위 계층 "계층"작업에 중첩되어 있다고 가정) 이름으로 찾기를 시도하십시오. 여기에 더 많은 정보는 : 당신이 무게가 훈련 도중 변경하는 방법을 보려면 Tensorflow: How to get a tensor by name?

, 당신은 또한 당신이 tf.Summary 객체에 관심이있는 모든 값을 저장하려고 할 수 있으며 나중에 분석 : Parsing `summary_str` byte string evaluated on tensorflow summary object

+0

당신을 감사하십시오 내가 분명하게 해줘. 위의 결과를 얻은 후에 TensorFlow 사용자가 일반적으로 일부 파일 (.dat 또는 .pb)을 사용하여 다른 플랫폼을 통과하기 위해이 값을 저장하는 방법은 무엇입니까? – kangaroo

+0

다른 플랫폼이 파이썬 기반이라면, 가장 빠른 방법은 피클 파일 (또는 numpy 배열) 파일을 만드는 것입니다. 아마도 더 많은 크로스 플랫폼을 원한다면 아마도 HDF5와 같은 바이너리 형식이 더 적절할 것입니다 (저는 개인적으로 HDF5를 사용하여 일부 데이터 세트를 배포합니다). –

+0

참조 https://stackoverflow.com/questions/34343259/is-there-an-example-on-how-to-generate-protobuf-files-holding-trained-tensorflow – Cristi