2017-12-04 18 views
0

여기에서 두 문장의 의미 론적 유사성을 비교하는 siaseme LSTM의 결과를 재현하려고 : - https://github.com/dhwajraj/deep-siamese-text-similarityTypeError : Fetch 인수 배열의 형식이 잘못되었습니다. numpy.ndarray, 문자열 또는 Tensor 여야합니다. (텐서 또는 조작에 ndarray를 변환 할 수 없습니다.)

을 내가 tensorflow 1.4 & 파이썬을 사용하고 있습니다 2.7

train.py가 제대로 작동하고 있습니다. 모델을 평가하기 위해 거기에서 사용할 수있는 "train_snli.txt"의 하위 집합 인 match_valid.tsv 파일을 만들었습니다. input_helpers.py 파일에있는 getTsvTestData 함수를 수정했습니다.

def getTsvTestData(self, filepath): 
     print("Loading testing/labelled data from "+filepath+"\n") 
     x1=[] 
     x2=[] 
     y=[] 
     # positive samples from file 
     for line in open(filepath): 
      l=line.strip().split("\t") 
      if len(l)<3: 
       continue 
      x1.append(l[1].lower()) # text 
      x2.append(l[0].lower()) # text 
      y.append(int(l[2])) # similarity score 0 or 1 
     return np.asarray(x1),np.asarray(x2),np.asarray(y) 

나는이 오류를 얻고있다 eval.py

for db in batches: 
      x1_dev_b,x2_dev_b,y_dev_b = zip(*db) 
      #x1_dev_b = tf.convert_to_tensor(x1_dev_b,) 
      print("type x1_dev_b {}".format(type(x1_dev_b))) # tuple 
      print("type x2_dev_b {}".format(type(x2_dev_b))) # tuple 
      print("type y_dev_b {}\n".format(type(y_dev_b))) # tuple 

      feed = {input_x1: x1_dev_b, 
        input_x2: x2_dev_b, 
        input_y:y_dev_b, 
        dropout_keep_prob: 1.0} 

      batch_predictions, batch_acc, sim = sess.run([predictions,accuracy,sim], feed_dict=feed) 

      print("type batch_predictions {}".format(type(batch_predictions))) # numpy.ndarray 
      print("type batch_acc {}".format(type(batch_acc))) # numpy.float32 
      print("type sim {}".format(type(sim))) # numpy.ndarray 

      all_predictions = np.concatenate([all_predictions, batch_predictions]) 

      print("\n printing batch predictions {} \n".format(batch_predictions)) 

      all_d = np.concatenate([all_d, sim]) 

      print("DEV acc {} \n".format(batch_acc)) 

코드의이 부분에서 오류가 발생하고있다. 형식을 찾으려면 sess.run()에서 print 문을 사용하려고했지만 작동하지 않습니다.

Traceback (most recent call last): 
    File "eval.py", line 92, in <module> 
    batch_predictions, batch_acc, sim = sess.run([predictions,accuracy,sim], feed_dict=feed) 
    File "/home/joe/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 889, in run 
    run_metadata_ptr) 
    File "/home/joe/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1105, in _run 
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles) 
    File "/home/joe/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 414, in __init__ 
    self._fetch_mapper = _FetchMapper.for_fetch(fetches) 
    File "/home/joe/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 234, in for_fetch 
    return _ListFetchMapper(fetch) 
    File "/home/joe/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 341, in __init__ 
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 
    File "/home/joe/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 242, in for_fetch 
    return _ElementFetchMapper(fetches, contraction_fn) 
    File "/home/joe/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 275, in __init__ 
    % (fetch, type(fetch), str(e))) 
TypeError: Fetch argument array([ 1., 1., 0., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 
     0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 
     0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 
     0., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 
     1., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 0., 
     0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 0., 
     0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 
     0., 0., 1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 0., 
     0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 
     0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
     1., 0., 0., 1., 0., 0., 1., 0., 1., 1., 0., 1., 0., 
     0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 
     1., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0., 0., 1., 
     1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0., 
     0., 1., 0., 0., 1., 0., 0., 1., 1., 0., 0., 1., 0., 
     0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
     0., 1., 0., 0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 
     0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 
     1., 0., 0., 1., 0., 1., 0., 0., 0.], dtype=float32) has invalid type <type 'numpy.ndarray'>, must be a string or Tensor. (Can not convert a ndarray into a Tensor or Operation.) 

사실, 난 내 신체의 모든 문서 벡터에 대한 질의 벡터를 비교, 질의 유사성을하고 유사성 점수에 따라 문장의 순위를하려합니다. 나는 현재 LSTM이 두 문장을 서로 비교하고 유사도를 0 또는 1로 출력한다는 것을 알고 있습니다. 어떻게 할 수 있습니까?

+0

'sim' : 당신은 이런 식으로 뭔가를 시도 할 수 있습니다

? 적어도 하나는'numpy' 배열이고 텐서/연산은 아닙니다. 데이터를로드 할 때 우연히 그 중 하나를 재정의 할 수 있습니까? – GPhilo

+1

예,이 문제를 일으키는 sim을 재정의했습니다. 이제 해결되었습니다. – joel

답변

2

문제는 처음에 TensorFlow 텐서 또는 연산에 대한 참조를 NumPy 배열로 평가 한 결과 인 sim의 값을 바꾸는 것이므로 두 번째 반복 sim이 TensorFlow 텐서 또는 동작이 아니기 때문에 실패합니다.

는`predictions`,`accuracy`과의 정의가 무엇
for db in batches: 
      x1_dev_b,x2_dev_b,y_dev_b = zip(*db) 
      #x1_dev_b = tf.convert_to_tensor(x1_dev_b,) 
      print("type x1_dev_b {}".format(type(x1_dev_b))) # tuple 
      print("type x2_dev_b {}".format(type(x2_dev_b))) # tuple 
      print("type y_dev_b {}\n".format(type(y_dev_b))) # tuple 

      feed = {input_x1: x1_dev_b, 
        input_x2: x2_dev_b, 
        input_y:y_dev_b, 
        dropout_keep_prob: 1.0} 

      batch_predictions, batch_acc, batch_sim = sess.run([predictions,accuracy,sim], feed_dict=feed) 

      print("type batch_predictions {}".format(type(batch_predictions))) # numpy.ndarray 
      print("type batch_acc {}".format(type(batch_acc))) # numpy.float32 
      print("type batch_sim {}".format(type(batch_sim))) # numpy.ndarray 

      all_predictions = np.concatenate([all_predictions, batch_predictions]) 

      print("\n printing batch predictions {} \n".format(batch_predictions)) 

      all_d = np.concatenate([all_d, batch_sim]) 

      print("DEV acc {} \n".format(batch_acc)) 
+0

안녕하세요 @jdehesa 그것을 가리키는 주셔서 감사합니다. 지금 일하고있어. – joel