% clc; clear;

% digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
%     'nndatasets','DigitDataset');

train_path = "./TrainingMNIST";
imds = imageDatastore(train_path, ...
    'IncludeSubfolders',true,'LabelSource','foldernames');

figure;
perm = randperm(6000,20);
for i = 1:20
    subplot(4,5,i);
    imshow(imds.Files{perm(i)});
end

labelCount = countEachLabel(imds);

img = readimage(imds,1);
size(img);

numTrainFiles = 500;
[imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,'randomize');

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

options = trainingOptions('sgdm', ...
                        'InitialLearnRate',0.01, ...
                        'MaxEpochs',4, ...
                        'Shuffle','every-epoch', ...
                        'ValidationData',imdsValidation, ...
                        'ValidationFrequency',30, ...
                        'Verbose',false, ...
                        'Plots','training-progress');

net = trainNetwork(imdsTrain,layers,options);

YPred = classify(net,imdsTrain);
YValidation = imdsTrain.Labels;

accuracy = sum(YPred == YValidation)/numel(YValidation);
fprintf("Training Accuracy : %.5f\n", accuracy);

YPred = classify(net,imdsValidation);
YValidation = imdsValidation.Labels;

accuracy = sum(YPred == YValidation)/numel(YValidation);
fprintf("Validation Accuracy : %.5f\n", accuracy);

test_dataset = "./CGF";
test_datastore = imageDatastore(test_dataset, ...
                    'IncludeSubfolders',true,'LabelSource','foldernames');

YPred = classify(net,test_datastore);
YValidation = test_datastore.Labels;

accuracy = sum(YPred == YValidation)/numel(YValidation);
fprintf("CGF Accuracy : %.5f\n", accuracy);

test_dataset = "./CHF";
test_datastore = imageDatastore(test_dataset, ...
                    'IncludeSubfolders',true,'LabelSource','foldernames');

YPred = classify(net,test_datastore);
YValidation = test_datastore.Labels;

accuracy = sum(YPred == YValidation)/numel(YValidation);
fprintf("CHF Accuracy : %.5f\n", accuracy);

test_dataset = "./FastICA";
test_datastore = imageDatastore(test_dataset, ...
                    'IncludeSubfolders',true,'LabelSource','foldernames');

YPred = classify(net,test_datastore);
YValidation = test_datastore.Labels;

accuracy = sum(YPred == YValidation)/numel(YValidation);
fprintf("FastICA Accuracy : %.5f\n", accuracy);

test_dataset = "./JADE";
test_datastore = imageDatastore(test_dataset, ...
                    'IncludeSubfolders',true,'LabelSource','foldernames');

YPred = classify(net,test_datastore);
YValidation = test_datastore.Labels;

accuracy = sum(YPred == YValidation)/numel(YValidation);
fprintf("JADE Accuracy : %.5f\n", accuracy);

test_dataset = "./Kurtosis";
test_datastore = imageDatastore(test_dataset, ...
                    'IncludeSubfolders',true,'LabelSource','foldernames');

YPred = classify(net,test_datastore);
YValidation = test_datastore.Labels;

accuracy = sum(YPred == YValidation)/numel(YValidation);
fprintf("Kurtosis Accuracy : %.5f\n", accuracy);

test_dataset = "./Meta";
test_datastore = imageDatastore(test_dataset, ...
                    'IncludeSubfolders',true,'LabelSource','foldernames');

YPred = classify(net,test_datastore);
YValidation = test_datastore.Labels;

accuracy = sum(YPred == YValidation)/numel(YValidation);
fprintf("Meta Accuracy : %.5f\n", accuracy);