도와주세요! 나는 deeplearning4j
을 사용하는 프로젝트에서 일하고 있습니다. MNIST 예제는 잘 작동하지만 데이터 세트에 오류가 발생합니다. 내 데이터 세트에는 2 개의 출력이 있습니다. 아닌 행렬 입력;MNIST를 사용한 코드 오류 deeplearning4j 예제
int height = 45;
int width = 800;
int channels = 1;
int rngseed = 123;
Random randNumGen = new Random(rngseed);
int batchSize = 128;
int outputNum = 2;
int numEpochs = 15;
File trainData = new File("C:/Users/JHP/Desktop/learningData/training");
File testData = new File("C:/Users/JHP/Desktop/learningData/testing");
FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
FileSplit test = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
ImageRecordReader recordReader2 = new ImageRecordReader(height, width, channels, labelMaker);
recordReader.initialize(train);
recordReader2.initialize(test);
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputNum);
DataSetIterator testIter = new RecordReaderDataSetIterator(recordReader2, batchSize, 1, outputNum);
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
scaler.fit(dataIter);
dataIter.setPreProcessor(scaler);
System.out.println("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngseed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.learningRate(0.006)
.updater(Updater.NESTEROVS).momentum(0.9)
.regularization(true).l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(height * width)
.nOut(1000)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.build()
)
.layer(1, newOutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(1000)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.build()
)
.pretrain(false).backprop(true)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(1));
System.out.println("Train model....");
for (int i = 0; i < numEpochs; i++) {
try {
model.fit(dataIter);
} catch (Exception e) {
System.out.println(e);
}
}
오류
org.deeplearning4j.exception.DL4JInvalidInputException은 예상 매트릭스 (순위 2), 모양 와 순위 4 배열 [128, 1, 45, 800]
DataSetIterator 함수를 다른 함수로 변경해야한다고 생각합니다. MNIST 예제의 경우 데이터를 함수로 가져 오는 것과 같습니다. ** DataSetIterator mnistTrain = new MnistDataSetIterator (batchSize, true, rngseed); ** 어떤 기능을 사용해야할지 모르겠습니다. – user7887249
@ TriV TriV 개선 할 부분을 알려 주신 데 대해 감사드립니다! 스택 오버플로를 처음 사용했기 때문에 몰랐습니다. 대단히 감사합니다! – user7887249