package com.example.minwoo_k.neural_network;
import android.os.AsyncTask;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.reflections.vfs.CommonsVfs2UrlType;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import static android.R.id.input;
import static org.reflections.Reflections.log;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
AsyncTask.execute(new Runnable() {
@Override
public void run() {
try {
createAndUseNetwork();
} catch (IOException e) {
e.printStackTrace();
}
}
});
}
private void createAndUseNetwork() throws IOException {
DenseLayer inputLayer = new DenseLayer.Builder() // Input Layer
.nIn(784)
.nOut(200)
.name("Input")
.activation(Activation.SIGMOID) // Sigmoid Activation function
.build();
DenseLayer hiddenLayer = new DenseLayer.Builder() // Hidden Layer
.nIn(200)
.nOut(10)
.name("Hidden")
.activation(Activation.SIGMOID) // Sigmoid Activation function
.build();
OutputLayer outputLayer = new OutputLayer.Builder() // Output Layer
.nIn(10)
.nOut(10)
.name("Output")
.activation(Activation.SOFTMAX) // Softmax Activation function
.build();
NeuralNetConfiguration.Builder nncBuilder = new NeuralNetConfiguration.Builder();
nncBuilder.iterations(5);
nncBuilder.learningRate(0.05); // Learning Rate
nncBuilder.weightInit(WeightInit.XAVIER);
nncBuilder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); // use SGD
NeuralNetConfiguration.ListBuilder listBuilder = nncBuilder.list();
listBuilder.layer(0, inputLayer);
listBuilder.layer(1, hiddenLayer);
listBuilder.layer(2, outputLayer);
listBuilder.backprop(true); // backpropagation
Log.d("ANN","****************Create ANN********************");
MultiLayerNetwork myNetwork = new MultiLayerNetwork(listBuilder.build());
myNetwork.init();
myNetwork.setListeners(new ScoreIterationListener(1));
Log.d("ANN","****************Get Data********************");
DataSetIterator mnistTrain = new MnistDataSetIterator(500, 10000, true);
DataSetIterator mnistTest = new MnistDataSetIterator(500, 100, true);
Log.d("ANN","****************Train ANN********************");
myNetwork.fit(mnistTrain);
Log.d("ANN","****************Evaluate ANN********************");
Evaluation eval = new Evaluation(10); //create an evaluation object with 10 possible classes
while(mnistTest.hasNext()){
DataSet next = mnistTest.next();
INDArray output = myNetwork.output(next.getFeatureMatrix()); //get the networks prediction
eval.eval(next.getLabels(), output); //check the prediction against the true class
}
log.info(eval.stats());
log.info("****************Example finished********************");
}
}
이것은 내 프로그램의 전체 소스 코드이며 mnist 데이터를 읽을 수 없습니다. Mnist 데이터 세트를 얻으려면 어떻게해야합니까? 26 :안드로이드의 DataSetlterator, DL4J에서 mnist 데이터를 얻으려면 어떻게해야합니까?
12-15 12 06.526 3910-3930/com.example.minwoo_k.neural_network W/System.err에가 : 때 java.io.IOException : 만들어라 (mkdir) 없습니다/MNIST 12-15 12시 26분 : 06.526 3910-3930/com.example.minwoo_k.neural_network 승/System.err : org.deeplearning4j.base.MnistFetcher.downloadAndUntar (MnistFetcher.java:66) 12-15 12 : 26 : 06.529 3910-3930 /com.example.minwoo_k.neural_network 승/System.err : org.deeplearning4j.datasets.fetchers.MnistDataFetcher (MnistDataFetcher.java:65) 12-15 12 : 26 : 06.529 3910-3930/com.example .minwoo_k.neural_network 승/System.err : at,org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator (MnistDataSetIterator.java:65) 12-15 12. 26 : 3910-3930 06.529/com.example.minwoo_k.neural_network W/System.err에 : 에서 조직도 .deeplearning4j.datasets.iterator.impl.MnistDataSetIterator. (MnistDataSetIterator.java:43) 12-15 12 : 26 : 06.529 3910-3930/com.example.minwoo_k.neural_network 승/System.err : com.example .minwoo_k.neural_network.MainActivity.createAndUseNetwork (MainActivity.java:93) 12-15 12 : 26 : 06.529 3910-3930/com.example.minwoo_k.neural_network 승/System.err : com.example.minwoo_k. neural_network.MainActivity.access $ 000 (MainActivity.java:33) 12-15 12 : 26 : 06.531 3910-3930/com.example.m inwoo_k.neural_network W/System.err : com.example.minwoo_k.neural_network.MainActivity $ 1.run (MainActivity.java:44) 12-15 12 : 26 : 06.531 3910-3930/com.example.minwoo_k. neural_network W/System.err : android.os.AsyncTask $ SerialExecutor $ 1.run (AsyncTask.java:245) 12-15 12 : 26 : 06.532 3910-3930/com.example.minwoo_k.neural_network W/System.err : 에서 java.util.concurrent.ThreadPoolExecutor.runWorker (ThreadPoolExecutor.java:1162) 12-15 26 : 06.532 3910-3930/com.example.minwoo_k.neural_network W/System.err : at java.util.concurrent.ThreadPoolExecutor $ Worker.run (ThreadPoolExecutor.java:636) 12-15 12 : 26 : 06.532 3910-393 0/com.example.minwoo_k.neural_network W/System.err에가 :
java.lang.Thread.run (Thread.java:764)에서이 내 로그 캣 기록이다. 어떻게이 문제를 해결할 수 있습니까?