0

저는 심도있는 학습을하고 알고리즘이 어떻게 작동하는지 이해하려고 노력하며 JavaScript로 작성합니다. 이제는 Tensorflow처럼 conv2d의 JavaScript 구현을 위해 노력하고 있습니다. 여러 개의 필터를 처리하는 방법을 오해하고, 하나의 출력 필터와 다중 출력에 성공했습니다. 그러나 여러 필터 입력으로 작업을 생성하는 방법을 혼동합니다. 시험에 대한Tensorflow, conv2d and filters

const outCount = 32 // count of inputs filters 
const inCount = 1 // count of output features 
const filterSize = 3 
const stride = 1 
const inShape = [1, 10, 10, outCount] 
const outShape = [ 
    1, 
    Math.ceil((inShape[1] - filterSize + 1)/stride), 
    Math.ceil((inShape[2] - filterSize + 1)/stride), 
    outCount 
]; 
const filters = ndarray([], [filterSize, filterSize, inCount, outCount]) 

const conv2d = (input) => { 
    const result = ndarray(outShape) 
    // for each output feature 

    for (let fo = 0; fo < outCount; fo += 1) { 
    for (let x = 0; x < outShape[1]; x += 1) { 
     for (let y = 0; y < outShape[2]; y += 1) { 
     const fragment = ndarray([], [filterSize, filterSize]); 
     const filter = ndarray([], [filterSize, filterSize]); 

     // agregate fragment of image and filter 
     for (let fx = 0; fx < filterSize; fx += 1) { 
     for (let fy = 0; fy < filterSize; fy += 1) { 
      const dx = (x * stride) + fx; 
      const dy = (y * stride) + fy; 

      fragment.data.push(input.get(0, dx, dy, 0)); 
      filter.data.push(filters.get(fx, fy, 0, fo)); 
     } 
     } 

     // calc dot product of filter and image fragment 
     result.set(0, x, y, fo, dot(filter, fragment)); 
     } 
    } 
    } 

    return result 
} 

내가 사실과 알고리즘의 소스로 Tenforflow을 사용하고 올바른 작동하지만 1 -> N으로 : - 32 여기> 64

ndarray 사용 코드의 예입니다. 하지만 내 질문에 N -> M 같은 입력 값에 여러 필터 지원을 추가하는 방법.

누군가가 Tensorflow와 더 호환되도록이 알고리즘을 수정하는 방법을 설명 할 수 있습니까 tf.nn.conv2d 감사합니다.

답변

0

for 루프를 추가해야합니다. 모든 입력 모양과 치수를 지정하지 않았으므로 정확하게 쓰는 것은 어렵지만 실제로는 이렇게 보일 것입니다.

// agregate fragment of image and filter 
    for (let fx = 0; fx < filterSize; fx += 1) { 
    for (let fy = 0; fy < filterSize; fy += 1) { 
     //addition 
     for (let ch = 0; ch < input.get_channels) { 
     const dx = (x * stride) + fx; 
     const dy = (y * stride) + fy; 

     fragment.data.push(input.get(0, dx, dy, ch)); 
     filter.data.push(filters.get(fx, fy, ch, fo)); 
     } 
    } 
    } 
+0

정말 고맙습니다. – XMANX