2

방금 ​​기계 학습 및 신경망을 연구하기 시작 했으므로 backpropagation의 작동 방식을 이해하는 데 어려움을 겪고 있습니다. 간단한 매트릭스 기반 접근법을 사용하여 Java에서 간단한 NN을 개발하려고했습니다. 하나의 교육 예제 만 입력하면 네트워크가 완벽하게 작동하지만 더 많이 사용하려고하면 출력은 항상 원하는 교육 출력의 평균입니다. http://neuralnetworksanddeeplearning.com/images/tikz21.pngBackpropagation Neural Network가 작동하지 않습니다.

package neuralnetwork; 
/** 
* @author Paolo Pellizzoni 
*/ 

public class NeuralNetwork { 

static final int in_l = 2; 
static final int h_l = 5; 
static final int out_l = 1; 

public static double[][] w2 = new double[h_l][in_l]; 
public static double[][] w3 = new double[out_l][h_l]; 
public static double[] b2 = new double[h_l]; 
public static double[] b3 = new double[out_l]; 

public static double[][] x = {{3,4},{2,3}}; 
public static double[][] y = {{0.3,0.7}}; 
public static double[][] test = {{3}, {2}}; 
// using x = {{3},{2}} and y = {{0.3}} it works 

    public static void main(String[] args) { 
     trainNN(0.2); 
     double[][] m = a_3(test); 

     for(int i=0; i<m.length; i++){ 
      for(int j=0; j<m[0].length; j++){ 
     System.out.print(m[i][j]+" "); 
      } 
      System.out.println(); 
    } 
    } 
    // ---------- FUNCTIONS ---------- 

    static void inizialize_weights(double[][] m){ 
    for(int i=0; i<m.length; i++){ 
      for(int j=0; j<m[0].length; j++){ 
     m[i][j]= Math.random(); 
      } 
    } 
    } 
    static void trainNN(double rate){ 
     inizialize_weights(w2); 
     inizialize_weights(w3); 

     for(int c=0; c<500; c++){ 
      double[][] dJ_w3 = dJ_w3(x, y); 
      double[][] dJ_w2 = dJ_w2(x, y); 
      double[] dJ_b3 = dJ_b3(x, y); 
      double[] dJ_b2 = dJ_b2(x, y); 
      w3 = matrix_sum(w3, dJ_w3, -rate); 
      w2 = matrix_sum(w2, dJ_w2, -rate); 
      b3 = vect_sum(b3, dJ_b3, -rate); 
      b2 = vect_sum(b2, dJ_b2, -rate); 
     } 
    } 

    static double[][] a_3(double[][] inputs){ 
     return sigmoid(z_3(inputs)); 
    } 
    static double[][] z_3(double[][] inputs){ 
     return matrix_sum_vect(matrix_product(w3, a_2(inputs)), b3, 1); 
    } 
    static double[][] a_2(double[][] inputs){ 
     return sigmoid(z_2(inputs)); 
    } 
    static double[][] z_2(double[][] inputs){ 
     return matrix_sum_vect(matrix_product(w2, inputs), b2, 1); 
    } 

    static double[][] delta3 (double[][] inputs, double[][] y){ 
     return matrix_hadamard(
       matrix_sum(a_3(inputs), y, -1), 
       sigmoid_prime(z_3(inputs)) 
     ); 
    } 
    static double[][] delta2 (double[][] inputs, double[][] y){ 
     return matrix_hadamard(
       matrix_product(
         transpose_matrix(w3), 
         delta3(inputs, y)), 
       sigmoid_prime(z_2(inputs)) 
     ); 
    } 
    static double[][] dJ_w3 (double[][] inputs, double[][] y){ 
     double[][] dJ_w3 = new double[out_l][h_l]; 
     double[][] delta3 = delta3(inputs, y); 
     double[][] a2 = a_2(inputs); 
     for(int i=0; i<delta3.length; i++){ 
      for(int j=0; j<a2.length; j++){ 
       double tmp = 0; 
       for(int k=0; k<a2[0].length; k++){ 
        tmp += a2[j][k]*delta3[i][k]; 
       } 
       dJ_w3[i][j] = tmp/a2[0].length; 
      } 
     } 

     return dJ_w3; 
    } 
    static double[][] dJ_w2 (double[][] inputs, double[][] y){ 
     double[][] dJ_w2 = new double[h_l][in_l]; 
     double[][] delta2 = delta2(inputs, y); 
     double[][] a1 = inputs; 

     for(int i=0; i<delta2.length; i++){ 
      for(int j=0; j<a1.length; j++){ 
       double tmp = 0; 
       for(int k=0; k<a1[0].length; k++){ 
        tmp += a1[j][k]*delta2[j][k]; 
       } 
       dJ_w2[i][j] = tmp/a1[0].length; 
      } 
     } 

     return dJ_w2; 
    } 
    static double[] dJ_b3 (double[][] inputs, double[][] y){ 
     double[] dJ_b3 = new double[out_l]; 
     double[][] delta3 = delta3(inputs, y); 
     for(int i=0; i<delta3.length; i++){ 
      double tmp = 0; 
      for(int k=0; k<delta3[0].length; k++){ 
       tmp += delta3[i][k]; 
      } 
      dJ_b3[i] = tmp/delta3[0].length; 
     } 

     return dJ_b3; 
    } 
    static double[] dJ_b2 (double[][] inputs, double[][] y){ 
     double[] dJ_b2 = new double[h_l]; 
     double[][] delta2 = delta2(inputs, y); 
     for(int i=0; i<delta2.length; i++){ 
      double tmp = 0; 
      for(int k=0; k<delta2[0].length; k++){ 
       tmp += delta2[i][k]; 
      } 
      dJ_b2[i] = tmp/delta2[0].length; 
     } 

     return dJ_b2; 
    } 


    // ----- Math ----- 


    static double[][] matrix_product(double[][] a, double[][] b){ // matrix multiplication 
     int m1ColLength = a[0].length; 
     int m2RowLength = b.length; 
     if(m1ColLength != m2RowLength) return null; 
     int mRRowLength = a.length;  
     int mRColLength = b[0].length; 
     double[][] mResult = new double[mRRowLength][mRColLength]; 
     for(int i = 0; i < mRRowLength; i++) {   
      for(int j = 0; j < mRColLength; j++) {  
       for(int k = 0; k < m1ColLength; k++) { 
        mResult[i][j] += a[i][k] * b[k][j]; 
       } 
      } 
     } 
     return mResult; 
    } 
    static double[][] matrix_sum(double[][] a, double[][] b, double is_sum){ //matrix sum 
     int m1ColLength = a[0].length; 
     int m2RowLength = b.length;  
     int m1RowLength = a.length;  
     int m2ColLength = b[0].length; 
     if(m1ColLength != m2ColLength || m1RowLength != m2RowLength) return null; 
     double[][] mResult = new double[m1RowLength][m1ColLength]; 
     for(int i = 0; i < m1RowLength; i++) {   
      for(int j = 0; j < m1ColLength; j++) {  
       mResult[i][j]=a[i][j]+(b[i][j])*is_sum; 
      } 
     } 
     return mResult; 
    } 
    static double[] vect_sum(double[] a, double[] b, double is_sum){ // vector sum 
     int m2RowLength = b.length;  
     int m1RowLength = a.length;  
     if(m1RowLength != m2RowLength) return null; 
     double[] mResult = new double[m1RowLength]; 
     for(int i = 0; i < m1RowLength; i++) {   
      mResult[i]=a[i]+(b[i])*is_sum; 
     } 
     return mResult; 
    } 
    static double[][] matrix_sum_vect(double[][] a, double[] b, double is_sum){ // adds a vector to each column 
     int m1ColLength = a[0].length; 
     int m2RowLength = b.length;  
     int m1RowLength = a.length;  
     if(m1RowLength != m2RowLength) return null; 
     double[][] mResult = new double[m1RowLength][m1ColLength]; 
     for(int i = 0; i < m1RowLength; i++) {   
      for(int j = 0; j < m1ColLength; j++) {  
       mResult[i][j]=a[i][j]+(b[i])*is_sum; 
      } 
     } 
     return mResult; 
    } 
    static double[][] matrix_hadamard(double[][] a, double[][] b){ // hadamard product 
     int m1ColLength = a[0].length; 
     int m2RowLength = b.length;  
     int m1RowLength = a.length;  
     int m2ColLength = b[0].length; 
     if(m1ColLength != m2ColLength || m1RowLength != m2RowLength) return null; 
     double[][] mResult = new double[m1RowLength][m1ColLength]; 
     for(int i = 0; i < m1RowLength; i++) {   
      for(int j = 0; j < m1ColLength; j++) {  
       mResult[i][j]=a[i][j]*b[i][j]; 
      } 
     } 
     return mResult; 
    } 
    static double[][] matrix_x_scalar(double[][] a, double scalar){ // matrix times scalar 
     int m1ColLength = a[0].length; 
     int m1RowLength = a.length;  
     double[][] mResult = new double[m1RowLength][m1ColLength]; 
     for(int i = 0; i < m1RowLength; i++) {   
      for(int j = 0; j < m1ColLength; j++) {  
       mResult[i][j]=a[i][j]*scalar; 
      } 
     } 
     return mResult; 
    } 
    static double[][] transpose_matrix(double [][] m){ 
     double[][] mResult = new double[m[0].length][m.length]; 
     for (int i = 0; i < m.length; i++) 
      for (int j = 0; j < m[0].length; j++) 
       mResult[j][i] = m[i][j]; 
     return mResult; 
    } 
    static double sigmoid(double z) { 
    return 1.0/(1.0+Math.exp(-z)); 
    } 
    static double[][] sigmoid(double[][] z) { 
     for(int i=0; i<z.length; i++){ 
      for(int j=0; j<z[0].length; j++){ 
       z[i][j]= sigmoid(z[i][j]); 
      } 
     } 
    return z; 
    } 
    static double sigmoid_prime(double z) { 
    return sigmoid(z)*(1-sigmoid(z)); 
    } 
    static double[][] sigmoid_prime(double[][] z) { 
     for(int i=0; i<z.length; i++){ 
      for(int j=0; j<z[0].length; j++){ 
       z[i][j]= sigmoid_prime(z[i][j]); 
      } 
     } 
    return z; 
    }// ----- end math ----- 







} 

나는 평균 모든 그라디언트,하지만 난 그냥 그것을 찾을 수 없다는 오류가 어쩌면 K 루프에서 dJ_w3, dJ_w2 기능을 숨 깁니다 것을 매우 확신합니다. 도와 주시겠습니까?

답변

0

문제가 발견되면 교육 반복 횟수를 50000 개로 늘려야합니다.