2017-03-22 5 views
1

많은 나무에 대해 xgboost.dump 텍스트 파일이 있습니다. 모든 경로를 찾아 각 경로의 값을 가져 오려고합니다. 다음은 나무 중 하나입니다. ,xgboost.dump에서 이진 트리의 모든 경로를 찾으십시오.

tree[0]: 
0:[a<0.966398] yes=1,no=2,missing=1 
    1:[b<0.323071] yes=3,no=4,missing=3 
     3:[c<0.461248] yes=7,no=8,missing=7 
      7:leaf=0.00972768 
      8:leaf=-0.0179376 
     4:[a<0.379082] yes=9,no=10,missing=9 
      9:leaf=0.0146003 
      10:leaf=0.0454369 
    2:[b<0.322352] yes=5,no=6,missing=5 
     5:[c<0.674868] yes=11,no=12,missing=11 
      11:leaf=0.0497964 
      12:leaf=0.00953781 
     6:[f<0.598267] yes=13,no=14,missing=13 
      13:leaf=0.0504545 
      14:leaf=0.0867654 

나는 이미

array([[ 0, 1, 3, 7], 
     [ 0, 1, 3, 8], 
     [ 0, 1, 4, 9], 
     [ 0, 1, 4, 10], 
     [ 0, 2, 5, 11], 
     [ 0, 2, 5, 12], 
     [ 0, 2, 6, 13], 
     [ 0, 2, 6, 14]]) 

등의 가능한 모든 경로를 나열하는 것을 시도했다

path1, a<0.966398, b<0.323071, c<0.461248, leaf = 0.00097268 
path2, a<0.966398, b<0.323071, c>0.461248, leaf = -0.0179376 
path3, a<0.966398, b>0.323071, a<0.379082, leaf = 0.0146003 
path4, a<0.966398, b>0.323071, a>0.379082, leaf = 0.0454369 
path5, a>0.966398, b<0.322352, c<0.674868, leaf = 0.0497964 
path6, a>0.966398, b<0.322352, c>0.674868, leaf = 0.00953781 
path7, a>0.966398, b>0.322352, f<0.598267, leaf = 0.0504545 
path8, a>0.966398, b>0.322352, f>0.598267, leaf = 0.0864654 

에 모든 경로를 변환 할 그러나 MAX_DEPTH 높은되면이 방법은 오류로 이어질 것 일부 지점은 성장을 멈추고 길은 잘못 될 것입니다. 그래서 실제, 올바른 경로를 생성하려면 텍스트 파일에서 yes, no를 구문 분석해야합니다. 제안 사항이 있으십니까? 감사합니다.

답변

0

다음은 R 구현을 사용하여이 문제에 접근 한 방법입니다. 다른 언어를 사용하는 사용자는 논리를 따르고 현물로 복제 할 수 있습니다.

먼저, xgb.model.dt.tree()가 생성 한 모델 덤프 파일로 시작했습니다.

그런 다음 임의 노드에서 덤프 된 모델의 개별 트리 내에서 최종 부모를 향한 유효한 경로를 구문 분석하는 함수를 작성했습니다.

나중에이 함수를 purrr :: by_row()를 사용하여 모델 덤프의 모든 터미널 노드 "리프"레코드에 적용하고 결과를 변환합니다.

이 함수는 두 개의 인수를 취합니다. 하나는 테스트 할 트리이고 다른 하나는 터미널 노드의 ID입니다.

  1. 트리 단위로 대상 (단말기) 노드부터 시작하여 c ("예" "아니오") 노드 중 유효한 노드로 대상 노드가있는 행을 찾으십시오. , "Missing") 결정이 나뉩니다.
  2. 이 유효한 부모 노드 ID를 대상 노드에서 최종 상위 노드까지의 경로의 각 단계를 추적하는 데 사용할 벡터에 연결합니다. 이 벡터는 함수가 완료 될 때 반환됩니다.
  3. 다음으로 경로가 궁극적 인 상위 노드 (이 노드 ID는 항상 "-0"로 끝남)에 도달 할 때까지 체인의 각 노드에 대해 "누가 내 부모인가"단계를 반복하면서 새로운 단계마다 경로 벡터를 업데이트합니다. 사슬.
  4. 함수가 터미널 노드에 도달하면 경로를 return()합니다.

필자는 purrr :: by_row() 및 .collating = "rows"를 사용하여 모델 덤프의 모든 "리프"노드에이 함수를 적용하여 경로를 출력에서 ​​추가 행으로 나타냅니다.

이것은 가능한 가장 빠른 방법이 아닙니다.

xgb.booster 모델의 nrounds 또는 max_depth가 증가하면이 프로세스의 실행 시간이 늘어납니다. 최종 모델에서 터미널 노드 전체 경로를 파싱하는 데 필요한 시간을 예상 할 수 있도록 트리 서브 세트 (xgb.model.dt.tree()의 인수 n_first_tree = N)를 사용하여 메소드를 개발할 수 있습니다. 필자의 경우 max_depth = 5에 500 나무가있는 모델은 30 분 이상 걸릴 수 있습니다.