//*************************************************************************************
//
// File name: ClassCNNCompleteTrain.cpp
// Location: See Matrox Example Launcher in the MIL Control Center
// 
//
// Synopsis:  This program uses the classification module to train
//            a context able to classify 3 different type of fabrics.
//
//
// Copyright (C) Matrox Electronic Systems Ltd., 1992-2020.
// All Rights Reserved

#include <windows.h>
#include <mil.h>
#include <string>
#include <algorithm>
#include <random>
#include <numeric>
#include <map>
#include <math.h>

static const MIL_INT NUMBER_OF_CLASSES = 3;

// ===========================================================================
// Example description.
// ===========================================================================
void PrintHeader()
   {
   MosPrintf(MIL_TEXT("[EXAMPLE NAME]\n")
      MIL_TEXT("ClassCNNCompleteTrain\n\n")

      MIL_TEXT("[SYNOPSIS]\n")
      MIL_TEXT("This example trains a CNN model to classify the %d shown fabrics.\n")
      MIL_TEXT("The first step prepares the datasets needed for the training.\n")
      MIL_TEXT("The second step trains a context and display the train evolution.\n")
      MIL_TEXT("The final step performs predictions on test data using the trained\n")
      MIL_TEXT("CNN model as a final check of the expected model performance.\n\n")

      MIL_TEXT("[MODULES USED]\n")
      MIL_TEXT("Modules used: application, system, display, buffer,\n")
      MIL_TEXT("graphic, classification.\n\n"), NUMBER_OF_CLASSES);

   MosPrintf(MIL_TEXT("Press <Enter> to continue.\n\n"));
   MosGetch();
   }

// Path definitions.
#define EXAMPLE_IMAGE_ROOT_PATH      M_IMAGE_PATH MIL_TEXT("Classification/Fabrics/")
#define EXAMPLE_ORIGINAL_DATA_PATH   M_IMAGE_PATH MIL_TEXT("Classification/Fabrics/OriginalData/")
#define EXAMPLE_DATA_PATH            MIL_TEXT("Data/")

MIL_STRING FABRICS_CLASS_NAME[NUMBER_OF_CLASSES] = { MIL_TEXT("Fabric1"),
                                                     MIL_TEXT("Fabric2"),
                                                     MIL_TEXT("Fabric3") };

// Nb images per classes.
MIL_INT FABRICS_CLASS_NB_IMAGES[NUMBER_OF_CLASSES] = { 200, 200, 200 };

// Icon image for each classes.
MIL_STRING FABRICS_CLASS_ICON[NUMBER_OF_CLASSES] = { EXAMPLE_IMAGE_ROOT_PATH MIL_TEXT("Fabric1_Icon.mim"),
                                                     EXAMPLE_IMAGE_ROOT_PATH MIL_TEXT("Fabric2_Icon.mim"),
                                                     EXAMPLE_IMAGE_ROOT_PATH MIL_TEXT("Fabric3_Icon.mim") };

static const MIL_INT TRAIN_IMAGE_SIZE          = 99;
static const MIL_INT NB_AUGMENTATION_PER_IMAGE = 2;

MIL_INT CnnTrainEngineDLLInstalled(MIL_ID MilSystem);

MIL_STRING GetExampleCurrentDirectory();

const std::vector<MIL_INT> CreateShuffledIndex(MIL_INT NbEntries, unsigned int Seed);

void DeleteFiles(const std::vector<MIL_STRING>& Files);

void ListFilesInFolder(const MIL_STRING& FolderName, std::vector<MIL_STRING>& FilesInFolder);

void CopyOriginalDataToExampleDataFolder(const MIL_STRING* FabricsClassName,
                                         MIL_INT NumberOfClasses,
                                         const MIL_STRING& OriginalDataPath,
                                         const MIL_STRING& ExampleDataPath);

void DeleteFilesInFolder(const MIL_STRING& FolderName);

void AddClassDescription(MIL_ID MilSystem,
                         MIL_ID Dataset,
                         const MIL_STRING* FabricsClassName,
                         const MIL_STRING* FabricsClassIcon,
                         MIL_INT NumberOfClasses);

void PrepareTheDatasets(MIL_ID MilSystem,
                        const MIL_STRING* FabricsClassName,
                        const MIL_STRING* FabricsClassIcon,
                        MIL_INT NumberOfClasses,
                        const MIL_STRING& OriginalDataPath,
                        const MIL_STRING& ExampleDataPath,
                        MIL_ID TrainDataset,
                        MIL_ID DevDataset,
                        MIL_ID TestDataset);

void PrepareExampleDataFolder(const MIL_STRING& ExampleDataPath, const MIL_STRING* FabricsClassName, MIL_INT NumberOfClasses);

void AddClassToDataset(MIL_INT ClassIndex, const MIL_STRING& DataToTrainPath, const MIL_STRING& FabricName, MIL_ID Dataset);

void AugmentDataset(MIL_ID System, MIL_ID Dataset, MIL_INT NbAugmentPerImage);

void CropDatasetImages(MIL_ID MilSystem, MIL_ID Dataset, MIL_INT FinalImageSize);

MIL_UNIQUE_BUF_ID CreateImageOfAllClasses(MIL_ID MilSystem, const MIL_STRING* FabricClassIcon, MIL_INT NumberOfClasses);

MIL_UNIQUE_CLASS_ID TrainTheModel(MIL_ID MilSystem, MIL_ID TrainDataset, MIL_ID DevDataset, MIL_ID MilDisplay);

void PredictUsingTrainedContext(MIL_ID MilSystem, MIL_ID MilDisplay, MIL_ID TrainedCtx, MIL_ID TestDataset);

//............................................................................
class CTrainEvolutionDashboard
   {
   public:
      CTrainEvolutionDashboard(MIL_ID MilSystem, MIL_INT MaxEpoch, MIL_INT MinibatchSize,
                               MIL_DOUBLE LearningRate,
                               MIL_INT TrainImageSizeX, MIL_INT TrainImageSizeY,
                               MIL_INT TrainDatasetSize, MIL_INT DevDatasetSize,
                               MIL_INT TrainEngineUsed, MIL_STRING& TrainEngineDescription);
      ~CTrainEvolutionDashboard();
      void AddEpochData(MIL_DOUBLE TrainErrorRate, MIL_DOUBLE DevErrorRate,
                        MIL_INT CurEpoch, bool TheEpochIsTheBestUpToNow,
                        MIL_DOUBLE EpochBenchMean);

      void AddMiniBatchData(MIL_DOUBLE LossError, MIL_INT MinibatchIdx, MIL_INT EpochIdx, MIL_INT NbBatchPerEpoch);

      MIL_ID GetDashboardBufId()
         {
         return m_DashboardBufId;
         }

   protected:
      void UpdateEpochInfo(MIL_DOUBLE TrainErrorRate, MIL_DOUBLE DevErrorRate, MIL_INT CurEpoch, bool TheEpochIsTheBestUpToNow);
      void UpdateLoss(MIL_DOUBLE LossError);
      void UpdateEpochGraph(MIL_DOUBLE TrainErrorRate, MIL_DOUBLE DevErrorRate, MIL_INT CurEpoch);
      void UpdateLossGraph(MIL_DOUBLE LossError, MIL_INT MiniBatchIdx, MIL_INT EpochIdx, MIL_INT NbBatchPerEpoch);
      void UpdateProgression(MIL_INT MinibatchIdx, MIL_INT EpochIdx, MIL_INT NbBatchPerEpoch);

      void DrawSectionSeparators();
      void DrawBufferFrame(MIL_ID BufId, MIL_INT FrameThickness);
      void InitializeEpochGraph();
      void InitializeLossGraph();
      void WriteGeneralTrainInfo(MIL_INT     MinibatchSize,
                                 MIL_INT     TrainImageSizeX,
                                 MIL_INT     TrainImageSizeY,
                                 MIL_INT     TrainDatasetSize,
                                 MIL_INT     DevDatasetSize,
                                 MIL_DOUBLE  LearningRate,
                                 MIL_INT     TrainEngineUsed,
                                 MIL_STRING& TrainEngineDescription);

      MIL_UNIQUE_BUF_ID m_DashboardBufId;
      MIL_UNIQUE_GRA_ID m_TheGraContext;

      MIL_UNIQUE_BUF_ID m_EpochInfoBufId;
      MIL_UNIQUE_BUF_ID m_EpochGraphBufId;
      MIL_UNIQUE_BUF_ID m_LossInfoBufId;
      MIL_UNIQUE_BUF_ID m_LossGraphBufId;
      MIL_UNIQUE_BUF_ID m_ProgressionInfoBufId;

      MIL_INT m_MaxEpoch;
      MIL_INT m_DashboardWidth;
      MIL_INT m_LastTrainPosX;
      MIL_INT m_LastTrainPosY;
      MIL_INT m_LastDevPosX;
      MIL_INT m_LastDevPosY;
      MIL_INT m_LastTrainMinibatchPosX;
      MIL_INT m_LastTrainMinibatchPosY;

      MIL_INT m_YPositionForLossText;

      MIL_DOUBLE m_EpochBenchMean;

      //Constants usefull for the graph
      MIL_INT GRAPH_SIZE_X;
      MIL_INT GRAPH_SIZE_Y;
      MIL_INT GRAPH_TOP_MARGIN;
      MIL_INT MARGIN;
      MIL_INT EPOCH_AND_MINIBATCH_REGION_HEIGHT;
      MIL_INT PROGRESSION_INFO_REGION_HEIGHT;

      MIL_INT LOSS_EXPONENT_MAX;
      MIL_INT LOSS_EXPONENT_MIN;

      MIL_DOUBLE COLOR_GENERAL_INFO;
      MIL_DOUBLE COLOR_DEV_SET_INFO;
      MIL_DOUBLE COLOR_TRAIN_SET_INFO;
      MIL_DOUBLE COLOR_PROGRESS_BAR;
   };

//............................................................................
struct HookEpochData
   {
   CTrainEvolutionDashboard* TheDashboard;
   };

struct HookMiniBatchData
   {
   CTrainEvolutionDashboard* TheDashboard;
   };


MIL_INT MFTYPE HookFuncEpoch(MIL_INT HookType, MIL_ID EventId, void* UserData);

MIL_INT MFTYPE HookFuncMiniBatch(MIL_INT HookType, MIL_ID EventId, void* UserData);

//............................................................................
class CPredictResultDisplay
   {
   public:
      CPredictResultDisplay(MIL_ID MilSystem, MIL_ID MilDisplay, MIL_ID TestDataset);
      ~CPredictResultDisplay();
      void Update(MIL_ID ImageToPredict, MIL_INT BestIndex, MIL_DOUBLE BestScore);

   protected:
      MIL_ID  m_MilSystem;
      MIL_ID  m_MilDisplay;
      MIL_INT m_MaxTrainImageSize;

      MIL_UNIQUE_BUF_ID m_MilDispImage;
      MIL_UNIQUE_BUF_ID m_MilDispChild;
      MIL_ID            m_MilOverlay;

      MIL_UNIQUE_GRA_ID m_GraContext;

      const MIL_DOUBLE COLOR_PREDICT_INFO;
      const MIL_INT    MARGIN;
   };

// ****************************************************************************
//    Main.
// ****************************************************************************
int MosMain()
   {
#if !M_MIL_USE_64BIT
   MosPrintf(MIL_TEXT("\n***** MclassTrain() is not available with a non 64-bit platform. *****\n"));
   MosPrintf(MIL_TEXT("Press <enter> to end...\n"));
   MosGetch();
   return -1;
#else //M_MIL_USE_64BIT

   PrintHeader();

   MIL_UNIQUE_APP_ID MilApplication = MappAlloc(M_NULL, M_DEFAULT, M_UNIQUE_ID);
   MIL_UNIQUE_SYS_ID MilSystem      = MsysAlloc(M_DEFAULT, M_SYSTEM_HOST, M_DEFAULT, M_DEFAULT, M_UNIQUE_ID);

   //If no train engine installed then the train example cannot run
   if(CnnTrainEngineDLLInstalled(MilSystem)!=M_TRUE)
      {
      MosPrintf(MIL_TEXT("\n***** No train engine installed, MclassTrain() cannot run! *****\n"));
      MosPrintf(MIL_TEXT("Press <enter> to end...\n"));
      MosGetch();
      return -1;
      }

   MIL_UNIQUE_DISP_ID MilDisplay = MdispAlloc(MilSystem, M_DEFAULT, MIL_TEXT("M_DEFAULT"), M_DEFAULT, M_UNIQUE_ID);

   //Display a representative image of all classes
   MIL_UNIQUE_BUF_ID AllClassesImage = CreateImageOfAllClasses(MilSystem, FABRICS_CLASS_ICON, NUMBER_OF_CLASSES);
   MdispSelect(MilDisplay, AllClassesImage);



   MIL_UNIQUE_CLASS_ID TrainDataset, DevDataset, TestDataset;
   TrainDataset = MclassAlloc(MilSystem, M_DATASET_IMAGES, M_DEFAULT, M_UNIQUE_ID);
   DevDataset   = MclassAlloc(MilSystem, M_DATASET_IMAGES, M_DEFAULT, M_UNIQUE_ID);
   TestDataset  = MclassAlloc(MilSystem, M_DATASET_IMAGES, M_DEFAULT, M_UNIQUE_ID);

   MclassControl(TrainDataset , M_CONTEXT, M_ROOT_PATH, GetExampleCurrentDirectory());   
   MclassControl(DevDataset   , M_CONTEXT, M_ROOT_PATH, GetExampleCurrentDirectory());   
   MclassControl(TestDataset  , M_CONTEXT, M_ROOT_PATH, GetExampleCurrentDirectory());

   MosPrintf(MIL_TEXT("\n*******************************************************\n"));
   MosPrintf(MIL_TEXT("PREPARING THE DATASETS... THIS WILL TAKE SOME TIME...\n"));
   MosPrintf(MIL_TEXT("*******************************************************\n"));
   PrepareTheDatasets(MilSystem, FABRICS_CLASS_NAME, FABRICS_CLASS_ICON, NUMBER_OF_CLASSES, EXAMPLE_ORIGINAL_DATA_PATH, EXAMPLE_DATA_PATH,
                      TrainDataset, DevDataset, TestDataset);

   // Useful to export entries from different sets if one wants to ensure that
   // data preparation has worked as expected. Uncomment if required...
   // MclassExport(MIL_TEXT("TrainDataset.csv"), M_FORMAT_CSV, TrainDataset, M_DEFAULT, M_ENTRIES, M_DEFAULT);
   // MclassExport(MIL_TEXT("DevDataset.csv"), M_FORMAT_CSV, DevDataset, M_DEFAULT, M_ENTRIES, M_DEFAULT);
   // MclassExport(MIL_TEXT("TestDataset.csv"), M_FORMAT_CSV, TestDataset, M_DEFAULT, M_ENTRIES, M_DEFAULT);

   MosPrintf(MIL_TEXT("\n*******************************************************\n"));
   MosPrintf(MIL_TEXT("TRAINING... THIS WILL TAKE SOME TIME...\n"));
   MosPrintf(MIL_TEXT("*******************************************************\n"));

   MIL_INT TrainDatasetNbImages = MclassInquire(TrainDataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, M_NULL);
   MIL_INT DevDatasetNbImages = MclassInquire(DevDataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, M_NULL);

   MosPrintf(MIL_TEXT("\nThe training has started.\n"));
   MosPrintf(MIL_TEXT("It can be paused at any time by pressing 'p'.\n"));
   MosPrintf(MIL_TEXT("It can then be stopped or continued.\n"));

   MosPrintf(MIL_TEXT("\nDuring training, you can observe the displayed error rate of the train\n"));
   MosPrintf(MIL_TEXT("and dev datasets together with the evolution of the loss value...\n"));

   MIL_UNIQUE_CLASS_ID TrainedCtx = TrainTheModel(MilSystem, TrainDataset, DevDataset, MilDisplay);

   if(TrainedCtx)
      {
      MosPrintf(MIL_TEXT("\n*******************************************************\n"));
      MosPrintf(MIL_TEXT("PREDICTING USING THE TRAINED CONTEXT...\n"));
      MosPrintf(MIL_TEXT("*******************************************************\n"));

      PredictUsingTrainedContext(MilSystem, MilDisplay, TrainedCtx, TestDataset);
      }
   else
      {
      MosPrintf(MIL_TEXT("\nTraining has not completed properly !!!!!!!!!!!!!!\n"));
      MosPrintf(MIL_TEXT("Press <enter> to end...\n"));
      MosGetch();
      }

   return 0;
#endif //!M_MIL_USE_64BIT
   }

MIL_INT CnnTrainEngineDLLInstalled(MIL_ID MilSystem)
   {
   MIL_INT IsInstalled = M_FALSE;

   MIL_UNIQUE_CLASS_ID TrainCtx = MclassAlloc(MilSystem, M_TRAIN_CNN, M_DEFAULT, M_UNIQUE_ID);
   MclassInquire(TrainCtx, M_DEFAULT, M_CNN_TRAIN_ENGINE_IS_INSTALLED + M_TYPE_MIL_INT, &IsInstalled);

   return IsInstalled;
   }

MIL_STRING GetExampleCurrentDirectory()
   {
   DWORD CurDirStrSize = GetCurrentDirectory(0, NULL) + 1;

   std::vector<MIL_TEXT_CHAR> vCurDir(CurDirStrSize, 0);
   GetCurrentDirectory(CurDirStrSize, (LPTSTR) &vCurDir[0]);

   MIL_STRING sRet = &vCurDir[0];
   return sRet;
   }

MIL_UNIQUE_BUF_ID CreateImageOfAllClasses(MIL_ID MilSystem, const MIL_STRING* FabricClassIcon, MIL_INT NumberOfClasses)
   {
   MIL_INT MaxSizeY = MIL_INT_MIN;
   MIL_INT SumSizeX = 0;
   std::vector<MIL_UNIQUE_BUF_ID> IconsToDisplay;
   for(MIL_INT i=0; i<NumberOfClasses; i++)
      {
      IconsToDisplay.push_back(MbufRestore(FabricClassIcon[i], MilSystem, M_UNIQUE_ID));
      MIL_INT SizeX = MbufInquire(IconsToDisplay.back(), M_SIZE_X, M_NULL);
      MIL_INT SizeY = MbufInquire(IconsToDisplay.back(), M_SIZE_Y, M_NULL);

      MaxSizeY = std::max<MIL_INT>(SizeY, MaxSizeY);
      SumSizeX = SumSizeX + SizeX;
      }

   MIL_UNIQUE_BUF_ID AllClassesImage = MbufAllocColor(MilSystem, 3, SumSizeX, MaxSizeY, 8 + M_UNSIGNED, M_IMAGE + M_PROC + M_DISP, M_UNIQUE_ID);
   MbufClear(AllClassesImage, 0.0);

   MIL_UNIQUE_GRA_ID GraContext = MgraAlloc(MilSystem, M_UNIQUE_ID);
   MgraColor(GraContext, M_COLOR_BLUE);

   MIL_INT CurXOffset = 0;
   for(const auto& IconImage : IconsToDisplay)
      {
      MIL_INT SizeX = MbufInquire(IconImage, M_SIZE_X, M_NULL);
      MIL_INT SizeY = MbufInquire(IconImage, M_SIZE_Y, M_NULL);

      MbufCopyColor2d(IconImage, AllClassesImage, M_ALL_BANDS, 0, 0, M_ALL_BANDS, CurXOffset, 0, SizeX, SizeY);
      MgraRect(GraContext, AllClassesImage, CurXOffset, 0, CurXOffset + SizeX - 1, SizeY - 1);
      CurXOffset += SizeX;
      }

   return AllClassesImage;
   }

const std::vector<MIL_INT> CreateShuffledIndex(MIL_INT NbEntries, unsigned int Seed)
   {
   std::vector<MIL_INT> IndexVector(NbEntries);
   std::iota(IndexVector.begin(), IndexVector.end(), 0);
   std::mt19937 gen(Seed);
   std::shuffle(IndexVector.begin(), IndexVector.end(), gen);
   return IndexVector;
   }

void DeleteFiles(const std::vector<MIL_STRING>& Files)
   {
   for(const auto& FileName : Files)
      {
      MappFileOperation(M_DEFAULT, FileName, M_NULL, M_NULL, M_FILE_DELETE, M_DEFAULT, M_NULL);
      }
   }

void ListFilesInFolder(const MIL_STRING& FolderName, std::vector<MIL_STRING>& FilesInFolder)
   {
   MIL_STRING FileToSearch = FolderName;
   FileToSearch += MIL_TEXT("*.*");

   WIN32_FIND_DATA FindFileData;
   HANDLE hFind;
   hFind = FindFirstFile(FileToSearch.c_str(), &FindFileData);

   if(hFind == INVALID_HANDLE_VALUE)
      {
      MosPrintf(MIL_TEXT("FindFirstFile failed (%d)\n"), GetLastError());
      return;
      }

   do
      {
      if(!(FindFileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY))
         {
         FilesInFolder.push_back(FolderName + FindFileData.cFileName);
         }
      } while(FindNextFile(hFind, &FindFileData) != 0);

      FindClose(hFind);
   }

void AddClassDescription(MIL_ID MilSystem,
                         MIL_ID Dataset,
                         const MIL_STRING* FabricsClassName,
                         const MIL_STRING* FabricsClassIcon,
                         MIL_INT NumberOfClasses)
   {
   for(MIL_INT i = 0; i < NumberOfClasses; i++)
      {
      MclassControl(Dataset, M_DEFAULT, M_CLASS_ADD, FabricsClassName[i]);
      MIL_UNIQUE_BUF_ID IconImageId = MbufRestore(FabricsClassIcon[i], MilSystem, M_UNIQUE_ID);
      MclassControl(Dataset, M_CLASS_INDEX(i), M_CLASS_ICON_ID, IconImageId);
      }
   }

void PrepareTheDatasets(MIL_ID MilSystem,
                        const MIL_STRING* FabricsClassName,
                        const MIL_STRING* FabricsClassIcon,
                        MIL_INT NumberOfClasses,
                        const MIL_STRING& OriginalDataPath,
                        const MIL_STRING& ExampleDataPath,
                        MIL_ID TrainDataset,
                        MIL_ID DevDataset,
                        MIL_ID TestDataset)
   {
   // If not already existing we will create the appropriate
   // ExampleDataPath folders structure.
   // If the structure is already existing, then we will remove previous
   // data to ensure repeatability
   PrepareExampleDataFolder(ExampleDataPath, FabricsClassName, NumberOfClasses);

   // We copy the original data to the ExampleDataPath folder to ensure we can
   // modify/preprocess this data later without affecting original data
   CopyOriginalDataToExampleDataFolder(FabricsClassName, NumberOfClasses, OriginalDataPath, ExampleDataPath);

   // We create a dataset with all the data
   MosPrintf(MIL_TEXT("\nCreating the dataset containing all the data...\n"));
   MIL_UNIQUE_CLASS_ID FullDataset = MclassAlloc(MilSystem, M_DATASET_IMAGES, M_DEFAULT, M_UNIQUE_ID);
   MclassControl(FullDataset, M_CONTEXT, M_ROOT_PATH, GetExampleCurrentDirectory());
   AddClassDescription(MilSystem, FullDataset, FabricsClassName, FabricsClassIcon, NumberOfClasses);
   for (MIL_INT i = 0; i < NumberOfClasses; i++)
      {
      AddClassToDataset(i, ExampleDataPath, FabricsClassName[i], FullDataset);
      }

   MosPrintf(MIL_TEXT("\nSplitting the dataset to train/dev/test datasets...\n"));
   MIL_UNIQUE_CLASS_ID WorkingDataset = MclassAlloc(MilSystem, M_DATASET_IMAGES, M_DEFAULT, M_UNIQUE_ID);
   MclassControl(WorkingDataset, M_CONTEXT, M_ROOT_PATH, GetExampleCurrentDirectory());

   //We want to split that way: Train=70%, Dev=20% and test 10%
   const MIL_DOUBLE PERCENTAGE_IN_TRAIN_DATASET = 70.0;
   const MIL_DOUBLE PERCENTAGE_IN_DEV_DATASET   = 20.0;
   const MIL_DOUBLE PERCENTAGE_IN_TEST_DATASET  = 10.0;

   // We create the test dataset first
   MclassSplitDataset(M_SPLIT_CONTEXT_FIXED_SEED, FullDataset, WorkingDataset, TestDataset,
                     100.0-PERCENTAGE_IN_TEST_DATASET, M_NULL, M_DEFAULT);

   // Then create the train and dev dataset from the remaining data
   const MIL_DOUBLE PERCENTAGE_TO_USE = PERCENTAGE_IN_TRAIN_DATASET/(100.0-PERCENTAGE_IN_TEST_DATASET)*100.0;
   MclassSplitDataset(M_SPLIT_CONTEXT_FIXED_SEED, WorkingDataset, TrainDataset, DevDataset,
                      PERCENTAGE_TO_USE, M_NULL, M_DEFAULT);

   MosPrintf(MIL_TEXT("\nAugmenting the train dataset...\n"));
   // Perform data augmentation to the TrainDataset.
   AugmentDataset(MilSystem, TrainDataset, NB_AUGMENTATION_PER_IMAGE);

   // Crop the dataset images to ensure that they have the required.
   // size for the application
   MosPrintf(MIL_TEXT("\nCropping images from the train/dev/test datasets...\n"));
   MosPrintf(MIL_TEXT("Cropping images from the train dataset...\n"));
   CropDatasetImages(MilSystem, TrainDataset, TRAIN_IMAGE_SIZE);
   MosPrintf(MIL_TEXT("Cropping images from the dev dataset...\n"));
   CropDatasetImages(MilSystem, DevDataset, TRAIN_IMAGE_SIZE);
   MosPrintf(MIL_TEXT("Cropping images from the test dataset...\n"));
   CropDatasetImages(MilSystem, TestDataset , TRAIN_IMAGE_SIZE);

   // Save the datasets. Uncomment if required...
   // MclassSave(MIL_TEXT("TrainDataset.mclassd"), TrainDataset, M_DEFAULT);
   // MclassSave(MIL_TEXT("DevDataset.mclassd"), DevDataset, M_DEFAULT);
   // MclassSave(MIL_TEXT("TestDataset.mclassd"), TestDataset, M_DEFAULT);

   MIL_STRING FolderList;
   for(MIL_INT i = 0; i < NumberOfClasses; i++)
      {
      if(i == (NumberOfClasses - 1))
         FolderList += FabricsClassName[i];
      else
         FolderList += (FabricsClassName[i] + MIL_TEXT(", "));
      }

   MosPrintf(MIL_TEXT("\nData preparation was successful.\n"));

   MosPrintf(MIL_TEXT("\nA train dataset was created using %.lf%% of the original images.\n"), PERCENTAGE_IN_TRAIN_DATASET);
   MosPrintf(MIL_TEXT("Augmented data (%d augmented images for each image in the dataset)\nwas also added to that set\n"), NB_AUGMENTATION_PER_IMAGE);
   MosPrintf(MIL_TEXT("to ensure enough data for training and to be tolerant to small changes\nin the later acquisition setup.\n"));

   MosPrintf(MIL_TEXT("\nA dev dataset was created using %.lf%% of the original images.\n"), PERCENTAGE_IN_DEV_DATASET);
   MosPrintf(MIL_TEXT("\nA test dataset was created using %.lf%% of the original images.\n"), PERCENTAGE_IN_TEST_DATASET);

   MosPrintf(MIL_TEXT("\nOriginal images are a little larger than those in the final application.\n"));
   MosPrintf(MIL_TEXT("This ensures augmented images do not contain overscan pixels.\n"));
   MosPrintf(MIL_TEXT("In the final datasets, the images were cropped to meet the final \n"));
   MosPrintf(MIL_TEXT("application's image size requirement.\n"));
   MosPrintf(MIL_TEXT("Press <enter> to continue...\n"));
   MosGetch();
   }


void CopyOriginalDataToExampleDataFolder(const MIL_STRING* FabricsClassName,
                                         MIL_INT NumberOfClasses,
                                         const MIL_STRING& OriginalDataPath,
                                         const MIL_STRING& ExampleDataPath)
   {
   MosPrintf(MIL_TEXT("\nCopying original images\nfrom %s folder\nto %s folder...\n"), OriginalDataPath.c_str(), ExampleDataPath.c_str());
   for(MIL_INT FabricIndex=0; FabricIndex<NumberOfClasses; FabricIndex++)
      {
      MIL_INT NbImages = FABRICS_CLASS_NB_IMAGES[FabricIndex];
      // Image names are 0001.mim, 0002.mim, ..., 000(NbImages).mim
      for(MIL_INT i = 1; i <= NbImages; i++)
         {
         MIL_TEXT_CHAR OriginalFileName[512];
         MosSprintf(OriginalFileName, 512, MIL_TEXT("%s%s/%04d.mim"), OriginalDataPath.c_str(), FabricsClassName[FabricIndex].c_str(), i);
         MIL_TEXT_CHAR DestFileName[512];
         MosSprintf(DestFileName, 512, MIL_TEXT("%s%s/%04d.mim"), ExampleDataPath.c_str(), FabricsClassName[FabricIndex].c_str(), i);
         MappFileOperation(M_DEFAULT, OriginalFileName, M_DEFAULT, DestFileName, M_FILE_COPY, M_DEFAULT, M_NULL);
         }
      }
   }

void DeleteFilesInFolder(const MIL_STRING& FolderName)
   {
   std::vector<MIL_STRING> FilesInFolder;
   ListFilesInFolder(FolderName, FilesInFolder);
   DeleteFiles(FilesInFolder);
   }


void PrepareExampleDataFolder(const MIL_STRING& ExampleDataPath, const MIL_STRING* FabricsClassName, MIL_INT NumberOfClasses)
   {
   MIL_INT FileExists;
   MappFileOperation(M_DEFAULT, ExampleDataPath, M_NULL, M_NULL, M_FILE_EXISTS, M_DEFAULT, &FileExists);

   if(FileExists != M_YES)
      {
      MosPrintf(MIL_TEXT("\nCreating the %s folder and a sub folder for each class...\n"), ExampleDataPath.c_str());

      // Create ExampleDataPath folder since it does not exist.
      MappFileOperation(M_DEFAULT, ExampleDataPath, M_NULL, M_NULL, M_FILE_MAKE_DIR, M_DEFAULT, M_NULL);
      for(MIL_INT i = 0; i < NumberOfClasses; i++)
         {
         // Create one folder for each class name.
         MappFileOperation(M_DEFAULT, ExampleDataPath + FabricsClassName[i], M_NULL, M_NULL, M_FILE_MAKE_DIR, M_DEFAULT, M_NULL);
         }
      }
   else
      {
      // If ExampleDataPath folder is existing, delete files already in there
      // Create the folder if not existing.
      MosPrintf(MIL_TEXT("\nDeleting files in the %s folder to ensure example repeatability...\n"), ExampleDataPath.c_str());

      for(MIL_INT i = 0; i < NumberOfClasses; i++)
         {
         MappFileOperation(M_DEFAULT, ExampleDataPath + FabricsClassName[i], M_NULL, M_NULL, M_FILE_EXISTS, M_DEFAULT, &FileExists);
         if(FileExists)
            DeleteFilesInFolder(ExampleDataPath + FabricsClassName[i] + MIL_TEXT("/"));
         else
            MappFileOperation(M_DEFAULT, ExampleDataPath + FabricsClassName[i], M_NULL, M_NULL, M_FILE_MAKE_DIR, M_DEFAULT, M_NULL);
         }
      }
   }

void AddClassToDataset(MIL_INT ClassIndex, const MIL_STRING& DataToTrainPath, const MIL_STRING& FabricName, MIL_ID Dataset)
   {
   MIL_INT NbEntries;
   MclassInquire(Dataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, &NbEntries);

   MIL_STRING FolderName = DataToTrainPath + FabricName + MIL_TEXT("/");

   std::vector<MIL_STRING> FilesInFolder;
   ListFilesInFolder(FolderName, FilesInFolder);

   MIL_INT CurImageIndex = 0;
   for(const auto& File : FilesInFolder)
      {
      MclassControl(Dataset, M_DEFAULT, M_ENTRY_ADD, M_DEFAULT);
      MclassControlEntry(Dataset, NbEntries + CurImageIndex, M_DEFAULT_KEY, M_REGION_INDEX(0), M_CLASS_INDEX_GROUND_TRUTH, ClassIndex, M_NULL, M_DEFAULT);
      MclassControlEntry(Dataset, NbEntries + CurImageIndex, M_DEFAULT_KEY, M_DEFAULT, M_FILE_PATH, M_DEFAULT, File, M_DEFAULT);
      CurImageIndex++;
      }
   }


void AugmentDataset(MIL_ID System, MIL_ID Dataset, MIL_INT NbAugmentPerImage)
   {
   auto AugmentContext = MimAlloc(System, M_AUGMENTATION_CONTEXT, M_DEFAULT, M_UNIQUE_ID);
   auto AugmentResult  = MimAllocResult(System, M_DEFAULT, M_AUGMENTATION_RESULT, M_UNIQUE_ID);

   // Seed the augmentation to ensure repeatability.
   MimControl(AugmentContext, M_AUG_SEED_MODE, M_RNG_INIT_VALUE);
   MimControl(AugmentContext, M_AUG_RNG_INIT_VALUE, 42);

   MimControl(AugmentContext, M_AUG_TRANSLATION_X_OP, M_ENABLE);
   MimControl(AugmentContext, M_AUG_TRANSLATION_X_OP_MAX, 2);
   MimControl(AugmentContext, M_AUG_TRANSLATION_Y_OP, M_ENABLE);
   MimControl(AugmentContext, M_AUG_TRANSLATION_Y_OP_MAX, 2);

   MimControl(AugmentContext, M_AUG_SCALE_OP, M_ENABLE);
   MimControl(AugmentContext, M_AUG_SCALE_OP_FACTOR_MIN, 0.97);
   MimControl(AugmentContext, M_AUG_SCALE_OP_FACTOR_MAX, 1.03);

   MimControl(AugmentContext, M_AUG_ROTATION_OP, M_ENABLE);
   MimControl(AugmentContext, M_AUG_ROTATION_OP_ANGLE_DELTA, 5.0);

   MIL_INT NbEntries;
   MclassInquire(Dataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, &NbEntries);

   MIL_INT PosInAugmentDataset = NbEntries;

   for(MIL_INT i = 0; i < NbEntries; i++)
      {
      MosPrintf(MIL_TEXT("%d of %d completed\r"), i + 1, NbEntries);

      MIL_STRING FilePath;
      MclassInquireEntry(Dataset, i, M_DEFAULT_KEY, M_DEFAULT, M_FILE_PATH, FilePath);
      MIL_INT GroundTruthIndex;
      MclassInquireEntry(Dataset, i, M_DEFAULT_KEY, M_REGION_INDEX(0), M_CLASS_INDEX_GROUND_TRUTH + M_TYPE_MIL_INT, &GroundTruthIndex);

      // Add the augmentations.
      MIL_UNIQUE_BUF_ID OrginalImage = MbufRestore(FilePath, System, M_UNIQUE_ID);
      MIL_UNIQUE_BUF_ID AugmentedImage = MbufClone(OrginalImage, M_DEFAULT, M_DEFAULT, M_DEFAULT, M_DEFAULT, M_DEFAULT, M_DEFAULT, M_UNIQUE_ID);
      for(MIL_INT AugIndex = 0; AugIndex < NbAugmentPerImage; AugIndex++)
         {
         MbufClear(AugmentedImage, 0.0);
         MimAugment(AugmentContext, OrginalImage, AugmentedImage, M_DEFAULT, M_DEFAULT);

         MIL_TEXT_CHAR Suffix[128];
         MosSprintf(Suffix, 128, MIL_TEXT("_Aug_%d"), AugIndex);
         
         MIL_STRING AugFileName = FilePath;
         std::size_t DotPos = AugFileName.rfind(MIL_TEXT("."));
         AugFileName.insert(DotPos, Suffix);
         MbufSave(AugFileName, AugmentedImage);

         // Add the augmented image.
         MclassControl(Dataset, M_DEFAULT, M_ENTRY_ADD, M_DEFAULT);
         MclassControlEntry(Dataset, PosInAugmentDataset, M_DEFAULT_KEY, M_REGION_INDEX(0), M_CLASS_INDEX_GROUND_TRUTH, GroundTruthIndex, M_NULL, M_DEFAULT);
         MclassControlEntry(Dataset, PosInAugmentDataset, M_DEFAULT_KEY, M_DEFAULT, M_FILE_PATH, M_DEFAULT, AugFileName, M_DEFAULT);
         // Identify the fact that this is augmented data in case we want to use this dataset later.
         // As an example a good practice is to not put augmented data in a dev or test dataset.
         // N.B. The MIL train operation will not use this information.
         MclassControlEntry(Dataset, PosInAugmentDataset, M_DEFAULT_KEY, M_DEFAULT, M_AUGMENTATION_SOURCE, i, M_NULL, M_DEFAULT);

         PosInAugmentDataset++;
         }
      }
   MosPrintf(MIL_TEXT("\n"));
   }

void CropDatasetImages(MIL_ID MilSystem, MIL_ID Dataset, MIL_INT FinalImageSize)
   {
   MIL_INT NbEntries;
   MclassInquire(Dataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, &NbEntries);

   for(MIL_INT i = 0; i < NbEntries; i++)
      {
      MosPrintf(MIL_TEXT("%d of %d completed\r"), i + 1, NbEntries);

      MIL_STRING FilePath;
      MclassInquireEntry(Dataset, i, M_DEFAULT_KEY, M_DEFAULT, M_FILE_PATH, FilePath);

      MIL_UNIQUE_BUF_ID OriginalImage = MbufRestore(FilePath, MilSystem, M_UNIQUE_ID);

      MIL_INT ImageSizeX = MbufInquire(OriginalImage, M_SIZE_X, M_NULL);
      MIL_INT ImageSizeY = MbufInquire(OriginalImage, M_SIZE_Y, M_NULL);

      //We crop by taking the centered pixels
      MIL_INT OffsetX = (ImageSizeX - FinalImageSize) / 2;
      MIL_INT OffsetY = (ImageSizeY - FinalImageSize) / 2;

      MIL_UNIQUE_BUF_ID CroppedImage = MbufClone(OriginalImage, M_DEFAULT, FinalImageSize, FinalImageSize, M_DEFAULT, M_DEFAULT, M_DEFAULT, M_UNIQUE_ID);

      MbufCopyColor2d(OriginalImage, CroppedImage, M_ALL_BANDS, OffsetX, OffsetY, M_ALL_BANDS, 0, 0, FinalImageSize, FinalImageSize);

      MbufSave(FilePath, CroppedImage);
      }

   MosPrintf(MIL_TEXT("\n"));
   }

MIL_INT MFTYPE HookFuncEpoch(MIL_INT /*HookType*/, MIL_ID EventId, void* UserData)
   {
   auto HookData = (HookEpochData *)UserData;

   MIL_DOUBLE CurBench = 0.0;
   MIL_DOUBLE CurBenchMean = -1.0;

   MIL_INT CurEpochIndex = 0;
   MclassGetHookInfo(EventId, M_EPOCH_INDEX + M_TYPE_MIL_INT, &CurEpochIndex);

   MappTimer(M_DEFAULT, M_TIMER_READ, &CurBench);
   MIL_DOUBLE EpochBenchMean = CurBench / (CurEpochIndex+1);

   MIL_DOUBLE TrainErrorRate = 0;
   MclassGetHookInfo(EventId, M_TRAIN_DATASET_ERROR_RATE, &TrainErrorRate);
   MIL_DOUBLE DevErrorRate = 0;
   MclassGetHookInfo(EventId, M_DEV_DATASET_ERROR_RATE, &DevErrorRate);

   MIL_INT AreTrainedCNNParametersUpdated = M_FALSE;
   MclassGetHookInfo(EventId,
                     M_TRAINED_CNN_PARAMETERS_UPDATED+M_TYPE_MIL_INT,
                     &AreTrainedCNNParametersUpdated);
   // By default trained parameters are updated when the dev error rate
   // is the best up to now.
   bool TheEpochIsTheBestUpToNow = (AreTrainedCNNParametersUpdated == M_TRUE);

   HookData->TheDashboard->AddEpochData(
      TrainErrorRate, DevErrorRate,
      CurEpochIndex, TheEpochIsTheBestUpToNow, EpochBenchMean);

   return M_NULL;
   }

MIL_INT MFTYPE HookFuncMiniBatch(MIL_INT /*HookType*/, MIL_ID EventId, void* UserData)
   {
   auto HookData = (HookMiniBatchData *)UserData;

   MIL_DOUBLE LossError = 0;
   MclassGetHookInfo(EventId, M_MINI_BATCH_LOSS, &LossError);

   MIL_INT MiniBatchIdx = 0;
   MclassGetHookInfo(EventId, M_MINI_BATCH_INDEX + M_TYPE_MIL_INT, &MiniBatchIdx);

   MIL_INT EpochIdx = 0;
   MclassGetHookInfo(EventId, M_EPOCH_INDEX + M_TYPE_MIL_INT, &EpochIdx);

   MIL_INT NbMiniBatchPerEpoch = 0;
   MclassGetHookInfo(EventId, M_MINI_BATCH_PER_EPOCH + M_TYPE_MIL_INT, &NbMiniBatchPerEpoch);

   if(EpochIdx == 0 && MiniBatchIdx == 0)
      {
      MappTimer(M_DEFAULT, M_TIMER_RESET, M_NULL);
      }

   HookData->TheDashboard->AddMiniBatchData(LossError, MiniBatchIdx, EpochIdx, NbMiniBatchPerEpoch);

   if(MosKbhit() != 0)
      {
      char KeyVal = (char)MosGetch();
      if(KeyVal == 'p')
         {
         MosPrintf(MIL_TEXT("\nPress 's' to stop the training or any other key to continue.\n"));
         while(MosKbhit() == 0)
            {
            KeyVal = (char)MosGetch();
            if(KeyVal == 's')
               {
               MIL_ID HookInfoTrainResId = M_NULL;
               MclassGetHookInfo(EventId, M_RESULT_ID + M_TYPE_MIL_ID, &HookInfoTrainResId);
               MclassControl(HookInfoTrainResId, M_DEFAULT, M_STOP_TRAIN, M_DEFAULT);
               MosPrintf(MIL_TEXT("The training have been stopped.\n"));
               break;
               }
            else
               {
               MosPrintf(MIL_TEXT("The training will continue.\n"));
               break;
               }
            }
         }
      }

   return(M_NULL);
   }

MIL_UNIQUE_CLASS_ID TrainTheModel(MIL_ID MilSystem, MIL_ID TrainDataset, MIL_ID DevDataset, MIL_ID MilDisplay)
   {
   MIL_INT TrainDatasetNbImages = MclassInquire(TrainDataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, M_NULL);
   MIL_INT DevDatasetNbImages = MclassInquire(DevDataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, M_NULL);

   // Allocate a context and a result for the training.
   MIL_UNIQUE_CLASS_ID TrainCtx = MclassAlloc(MilSystem, M_TRAIN_CNN, M_DEFAULT, M_UNIQUE_ID);
   MIL_UNIQUE_CLASS_ID TrainRes = MclassAllocResult(MilSystem, M_TRAIN_CNN_RESULT, M_DEFAULT, M_UNIQUE_ID);

   // Use the proper parameters for the training context.
   const MIL_INT    MAX_NUMBER_OF_EPOCH = 10;
   const MIL_INT    MINI_BATCH_SIZE     = 64;
   const MIL_DOUBLE LEARNING_RATE       = 0.001;
   MclassControl(TrainCtx, M_DEFAULT, M_MAX_EPOCH            , MAX_NUMBER_OF_EPOCH);
   MclassControl(TrainCtx, M_DEFAULT, M_MINI_BATCH_SIZE      , MINI_BATCH_SIZE);
   MclassControl(TrainCtx, M_DEFAULT, M_INITIAL_LEARNING_RATE, LEARNING_RATE);

   MclassPreprocess(TrainCtx, M_DEFAULT);

   MIL_INT TrainEngineUsed;
   MclassInquire(TrainCtx, M_CONTEXT, M_CNN_TRAIN_ENGINE_USED + M_TYPE_MIL_INT, &TrainEngineUsed);

   MIL_INT GpuTrainEngineStatus;
   MclassInquire(TrainCtx, M_CONTEXT, M_GPU_TRAIN_ENGINE_LOAD_STATUS + M_TYPE_MIL_INT, &GpuTrainEngineStatus);

   if(TrainEngineUsed == M_GPU && GpuTrainEngineStatus == M_JIT_COMPILATION_REQUIRED)
      {
      MosPrintf(MIL_TEXT("\nWarning :: The training might not be optimal for the current system.\n"));
      MosPrintf(MIL_TEXT("Use the CNN Train Engine Test under Classification in MILConfig for more information.\n"));
      MosPrintf(MIL_TEXT("It may take some time before displaying the first results...\n"));
      }
   else if(GpuTrainEngineStatus != M_SUCCESS)
      {
      MosPrintf(MIL_TEXT("\nWarning :: The training is beeing done on the CPU.\n"));
      MosPrintf(MIL_TEXT("If a training on GPU was expected, use the CNN Train Engine Test under Classification in MILConfig for more information.\n"));
      }

   MIL_STRING TrainEngineDescription;
   MclassInquire(TrainCtx, M_CONTEXT, M_CNN_TRAIN_ENGINE_USED_DESCRIPTION, TrainEngineDescription);

   // Initialize the object reponsible to display the train evolution.
   CTrainEvolutionDashboard TheTrainEvolutionDashboard(MilSystem, MAX_NUMBER_OF_EPOCH,
                                                       MINI_BATCH_SIZE, LEARNING_RATE,
                                                       TRAIN_IMAGE_SIZE, TRAIN_IMAGE_SIZE,
                                                       TrainDatasetNbImages, DevDatasetNbImages,
                                                       TrainEngineUsed, TrainEngineDescription);
   MdispSelect(MilDisplay, TheTrainEvolutionDashboard.GetDashboardBufId());

   // Initialize the hook associated to the epoch trained event.
   HookEpochData TheHookEpochData;
   TheHookEpochData.TheDashboard = &TheTrainEvolutionDashboard;
   MclassHookFunction(TrainCtx, M_EPOCH_TRAINED, HookFuncEpoch, &TheHookEpochData);

   // Initialize the hook associated to the mini batch trained event.
   HookMiniBatchData TheHookMiniBatchData;
   TheHookMiniBatchData.TheDashboard = &TheTrainEvolutionDashboard;
   MclassHookFunction(TrainCtx, M_MINI_BATCH_TRAINED, HookFuncMiniBatch, &TheHookMiniBatchData);

   // Allocate the proper Matrox CNN model for the application.
   MIL_UNIQUE_CLASS_ID PretrainedCtx = MclassAlloc(MilSystem, M_CLASSIFIER_CNN_PREDEFINED, M_FCNET_M, M_UNIQUE_ID);
   // Start the training process.
   MclassTrain(TrainCtx, PretrainedCtx, TrainDataset, DevDataset, TrainRes, M_DEFAULT);

   MIL_UNIQUE_CLASS_ID TrainedCtx;

   // Check the training status to ensure the training has completed properly.
   MIL_INT Status = -1;
   MclassGetResult(TrainRes, M_DEFAULT, M_STATUS + M_TYPE_MIL_INT, &Status);
   if(Status == M_COMPLETE)
      {
      MosPrintf(MIL_TEXT("\nTraining was successful.\n"));

      // Check if at some point there were missing train images.
      MIL_INT NbErrorImage = -1;
      MclassGetResult(TrainRes, M_DEFAULT, M_TRAIN_DATASET_ERROR_ENTRIES + M_NB_ELEMENTS + M_TYPE_MIL_INT, &NbErrorImage);
      if(NbErrorImage != 0)
         {
         MosPrintf(MIL_TEXT("Warning :: few images (%d) were missing at some part of the training.\n"), NbErrorImage);
         }

      TrainedCtx = MclassAlloc(MilSystem, M_CLASSIFIER_CNN_PREDEFINED, M_DEFAULT, M_UNIQUE_ID);
      MclassCopyResult(TrainRes, M_DEFAULT, TrainedCtx, M_DEFAULT, M_TRAINED_CLASSIFIER_CNN, M_DEFAULT);

      MosPrintf(MIL_TEXT("A training report was saved: \"TrainReport.csv\".\n"));
      MclassExport(MIL_TEXT("TrainReport.csv"), M_FORMAT_TXT, TrainRes, M_DEFAULT, M_TRAIN_REPORT, M_DEFAULT);

      MIL_DOUBLE TrainErrorRate = 0;
      MclassGetResult(TrainRes, M_DEFAULT, M_TRAIN_DATASET_ERROR_RATE, &TrainErrorRate);
      MIL_DOUBLE DevErrorRate = 0;
      MclassGetResult(TrainRes, M_DEFAULT, M_DEV_DATASET_ERROR_RATE, &DevErrorRate);

      MIL_INT LastUpdatedEpochIndex;
      MclassGetResult(TrainRes, M_DEFAULT, M_LAST_EPOCH_UPDATED_PARAMETERS + M_TYPE_MIL_INT, &LastUpdatedEpochIndex);

      MosPrintf(MIL_TEXT("\nThe best epoch was epoch %d with an error rate on the dev dataset of %.8lf.\n"), LastUpdatedEpochIndex, DevErrorRate);
      MosPrintf(MIL_TEXT("The associated train error rate is %.8lf.\n"), TrainErrorRate);

      MosPrintf(MIL_TEXT("Press <enter> to continue...\n"));
      MosGetch();
      }

   return TrainedCtx;
   }

void PredictUsingTrainedContext(MIL_ID MilSystem, MIL_ID MilDisplay, MIL_ID TrainedCtx, MIL_ID TestDataset)
   {
   CPredictResultDisplay ThePredictResultDisplay(MilSystem, MilDisplay, TestDataset);

   MclassPreprocess(TrainedCtx, M_DEFAULT);

   // Create a predict context from the train result and classify with it.
   MIL_UNIQUE_CLASS_ID PredictedDataset = MclassAlloc(MilSystem, M_DATASET_IMAGES, M_DEFAULT, M_UNIQUE_ID);

   MclassPredict(TrainedCtx, TestDataset, PredictedDataset, M_DEFAULT);

   MIL_INT NbEntries = 0;
   MIL_INT NbEntriesPredicted = 0;
   MIL_DOUBLE PredAvg = 0;
   MclassInquire(PredictedDataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, &NbEntries);
   MclassInquire(PredictedDataset, M_DEFAULT, M_NUMBER_OF_ENTRIES_PREDICTED + M_TYPE_MIL_INT, &NbEntriesPredicted);
   MclassInquire(PredictedDataset, M_DEFAULT, M_PREDICTED_SCORE_AVERAGE, &PredAvg);

   // Here we shuffle the index of the test dataset to ensure showing classification for all classes.
   const unsigned int ShuffledIndexSeed = 49;
   const std::vector<MIL_INT>& ShuffeldIndex = CreateShuffledIndex(NbEntries, ShuffledIndexSeed);

   const MIL_INT NbPredictionToShow = std::min<MIL_INT>(10, NbEntriesPredicted);

   MosPrintf(MIL_TEXT("\nPredictions will be performed on the test dataset as a final check\nof the trained CNN model.\n"));
   MosPrintf(MIL_TEXT("The test dataset contains %d images.\n"), NbEntries);
   MosPrintf(MIL_TEXT("The prediction results will be shown for the first %d images.\n"), NbPredictionToShow);

   MIL_INT NbGoodPredictions = 0;
   for(MIL_INT i = 0; i < NbEntries; i++)
      {
      // Check that entry was predicted
      MIL_INT EntryPredicted = 0;
      MclassInquireEntry(PredictedDataset, ShuffeldIndex[i], M_DEFAULT_KEY, M_REGION_INDEX(0), M_CLASS_INDEX_PREDICTED + M_NB_ELEMENTS + M_TYPE_MIL_INT, &EntryPredicted);

      if(EntryPredicted == 1)
         {
         MIL_INT GroundTruthIndex;
         MclassInquireEntry(TestDataset, ShuffeldIndex[i], M_DEFAULT_KEY, M_REGION_INDEX(0), M_CLASS_INDEX_GROUND_TRUTH + M_TYPE_MIL_INT, &GroundTruthIndex);

         MIL_INT PredIndex = 0;
         MclassInquireEntry(PredictedDataset, ShuffeldIndex[i], M_DEFAULT_KEY, M_REGION_INDEX(0), M_CLASS_INDEX_PREDICTED + M_TYPE_MIL_INT, &PredIndex);
         std::vector<MIL_DOUBLE> PredScores;
         MclassInquireEntry(PredictedDataset, ShuffeldIndex[i], M_DEFAULT_KEY, M_REGION_INDEX(0), M_PREDICTED_CLASS_SCORES, PredScores);

         if(PredIndex == GroundTruthIndex)
            NbGoodPredictions++;

         if(i < NbPredictionToShow)
            {
            MIL_STRING FilePath;
            MclassInquireEntry(TestDataset, ShuffeldIndex[i], M_DEFAULT_KEY, M_DEFAULT, M_FILE_PATH, FilePath);

            MIL_UNIQUE_BUF_ID ImageToPredict = MbufRestore(FilePath, MilSystem, M_UNIQUE_ID);

            ThePredictResultDisplay.Update(ImageToPredict, PredIndex, PredScores[PredIndex]);
            MosPrintf(MIL_TEXT("The predicted index is %d and the predicted score is %.2lf%s (Ground truth=%d)\n"), PredIndex, PredScores[PredIndex], MIL_TEXT("%"), GroundTruthIndex);

            MosPrintf(MIL_TEXT("Press <enter> to continue...\n"));
            MosGetch();
            }
         }
      else
         {
         MIL_STRING FilePath;
         MclassInquireEntry(TestDataset, ShuffeldIndex[i], M_DEFAULT_KEY, M_DEFAULT, M_FILE_PATH, FilePath);
         MosPrintf(MIL_TEXT("The image \"%s\" failed to be predicted.\n"), FilePath.c_str());
         }
      }

   MclassSave(MIL_TEXT("FabricsNet_Gray.mclass"), TrainedCtx, M_DEFAULT);

   MosPrintf(MIL_TEXT("The accuracy on the test dataset using the trained context is %.2lf%s.\n"), ((MIL_DOUBLE)NbGoodPredictions / (MIL_DOUBLE)NbEntriesPredicted)*100.0, MIL_TEXT("%"));
   MosPrintf(MIL_TEXT("The average predicted score on the test dataset using the trained\ncontext is %.2lf%s.\n"), PredAvg, MIL_TEXT("%"));
   MosPrintf(MIL_TEXT("The trained context was saved: \"FabricsNet_Gray.mclass\".\n"));
   MosPrintf(MIL_TEXT("Press <enter> to end...\n"));

   MosGetch();
   }


CTrainEvolutionDashboard::CTrainEvolutionDashboard(MIL_ID MilSystem, MIL_INT MaxEpoch, MIL_INT MinibatchSize,
                                                   MIL_DOUBLE LearningRate,
                                                   MIL_INT TrainImageSizeX, MIL_INT TrainImageSizeY,
                                                   MIL_INT TrainDatasetSize, MIL_INT DevDatasetSize,
                                                   MIL_INT TrainEngineUsed, MIL_STRING& TrainEngineDescription):
   m_DashboardBufId(M_NULL),
   m_TheGraContext(M_NULL),
   m_EpochInfoBufId(M_NULL),
   m_EpochGraphBufId(M_NULL),
   m_LossInfoBufId(M_NULL),
   m_LossGraphBufId(M_NULL),
   m_ProgressionInfoBufId(M_NULL),
   m_MaxEpoch(MaxEpoch),
   m_DashboardWidth(0),
   m_LastTrainPosX(0),
   m_LastTrainPosY(0),
   m_LastDevPosX(0),
   m_LastDevPosY(0),
   m_LastTrainMinibatchPosX(0),
   m_LastTrainMinibatchPosY(0),
   m_YPositionForLossText(0),
   m_EpochBenchMean(-1.0),
   GRAPH_SIZE_X(400),
   GRAPH_SIZE_Y(400),
   GRAPH_TOP_MARGIN(30),
   MARGIN(50),
   EPOCH_AND_MINIBATCH_REGION_HEIGHT(190),
   PROGRESSION_INFO_REGION_HEIGHT(100),
   LOSS_EXPONENT_MAX(0),
   LOSS_EXPONENT_MIN(-5),
   COLOR_GENERAL_INFO(M_RGB888(0, 176, 255)),
   COLOR_DEV_SET_INFO(M_COLOR_MAGENTA),
   COLOR_TRAIN_SET_INFO(M_COLOR_GREEN),
   COLOR_PROGRESS_BAR(M_COLOR_DARK_GREEN)
   {
   // One graph width.
   const MIL_INT GraphBoxWidth = GRAPH_SIZE_X + 2 * MARGIN;
   const MIL_INT GraphBoxHeight = GRAPH_SIZE_Y + MARGIN + GRAPH_TOP_MARGIN;
   // There are 2 graphs side by side.
   m_DashboardWidth = 2 * GraphBoxWidth;

   const MIL_INT DashboardHeight = EPOCH_AND_MINIBATCH_REGION_HEIGHT + GraphBoxHeight + PROGRESSION_INFO_REGION_HEIGHT;

   // Allocate the full dashboard buffer.
   m_DashboardBufId = MbufAllocColor(MilSystem, 3, m_DashboardWidth, DashboardHeight,
                                     8 + M_UNSIGNED, M_IMAGE + M_PROC + M_DISP, M_UNIQUE_ID);
   MbufClear(m_DashboardBufId, M_COLOR_BLACK);

   m_TheGraContext = MgraAlloc(MilSystem, M_UNIQUE_ID);

   // Allocate child buffers for each different dashboard sections.
   const MIL_INT GraphYPosition = EPOCH_AND_MINIBATCH_REGION_HEIGHT;
   const MIL_INT ProgressionInfoYPosition = GraphYPosition + GraphBoxHeight;

   m_EpochInfoBufId = MbufChild2d(m_DashboardBufId, 0, 0, GraphBoxWidth, EPOCH_AND_MINIBATCH_REGION_HEIGHT, M_UNIQUE_ID);
   m_LossInfoBufId = MbufChild2d(m_DashboardBufId, GraphBoxWidth, 0, GraphBoxWidth, EPOCH_AND_MINIBATCH_REGION_HEIGHT, M_UNIQUE_ID);
   m_EpochGraphBufId = MbufChild2d(m_DashboardBufId, 0, GraphYPosition, GraphBoxWidth, GraphBoxHeight, M_UNIQUE_ID);
   m_LossGraphBufId = MbufChild2d(m_DashboardBufId, GraphBoxWidth, GraphYPosition, GraphBoxWidth, GraphBoxHeight, M_UNIQUE_ID);
   m_ProgressionInfoBufId = MbufChild2d(m_DashboardBufId, 0, ProgressionInfoYPosition, m_DashboardWidth, PROGRESSION_INFO_REGION_HEIGHT, M_UNIQUE_ID);

   // Initialize the different dashboard sections.
   DrawSectionSeparators();

   InitializeEpochGraph();
   InitializeLossGraph();

   WriteGeneralTrainInfo(MinibatchSize, TrainImageSizeX, TrainImageSizeY, TrainDatasetSize,
                         DevDatasetSize, LearningRate, TrainEngineUsed, TrainEngineDescription);
   }

CTrainEvolutionDashboard::~CTrainEvolutionDashboard()
   {
   m_TheGraContext = M_NULL;
   m_EpochInfoBufId = M_NULL;
   m_LossInfoBufId = M_NULL;
   m_EpochGraphBufId = M_NULL;
   m_LossGraphBufId = M_NULL;
   m_ProgressionInfoBufId = M_NULL;
   m_DashboardBufId = M_NULL;
   }

void CTrainEvolutionDashboard::DrawBufferFrame(MIL_ID BufId, MIL_INT FrameThickness)
   {
   MIL_ID SizeX = MbufInquire(BufId, M_SIZE_X, M_NULL);
   MIL_ID SizeY = MbufInquire(BufId, M_SIZE_Y, M_NULL);

   MgraColor(m_TheGraContext, COLOR_GENERAL_INFO);
   MgraRectFill(m_TheGraContext, BufId, 0, 0, SizeX - 1, FrameThickness - 1);
   MgraRectFill(m_TheGraContext, BufId, SizeX - FrameThickness, 0, SizeX - 1, SizeY - 1);
   MgraRectFill(m_TheGraContext, BufId, 0, SizeY - FrameThickness, SizeX - 1, SizeY - 1);
   MgraRectFill(m_TheGraContext, BufId, 0, 0, FrameThickness - 1, SizeY - 1);
   }


void CTrainEvolutionDashboard::DrawSectionSeparators()
   {
   // Draw a frame for the whole dashboard.
   DrawBufferFrame(m_DashboardBufId, 4);
   // Draw a frame for each section.
   DrawBufferFrame(m_EpochInfoBufId, 2);
   DrawBufferFrame(m_EpochGraphBufId, 2);
   DrawBufferFrame(m_LossInfoBufId, 2);
   DrawBufferFrame(m_LossGraphBufId, 2);
   DrawBufferFrame(m_ProgressionInfoBufId, 2);
   }

void CTrainEvolutionDashboard::InitializeEpochGraph()
   {
   // Draw axis.
   MgraColor(m_TheGraContext, M_COLOR_WHITE);
   MgraRect(m_TheGraContext, m_EpochGraphBufId, MARGIN, GRAPH_TOP_MARGIN, MARGIN + GRAPH_SIZE_X, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y);

   MgraControl(m_TheGraContext, M_TEXT_ALIGN_HORIZONTAL, M_RIGHT);
   MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN - 5, GRAPH_TOP_MARGIN, MIL_TEXT("100"));
   MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN - 5, GRAPH_TOP_MARGIN + ((MIL_INT)(0.25*GRAPH_SIZE_Y)), MIL_TEXT("75"));
   MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN - 5, GRAPH_TOP_MARGIN + ((MIL_INT)(0.50*GRAPH_SIZE_Y)), MIL_TEXT("50"));
   MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN - 5, GRAPH_TOP_MARGIN + ((MIL_INT)(0.75*GRAPH_SIZE_Y)), MIL_TEXT("25"));
   MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN - 5, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y, MIL_TEXT("0"));

   MgraLine(m_TheGraContext, m_EpochGraphBufId, MARGIN, GRAPH_TOP_MARGIN + ((MIL_INT)(0.25*GRAPH_SIZE_Y)), MARGIN + 5, GRAPH_TOP_MARGIN + ((MIL_INT)(0.25*GRAPH_SIZE_Y)));
   MgraLine(m_TheGraContext, m_EpochGraphBufId, MARGIN, GRAPH_TOP_MARGIN + ((MIL_INT)(0.50*GRAPH_SIZE_Y)), MARGIN + 5, GRAPH_TOP_MARGIN + ((MIL_INT)(0.50*GRAPH_SIZE_Y)));
   MgraLine(m_TheGraContext, m_EpochGraphBufId, MARGIN, GRAPH_TOP_MARGIN + ((MIL_INT)(0.75*GRAPH_SIZE_Y)), MARGIN + 5, GRAPH_TOP_MARGIN + ((MIL_INT)(0.75*GRAPH_SIZE_Y)));

   MgraControl(m_TheGraContext, M_TEXT_ALIGN_HORIZONTAL, M_LEFT);

   MIL_INT NbTick = std::min<MIL_INT>(m_MaxEpoch, 10);
   const MIL_INT EpochTickValue = m_MaxEpoch / NbTick;

   for(MIL_INT CurTick = 1; CurTick <= m_MaxEpoch; CurTick += EpochTickValue)
      {
      MIL_DOUBLE Percentage = (MIL_DOUBLE)CurTick / (MIL_DOUBLE)m_MaxEpoch;
      MIL_INT XOffset = (MIL_INT)(Percentage * GRAPH_SIZE_X);
      MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN + XOffset, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y + 5, M_TO_STRING(CurTick-1));
      MgraLine(m_TheGraContext, m_EpochGraphBufId, MARGIN + XOffset, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y - 5, MARGIN + XOffset, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y);
      }
   }

void CTrainEvolutionDashboard::InitializeLossGraph()
   {
   // Draw axis.
   MgraColor(m_TheGraContext, M_COLOR_WHITE);
   MgraRect(m_TheGraContext, m_LossGraphBufId, MARGIN, GRAPH_TOP_MARGIN, MARGIN + GRAPH_SIZE_X, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y);

   MgraControl(m_TheGraContext, M_TEXT_ALIGN_HORIZONTAL, M_RIGHT);

   const MIL_INT NbLossValueTick = LOSS_EXPONENT_MAX - LOSS_EXPONENT_MIN;
   const MIL_DOUBLE TickRatio = 1.0/(MIL_DOUBLE)NbLossValueTick;

   MIL_DOUBLE TickNum = 0.0;
   for(MIL_INT i = LOSS_EXPONENT_MAX; i >= LOSS_EXPONENT_MIN; i--)
      {
      MIL_TEXT_CHAR CurTickText[128];
      MosSprintf(CurTickText, 128, MIL_TEXT("1e%d"), i);

      MIL_INT TickYPos = (MIL_INT)(TickNum*TickRatio*GRAPH_SIZE_Y);
      MgraText(m_TheGraContext, m_LossGraphBufId, MARGIN - 5, GRAPH_TOP_MARGIN + TickYPos, CurTickText);
      if(i!=LOSS_EXPONENT_MAX && i!= LOSS_EXPONENT_MIN)
         {
         MgraLine(m_TheGraContext, m_LossGraphBufId, MARGIN, GRAPH_TOP_MARGIN + TickYPos, MARGIN + 5, GRAPH_TOP_MARGIN + TickYPos);
         }
      TickNum = TickNum + 1.0;
      }

   MgraControl(m_TheGraContext, M_TEXT_ALIGN_HORIZONTAL, M_LEFT);

   const MIL_INT NbEpochTick    = std::min<MIL_INT>(m_MaxEpoch, 10);
   const MIL_INT EpochTickValue = m_MaxEpoch / NbEpochTick;

   for(MIL_INT CurTick = 1; CurTick <= m_MaxEpoch; CurTick += EpochTickValue)
      {
      MIL_DOUBLE Percentage = (MIL_DOUBLE)CurTick / (MIL_DOUBLE)m_MaxEpoch;
      MIL_INT XOffset = (MIL_INT)(Percentage * GRAPH_SIZE_X);
      MgraText(m_TheGraContext, m_LossGraphBufId, MARGIN + XOffset, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y + 5, M_TO_STRING(CurTick-1));
      MgraLine(m_TheGraContext, m_LossGraphBufId, MARGIN + XOffset, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y - 5, MARGIN + XOffset, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y);
      }
   }

void CTrainEvolutionDashboard::WriteGeneralTrainInfo(MIL_INT     MinibatchSize,
                                                     MIL_INT     TrainImageSizeX,
                                                     MIL_INT     TrainImageSizeY,
                                                     MIL_INT     TrainDatasetSize,
                                                     MIL_INT     DevDatasetSize,
                                                     MIL_DOUBLE  LearningRate,
                                                     MIL_INT     TrainEngineUsed,
                                                     MIL_STRING& TrainEngineDescription)
   {
   MgraControl(m_TheGraContext, M_BACKGROUND_MODE, M_OPAQUE);
   MgraControl(m_TheGraContext, M_BACKCOLOR, M_COLOR_BLACK);

   MgraControl(m_TheGraContext, M_TEXT_ALIGN_HORIZONTAL, M_LEFT);

   const MIL_INT YMargin = 15;
   const MIL_INT TextHeight = 20;
   const MIL_INT TextMargin = MARGIN - 10;

   MIL_INT TextYPos = YMargin;

   MgraColor(m_TheGraContext, COLOR_GENERAL_INFO);

   MIL_TEXT_CHAR TheString[512];
   if(TrainEngineUsed==M_CPU)
      MosSprintf(TheString, 512, MIL_TEXT("Training is being performed on the CPU"));
   else
      MosSprintf(TheString, 512, MIL_TEXT("Training is being performed on the GPU"));
   MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
   TextYPos += TextHeight;

   MosSprintf(TheString, 512, MIL_TEXT("Training engine: %s"), TrainEngineDescription.c_str());
   MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
   TextYPos += TextHeight;

   MosSprintf(TheString, 512, MIL_TEXT("Train image size: %dx%d"), TrainImageSizeX, TrainImageSizeY);
   MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
   TextYPos += TextHeight;

   MosSprintf(TheString, 512, MIL_TEXT("Train and Dev dataset size: %d and %d images"), TrainDatasetSize, DevDatasetSize);
   MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
   TextYPos += TextHeight;

   MosSprintf(TheString, 512, MIL_TEXT("Max number of epochs: %d"), m_MaxEpoch);
   MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
   TextYPos += TextHeight;

   MosSprintf(TheString, 512, MIL_TEXT("Minibatch size: %d"), MinibatchSize);
   MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
   TextYPos += TextHeight;

   MosSprintf(TheString, 512, MIL_TEXT("Learning rate: %.2e"), LearningRate);
   MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
   TextYPos += TextHeight;

   // The loss will be drawn under later on, so we retain is position.
   m_YPositionForLossText = TextYPos;
   }

void CTrainEvolutionDashboard::AddEpochData(MIL_DOUBLE TrainErrorRate, MIL_DOUBLE DevErrorRate,
                                            MIL_INT CurEpoch, bool TheEpochIsTheBestUpToNow,
                                            MIL_DOUBLE EpochBenchMean)
   {
   m_EpochBenchMean = EpochBenchMean;
   UpdateEpochInfo(TrainErrorRate, DevErrorRate, CurEpoch, TheEpochIsTheBestUpToNow);
   UpdateEpochGraph(TrainErrorRate, DevErrorRate, CurEpoch);
   }

void CTrainEvolutionDashboard::AddMiniBatchData(MIL_DOUBLE LossError, MIL_INT MinibatchIdx, MIL_INT EpochIdx, MIL_INT NbBatchPerEpoch)
   {
   UpdateLoss(LossError);
   UpdateLossGraph(LossError, MinibatchIdx, EpochIdx, NbBatchPerEpoch);
   UpdateProgression(MinibatchIdx, EpochIdx, NbBatchPerEpoch);
   }

void CTrainEvolutionDashboard::UpdateEpochInfo(MIL_DOUBLE TrainErrorRate, MIL_DOUBLE DevErrorRate, MIL_INT CurEpoch, bool TheEpochIsTheBestUpToNow)
   {
   const MIL_INT YMargin = 15;
   const MIL_INT TextHeight = 20;
   const MIL_INT TextMargin = MARGIN - 10;

   MgraColor(m_TheGraContext, COLOR_DEV_SET_INFO);
   MIL_TEXT_CHAR DevError[512];
   MosSprintf(DevError, 512, MIL_TEXT("Current Dev error rate: %7.4lf %%"), DevErrorRate);
   MgraText(m_TheGraContext, m_EpochInfoBufId, TextMargin, YMargin, DevError);

   MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
   MIL_TEXT_CHAR TrainError[512];
   MosSprintf(TrainError, 512, MIL_TEXT("Current Train error rate: %7.4lf %%"), TrainErrorRate);
   MgraText(m_TheGraContext, m_EpochInfoBufId, TextMargin, YMargin + TextHeight, TrainError);

   if(TheEpochIsTheBestUpToNow)
      {
      MgraColor(m_TheGraContext, COLOR_DEV_SET_INFO);
      MIL_TEXT_CHAR BestDevError[512];
      MosSprintf(BestDevError, 512, MIL_TEXT("Best epoch Dev error rate: %7.4lf %%   (Epoch %d)"), DevErrorRate, CurEpoch);
      MgraText(m_TheGraContext, m_EpochInfoBufId, TextMargin, YMargin + 2 * TextHeight, BestDevError);

      MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
      MIL_TEXT_CHAR TrainErrorBest[512];
      MosSprintf(TrainErrorBest, 512, MIL_TEXT("Train error rate for the best epoch: %7.4lf %%"), TrainErrorRate);
      MgraText(m_TheGraContext, m_EpochInfoBufId, TextMargin, YMargin + 3 * TextHeight, TrainErrorBest);
      }
   }

void CTrainEvolutionDashboard::UpdateLoss(MIL_DOUBLE LossError)
   {
   const MIL_INT TextMargin = MARGIN - 10;

   MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
   MIL_TEXT_CHAR LossText[512];
   MosSprintf(LossText, 512, MIL_TEXT("Current loss value: %11.7lf"), LossError);

   MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, m_YPositionForLossText, LossText);
   }

void CTrainEvolutionDashboard::UpdateEpochGraph(MIL_DOUBLE TrainErrorRate, MIL_DOUBLE DevErrorRate, MIL_INT CurEpoch)
   {
   MIL_INT EpochIndex = CurEpoch + 1;
   MIL_INT CurTrainPosX = MARGIN + (MIL_INT)((MIL_DOUBLE)(EpochIndex) / (MIL_DOUBLE)(m_MaxEpoch)*(MIL_DOUBLE)GRAPH_SIZE_X);
   MIL_INT CurTrainPosY = GRAPH_TOP_MARGIN + (MIL_INT)((MIL_DOUBLE)GRAPH_SIZE_Y*(1.0 - TrainErrorRate * 0.01));

   MIL_INT CurDevPosX = CurTrainPosX;
   MIL_INT CurDevPosY = GRAPH_TOP_MARGIN + (MIL_INT)((MIL_DOUBLE)GRAPH_SIZE_Y*(1.0 - DevErrorRate * 0.01));

   if(CurEpoch == 0)
      {
      MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
      MgraArcFill(m_TheGraContext, m_EpochGraphBufId, CurTrainPosX, CurTrainPosY, 2, 2, 0, 360);
      MgraColor(m_TheGraContext, COLOR_DEV_SET_INFO);
      MgraArcFill(m_TheGraContext, m_EpochGraphBufId, CurDevPosX, CurDevPosY, 2, 2, 0, 360);
      }
   else
      {
      MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
      MgraLine(m_TheGraContext, m_EpochGraphBufId, m_LastTrainPosX, m_LastTrainPosY, CurTrainPosX, CurTrainPosY);
      MgraColor(m_TheGraContext, COLOR_DEV_SET_INFO);
      MgraLine(m_TheGraContext, m_EpochGraphBufId, m_LastDevPosX, m_LastDevPosY, CurDevPosX, CurDevPosY);
      }

   m_LastTrainPosX = CurTrainPosX;
   m_LastTrainPosY = CurTrainPosY;
   m_LastDevPosX = CurDevPosX;
   m_LastDevPosY = CurDevPosY;

   MgraColor(m_TheGraContext, COLOR_GENERAL_INFO);
   MIL_TEXT_CHAR EpochText[128];
   MosSprintf(EpochText, 128, MIL_TEXT("Epoch %d completed"), CurEpoch);
   MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y + 25, EpochText);
   }


void CTrainEvolutionDashboard::UpdateLossGraph(MIL_DOUBLE LossError, MIL_INT MiniBatchIdx, MIL_INT EpochIdx, MIL_INT NbBatchPerEpoch)
   {
   MIL_INT NBMiniBatch = m_MaxEpoch * NbBatchPerEpoch;
   MIL_INT CurMiniBatch = EpochIdx * NbBatchPerEpoch + MiniBatchIdx;

   MIL_DOUBLE XRatio = ((MIL_DOUBLE)CurMiniBatch / (MIL_DOUBLE)(NBMiniBatch));

   MIL_INT CurTrainMBPosX = MARGIN + (MIL_INT)(XRatio * (MIL_DOUBLE)GRAPH_SIZE_X);

   const MIL_DOUBLE MaxVal = std::pow(10.0, LOSS_EXPONENT_MAX);
   const MIL_INT    NbTick = LOSS_EXPONENT_MAX - LOSS_EXPONENT_MIN;

   // Saturate to the highest value of the graph.
   LossError = std::min<MIL_DOUBLE>(LossError, MaxVal);
   MIL_DOUBLE Log10RemapPos = std::max<MIL_DOUBLE>(log10(LossError) + (-LOSS_EXPONENT_MIN), 0.0);
   MIL_DOUBLE YRatio = Log10RemapPos / (MIL_DOUBLE)NbTick;

   MIL_INT CurTrainMBPosY = GRAPH_TOP_MARGIN + (MIL_INT)((MIL_DOUBLE)GRAPH_SIZE_Y*(1.0 - YRatio));

   if(EpochIdx == 0 && MiniBatchIdx == 0)
      {
      MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
      MgraDot(m_TheGraContext, m_LossGraphBufId, CurTrainMBPosX, CurTrainMBPosY);
      }
   else
      {
      MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
      MgraLine(m_TheGraContext, m_LossGraphBufId, m_LastTrainMinibatchPosX, m_LastTrainMinibatchPosY, CurTrainMBPosX, CurTrainMBPosY);
      }

   m_LastTrainMinibatchPosX = CurTrainMBPosX;
   m_LastTrainMinibatchPosY = CurTrainMBPosY;

   MgraColor(m_TheGraContext, COLOR_GENERAL_INFO);
   // To clear the previous information.
   MgraText(m_TheGraContext, m_LossGraphBufId, MARGIN, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y + 25, MIL_TEXT("                                                    "));
   MIL_TEXT_CHAR EpochText[512];
   MosSprintf(EpochText, 512, MIL_TEXT("Epoch %d :: Minibatch %d"), EpochIdx, MiniBatchIdx);
   MgraText(m_TheGraContext, m_LossGraphBufId, MARGIN, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y + 25, EpochText);
   }

void CTrainEvolutionDashboard::UpdateProgression(MIL_INT MinibatchIdx, MIL_INT EpochIdx, MIL_INT NbBatchPerEpoch)
   {
   const MIL_INT YMargin = 20;
   const MIL_INT TextHeight = 30;

   const MIL_INT NbMinibatch = m_MaxEpoch * NbBatchPerEpoch;
   const MIL_INT NbMinibatchDone = EpochIdx * NbBatchPerEpoch + MinibatchIdx + 1;
   const MIL_INT NbMinibatchRemaining = NbMinibatch - NbMinibatchDone - 1;

   // Update estimated remaining time.
   MgraColor(m_TheGraContext, COLOR_GENERAL_INFO);

   // The first epoch implied data loading and cannot be used to estimate the
   // remaining time accurately.
   if(EpochIdx == 0)
      {
      MgraText(m_TheGraContext, m_ProgressionInfoBufId, MARGIN, YMargin, MIL_TEXT("Estimated remaining time: N/A"));
      }
   else
      {
      MIL_DOUBLE MinibatchBenchMean = m_EpochBenchMean/(MIL_DOUBLE)NbBatchPerEpoch;
      MIL_DOUBLE RemainingTime = MinibatchBenchMean * (MIL_DOUBLE)NbMinibatchRemaining;
      MIL_TEXT_CHAR RemainingTimeText[512];
      MosSprintf(RemainingTimeText, 512, MIL_TEXT("Estimated remaining time: %8.0lf seconds"), RemainingTime);

      if(NbMinibatchDone == NbMinibatch)
         MgraText(m_TheGraContext, m_ProgressionInfoBufId, MARGIN, YMargin, MIL_TEXT("Training completed!                         "));
      else
         MgraText(m_TheGraContext, m_ProgressionInfoBufId, MARGIN, YMargin, RemainingTimeText);
      }

   // Update the progression bar.
   const MIL_INT ProgressionBarWidth = m_DashboardWidth - 2 * MARGIN;
   const MIL_INT ProgressionBarHeight = 30;
   MgraColor(m_TheGraContext, COLOR_GENERAL_INFO);
   MgraRectFill(m_TheGraContext, m_ProgressionInfoBufId, MARGIN, YMargin + TextHeight, MARGIN + ProgressionBarWidth, YMargin + TextHeight + ProgressionBarHeight);

   MIL_DOUBLE PercentageComplete = (MIL_DOUBLE)(NbMinibatchDone) / (MIL_DOUBLE)(NbMinibatch);
   MIL_INT PercentageCompleteWidth = (MIL_INT)(PercentageComplete*ProgressionBarWidth);
   MgraColor(m_TheGraContext, COLOR_PROGRESS_BAR);
   MgraRectFill(m_TheGraContext, m_ProgressionInfoBufId, MARGIN, YMargin + TextHeight, MARGIN + PercentageCompleteWidth, YMargin + TextHeight + ProgressionBarHeight);
   }

CPredictResultDisplay::CPredictResultDisplay(MIL_ID MilSystem, MIL_ID MilDisplay, MIL_ID TestDataset):
   m_MilSystem(MilSystem),
   m_MilDisplay(MilDisplay),
   m_MaxTrainImageSize(-1),
   m_MilDispImage(M_NULL),
   m_MilDispChild(M_NULL),
   m_MilOverlay(M_NULL),
   m_GraContext(M_NULL),
   COLOR_PREDICT_INFO(M_COLOR_GREEN),
   MARGIN(100)
   {
   MIL_INT NbClassDef = MclassInquire(TestDataset, M_DEFAULT, M_NUMBER_OF_CLASSES, M_NULL);
   std::vector<MIL_ID> ClassImages(NbClassDef);

   for(MIL_INT i = 0; i < NbClassDef; i++)
      {
      ClassImages[i] = MclassInquire(TestDataset, M_CLASS_INDEX(i), M_CLASS_ICON_ID + M_TYPE_MIL_ID, M_NULL);
      MIL_INT SizeX = MbufInquire(ClassImages[i], M_SIZE_X, M_NULL);
      MIL_INT SizeY = MbufInquire(ClassImages[i], M_SIZE_Y, M_NULL);
      m_MaxTrainImageSize = std::max<MIL_INT>(m_MaxTrainImageSize, SizeX);
      m_MaxTrainImageSize = std::max<MIL_INT>(m_MaxTrainImageSize, SizeY);
      }

   // Allocate a color buffer.
   m_MilDispImage = MbufAllocColor(m_MilSystem, 3, 2 * m_MaxTrainImageSize + MARGIN, NbClassDef*m_MaxTrainImageSize, 8 + M_UNSIGNED, M_IMAGE + M_PROC + M_DISP, M_UNIQUE_ID);
   MbufClear(m_MilDispImage, M_COLOR_BLACK);
   m_MilDispChild = MbufChild2d(m_MilDispImage, MARGIN / 2, m_MaxTrainImageSize, m_MaxTrainImageSize, m_MaxTrainImageSize, M_UNIQUE_ID);

   // Set annotation color.
   m_GraContext = MgraAlloc(MilSystem, M_UNIQUE_ID);
   MgraColor(m_GraContext, M_COLOR_RED);

   MIL_INT PosY = 0;
   for(const auto& Image : ClassImages)
      {
      MbufCopyColor2d(Image, m_MilDispImage, M_ALL_BANDS, 0, 0, M_ALL_BANDS, m_MaxTrainImageSize + MARGIN, PosY, m_MaxTrainImageSize, m_MaxTrainImageSize);
      MgraRect(m_GraContext, m_MilDispImage, m_MaxTrainImageSize + MARGIN, PosY, m_MaxTrainImageSize + MARGIN + m_MaxTrainImageSize - 1, PosY + m_MaxTrainImageSize - 1);
      PosY += m_MaxTrainImageSize;
      }

   // Display the window with black color.
   MdispSelect(m_MilDisplay, m_MilDispImage);

   // Prepare for overlay annotations.
   MdispControl(m_MilDisplay, M_OVERLAY, M_ENABLE);
   m_MilOverlay = MdispInquire(MilDisplay, M_OVERLAY_ID, M_NULL);
   }

CPredictResultDisplay::~CPredictResultDisplay()
   {
   m_GraContext = M_NULL;
   m_MilDispChild = M_NULL;
   m_MilDispImage = M_NULL;
   }

void CPredictResultDisplay::Update(MIL_ID ImageToPredict, MIL_INT BestIndex, MIL_DOUBLE BestScore)
   {
   MdispControl(m_MilDisplay, M_UPDATE, M_DISABLE);
   MdispControl(m_MilDisplay, M_OVERLAY_CLEAR, M_TRANSPARENT_COLOR);
   MbufCopy(ImageToPredict, m_MilDispChild);

   MIL_INT RectOffsetX = m_MaxTrainImageSize + 100;
   MIL_INT RectOffsetY = BestIndex * m_MaxTrainImageSize;

   MgraColor(m_GraContext, COLOR_PREDICT_INFO);
   MgraRect(m_GraContext, m_MilOverlay, RectOffsetX, RectOffsetY, RectOffsetX + m_MaxTrainImageSize - 1, RectOffsetY + m_MaxTrainImageSize - 1);
   MIL_TEXT_CHAR Accuracy_text[256];
   MosSprintf(Accuracy_text, 256, MIL_TEXT("%.2lf%%"), BestScore);
   MgraControl(m_GraContext, M_BACKGROUND_MODE, M_OPAQUE);
   MgraFont(m_GraContext, M_FONT_DEFAULT_SMALL);
   MgraText(m_GraContext, m_MilOverlay, RectOffsetX + 2, RectOffsetY + 2, Accuracy_text);

   MdispControl(m_MilDisplay, M_UPDATE, M_ENABLE);
   }