#include <windows.h>
#include <mil.h>
#include <string>
#include <algorithm>
#include <random>
#include <numeric>
#include <map>
#include <math.h>
static const MIL_INT NUMBER_OF_CLASSES = 3;
void PrintHeader()
{
MosPrintf(MIL_TEXT("[EXAMPLE NAME]\n")
MIL_TEXT("ClassCNNCompleteTrain\n\n")
MIL_TEXT("[SYNOPSIS]\n")
MIL_TEXT("This example trains a CNN model to classify the %d shown fabrics.\n")
MIL_TEXT("The first step prepares the datasets needed for the training.\n")
MIL_TEXT("The second step trains a context and display the train evolution.\n")
MIL_TEXT("The final step performs predictions on test data using the trained\n")
MIL_TEXT("CNN model as a final check of the expected model performance.\n\n")
MIL_TEXT("[MODULES USED]\n")
MIL_TEXT("Modules used: application, system, display, buffer,\n")
MIL_TEXT("graphic, classification.\n\n"), NUMBER_OF_CLASSES);
MosPrintf(MIL_TEXT("Press <Enter> to continue.\n\n"));
MosGetch();
}
#define EXAMPLE_IMAGE_ROOT_PATH M_IMAGE_PATH MIL_TEXT("Classification/Fabrics/")
#define EXAMPLE_ORIGINAL_DATA_PATH M_IMAGE_PATH MIL_TEXT("Classification/Fabrics/OriginalData/")
#define EXAMPLE_DATA_PATH MIL_TEXT("Data/")
MIL_STRING FABRICS_CLASS_NAME[NUMBER_OF_CLASSES] = { MIL_TEXT("Fabric1"),
MIL_TEXT("Fabric2"),
MIL_TEXT("Fabric3") };
MIL_INT FABRICS_CLASS_NB_IMAGES[NUMBER_OF_CLASSES] = { 200, 200, 200 };
MIL_STRING FABRICS_CLASS_ICON[NUMBER_OF_CLASSES] = { EXAMPLE_IMAGE_ROOT_PATH MIL_TEXT("Fabric1_Icon.mim"),
EXAMPLE_IMAGE_ROOT_PATH MIL_TEXT("Fabric2_Icon.mim"),
EXAMPLE_IMAGE_ROOT_PATH MIL_TEXT("Fabric3_Icon.mim") };
static const MIL_INT TRAIN_IMAGE_SIZE = 99;
static const MIL_INT NB_AUGMENTATION_PER_IMAGE = 2;
MIL_INT CnnTrainEngineDLLInstalled(MIL_ID MilSystem);
MIL_STRING GetExampleCurrentDirectory();
const std::vector<MIL_INT> CreateShuffledIndex(MIL_INT NbEntries, unsigned int Seed);
void DeleteFiles(const std::vector<MIL_STRING>& Files);
void ListFilesInFolder(const MIL_STRING& FolderName, std::vector<MIL_STRING>& FilesInFolder);
void CopyOriginalDataToExampleDataFolder(const MIL_STRING* FabricsClassName,
MIL_INT NumberOfClasses,
const MIL_STRING& OriginalDataPath,
const MIL_STRING& ExampleDataPath);
void DeleteFilesInFolder(const MIL_STRING& FolderName);
void AddClassDescription(MIL_ID MilSystem,
MIL_ID Dataset,
const MIL_STRING* FabricsClassName,
const MIL_STRING* FabricsClassIcon,
MIL_INT NumberOfClasses);
void PrepareTheDatasets(MIL_ID MilSystem,
const MIL_STRING* FabricsClassName,
const MIL_STRING* FabricsClassIcon,
MIL_INT NumberOfClasses,
const MIL_STRING& OriginalDataPath,
const MIL_STRING& ExampleDataPath,
MIL_ID TrainDataset,
MIL_ID DevDataset,
MIL_ID TestDataset);
void PrepareExampleDataFolder(const MIL_STRING& ExampleDataPath, const MIL_STRING* FabricsClassName, MIL_INT NumberOfClasses);
void AddClassToDataset(MIL_INT ClassIndex, const MIL_STRING& DataToTrainPath, const MIL_STRING& FabricName, MIL_ID Dataset);
void AugmentDataset(MIL_ID System, MIL_ID Dataset, MIL_INT NbAugmentPerImage);
void CropDatasetImages(MIL_ID MilSystem, MIL_ID Dataset, MIL_INT FinalImageSize);
MIL_UNIQUE_BUF_ID CreateImageOfAllClasses(MIL_ID MilSystem, const MIL_STRING* FabricClassIcon, MIL_INT NumberOfClasses);
MIL_UNIQUE_CLASS_ID TrainTheModel(MIL_ID MilSystem, MIL_ID TrainDataset, MIL_ID DevDataset, MIL_ID MilDisplay);
void PredictUsingTrainedContext(MIL_ID MilSystem, MIL_ID MilDisplay, MIL_ID TrainedCtx, MIL_ID TestDataset);
class CTrainEvolutionDashboard
{
public:
CTrainEvolutionDashboard(MIL_ID MilSystem, MIL_INT MaxEpoch, MIL_INT MinibatchSize,
MIL_DOUBLE LearningRate,
MIL_INT TrainImageSizeX, MIL_INT TrainImageSizeY,
MIL_INT TrainDatasetSize, MIL_INT DevDatasetSize,
MIL_INT TrainEngineUsed, MIL_STRING& TrainEngineDescription);
~CTrainEvolutionDashboard();
void AddEpochData(MIL_DOUBLE TrainErrorRate, MIL_DOUBLE DevErrorRate,
MIL_INT CurEpoch, bool TheEpochIsTheBestUpToNow,
MIL_DOUBLE EpochBenchMean);
void AddMiniBatchData(MIL_DOUBLE LossError, MIL_INT MinibatchIdx, MIL_INT EpochIdx, MIL_INT NbBatchPerEpoch);
MIL_ID GetDashboardBufId()
{
return m_DashboardBufId;
}
protected:
void UpdateEpochInfo(MIL_DOUBLE TrainErrorRate, MIL_DOUBLE DevErrorRate, MIL_INT CurEpoch, bool TheEpochIsTheBestUpToNow);
void UpdateLoss(MIL_DOUBLE LossError);
void UpdateEpochGraph(MIL_DOUBLE TrainErrorRate, MIL_DOUBLE DevErrorRate, MIL_INT CurEpoch);
void UpdateLossGraph(MIL_DOUBLE LossError, MIL_INT MiniBatchIdx, MIL_INT EpochIdx, MIL_INT NbBatchPerEpoch);
void UpdateProgression(MIL_INT MinibatchIdx, MIL_INT EpochIdx, MIL_INT NbBatchPerEpoch);
void DrawSectionSeparators();
void DrawBufferFrame(MIL_ID BufId, MIL_INT FrameThickness);
void InitializeEpochGraph();
void InitializeLossGraph();
void WriteGeneralTrainInfo(MIL_INT MinibatchSize,
MIL_INT TrainImageSizeX,
MIL_INT TrainImageSizeY,
MIL_INT TrainDatasetSize,
MIL_INT DevDatasetSize,
MIL_DOUBLE LearningRate,
MIL_INT TrainEngineUsed,
MIL_STRING& TrainEngineDescription);
MIL_UNIQUE_BUF_ID m_DashboardBufId;
MIL_UNIQUE_GRA_ID m_TheGraContext;
MIL_UNIQUE_BUF_ID m_EpochInfoBufId;
MIL_UNIQUE_BUF_ID m_EpochGraphBufId;
MIL_UNIQUE_BUF_ID m_LossInfoBufId;
MIL_UNIQUE_BUF_ID m_LossGraphBufId;
MIL_UNIQUE_BUF_ID m_ProgressionInfoBufId;
MIL_INT m_MaxEpoch;
MIL_INT m_DashboardWidth;
MIL_INT m_LastTrainPosX;
MIL_INT m_LastTrainPosY;
MIL_INT m_LastDevPosX;
MIL_INT m_LastDevPosY;
MIL_INT m_LastTrainMinibatchPosX;
MIL_INT m_LastTrainMinibatchPosY;
MIL_INT m_YPositionForLossText;
MIL_DOUBLE m_EpochBenchMean;
MIL_INT GRAPH_SIZE_X;
MIL_INT GRAPH_SIZE_Y;
MIL_INT GRAPH_TOP_MARGIN;
MIL_INT MARGIN;
MIL_INT EPOCH_AND_MINIBATCH_REGION_HEIGHT;
MIL_INT PROGRESSION_INFO_REGION_HEIGHT;
MIL_INT LOSS_EXPONENT_MAX;
MIL_INT LOSS_EXPONENT_MIN;
MIL_DOUBLE COLOR_GENERAL_INFO;
MIL_DOUBLE COLOR_DEV_SET_INFO;
MIL_DOUBLE COLOR_TRAIN_SET_INFO;
MIL_DOUBLE COLOR_PROGRESS_BAR;
};
struct HookEpochData
{
CTrainEvolutionDashboard* TheDashboard;
};
struct HookMiniBatchData
{
CTrainEvolutionDashboard* TheDashboard;
};
MIL_INT MFTYPE HookFuncEpoch(MIL_INT HookType, MIL_ID EventId, void* UserData);
MIL_INT MFTYPE HookFuncMiniBatch(MIL_INT HookType, MIL_ID EventId, void* UserData);
class CPredictResultDisplay
{
public:
CPredictResultDisplay(MIL_ID MilSystem, MIL_ID MilDisplay, MIL_ID TestDataset);
~CPredictResultDisplay();
void Update(MIL_ID ImageToPredict, MIL_INT BestIndex, MIL_DOUBLE BestScore);
protected:
MIL_ID m_MilSystem;
MIL_ID m_MilDisplay;
MIL_INT m_MaxTrainImageSize;
MIL_UNIQUE_BUF_ID m_MilDispImage;
MIL_UNIQUE_BUF_ID m_MilDispChild;
MIL_ID m_MilOverlay;
MIL_UNIQUE_GRA_ID m_GraContext;
const MIL_DOUBLE COLOR_PREDICT_INFO;
const MIL_INT MARGIN;
};
int MosMain()
{
#if !M_MIL_USE_64BIT
MosPrintf(MIL_TEXT("\n***** MclassTrain() is not available with a non 64-bit platform. *****\n"));
MosPrintf(MIL_TEXT("Press <enter> to end...\n"));
MosGetch();
return -1;
#else
PrintHeader();
MIL_UNIQUE_APP_ID MilApplication = MappAlloc(M_NULL, M_DEFAULT, M_UNIQUE_ID);
MIL_UNIQUE_SYS_ID MilSystem = MsysAlloc(M_DEFAULT, M_SYSTEM_HOST, M_DEFAULT, M_DEFAULT, M_UNIQUE_ID);
if(CnnTrainEngineDLLInstalled(MilSystem)!=M_TRUE)
{
MosPrintf(MIL_TEXT("\n***** No train engine installed, MclassTrain() cannot run! *****\n"));
MosPrintf(MIL_TEXT("Press <enter> to end...\n"));
MosGetch();
return -1;
}
MIL_UNIQUE_DISP_ID MilDisplay = MdispAlloc(MilSystem, M_DEFAULT, MIL_TEXT("M_DEFAULT"), M_DEFAULT, M_UNIQUE_ID);
MIL_UNIQUE_BUF_ID AllClassesImage = CreateImageOfAllClasses(MilSystem, FABRICS_CLASS_ICON, NUMBER_OF_CLASSES);
MdispSelect(MilDisplay, AllClassesImage);
MIL_UNIQUE_CLASS_ID TrainDataset, DevDataset, TestDataset;
TrainDataset = MclassAlloc(MilSystem, M_DATASET_IMAGES, M_DEFAULT, M_UNIQUE_ID);
DevDataset = MclassAlloc(MilSystem, M_DATASET_IMAGES, M_DEFAULT, M_UNIQUE_ID);
TestDataset = MclassAlloc(MilSystem, M_DATASET_IMAGES, M_DEFAULT, M_UNIQUE_ID);
MclassControl(TrainDataset , M_CONTEXT, M_ROOT_PATH, GetExampleCurrentDirectory());
MclassControl(DevDataset , M_CONTEXT, M_ROOT_PATH, GetExampleCurrentDirectory());
MclassControl(TestDataset , M_CONTEXT, M_ROOT_PATH, GetExampleCurrentDirectory());
MosPrintf(MIL_TEXT("\n*******************************************************\n"));
MosPrintf(MIL_TEXT("PREPARING THE DATASETS... THIS WILL TAKE SOME TIME...\n"));
MosPrintf(MIL_TEXT("*******************************************************\n"));
PrepareTheDatasets(MilSystem, FABRICS_CLASS_NAME, FABRICS_CLASS_ICON, NUMBER_OF_CLASSES, EXAMPLE_ORIGINAL_DATA_PATH, EXAMPLE_DATA_PATH,
TrainDataset, DevDataset, TestDataset);
MosPrintf(MIL_TEXT("\n*******************************************************\n"));
MosPrintf(MIL_TEXT("TRAINING... THIS WILL TAKE SOME TIME...\n"));
MosPrintf(MIL_TEXT("*******************************************************\n"));
MIL_INT TrainDatasetNbImages = MclassInquire(TrainDataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, M_NULL);
MIL_INT DevDatasetNbImages = MclassInquire(DevDataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, M_NULL);
MosPrintf(MIL_TEXT("\nThe training has started.\n"));
MosPrintf(MIL_TEXT("It can be paused at any time by pressing 'p'.\n"));
MosPrintf(MIL_TEXT("It can then be stopped or continued.\n"));
MosPrintf(MIL_TEXT("\nDuring training, you can observe the displayed error rate of the train\n"));
MosPrintf(MIL_TEXT("and dev datasets together with the evolution of the loss value...\n"));
MIL_UNIQUE_CLASS_ID TrainedCtx = TrainTheModel(MilSystem, TrainDataset, DevDataset, MilDisplay);
if(TrainedCtx)
{
MosPrintf(MIL_TEXT("\n*******************************************************\n"));
MosPrintf(MIL_TEXT("PREDICTING USING THE TRAINED CONTEXT...\n"));
MosPrintf(MIL_TEXT("*******************************************************\n"));
PredictUsingTrainedContext(MilSystem, MilDisplay, TrainedCtx, TestDataset);
}
else
{
MosPrintf(MIL_TEXT("\nTraining has not completed properly !!!!!!!!!!!!!!\n"));
MosPrintf(MIL_TEXT("Press <enter> to end...\n"));
MosGetch();
}
return 0;
#endif
}
MIL_INT CnnTrainEngineDLLInstalled(MIL_ID MilSystem)
{
MIL_INT IsInstalled = M_FALSE;
MIL_UNIQUE_CLASS_ID TrainCtx = MclassAlloc(MilSystem, M_TRAIN_CNN, M_DEFAULT, M_UNIQUE_ID);
MclassInquire(TrainCtx, M_DEFAULT, M_CNN_TRAIN_ENGINE_IS_INSTALLED + M_TYPE_MIL_INT, &IsInstalled);
return IsInstalled;
}
MIL_STRING GetExampleCurrentDirectory()
{
DWORD CurDirStrSize = GetCurrentDirectory(0, NULL) + 1;
std::vector<MIL_TEXT_CHAR> vCurDir(CurDirStrSize, 0);
GetCurrentDirectory(CurDirStrSize, (LPTSTR) &vCurDir[0]);
MIL_STRING sRet = &vCurDir[0];
return sRet;
}
MIL_UNIQUE_BUF_ID CreateImageOfAllClasses(MIL_ID MilSystem, const MIL_STRING* FabricClassIcon, MIL_INT NumberOfClasses)
{
MIL_INT MaxSizeY = MIL_INT_MIN;
MIL_INT SumSizeX = 0;
std::vector<MIL_UNIQUE_BUF_ID> IconsToDisplay;
for(MIL_INT i=0; i<NumberOfClasses; i++)
{
IconsToDisplay.push_back(MbufRestore(FabricClassIcon[i], MilSystem, M_UNIQUE_ID));
MIL_INT SizeX = MbufInquire(IconsToDisplay.back(), M_SIZE_X, M_NULL);
MIL_INT SizeY = MbufInquire(IconsToDisplay.back(), M_SIZE_Y, M_NULL);
MaxSizeY = std::max<MIL_INT>(SizeY, MaxSizeY);
SumSizeX = SumSizeX + SizeX;
}
MIL_UNIQUE_BUF_ID AllClassesImage = MbufAllocColor(MilSystem, 3, SumSizeX, MaxSizeY, 8 + M_UNSIGNED, M_IMAGE + M_PROC + M_DISP, M_UNIQUE_ID);
MbufClear(AllClassesImage, 0.0);
MIL_UNIQUE_GRA_ID GraContext = MgraAlloc(MilSystem, M_UNIQUE_ID);
MgraColor(GraContext, M_COLOR_BLUE);
MIL_INT CurXOffset = 0;
for(const auto& IconImage : IconsToDisplay)
{
MIL_INT SizeX = MbufInquire(IconImage, M_SIZE_X, M_NULL);
MIL_INT SizeY = MbufInquire(IconImage, M_SIZE_Y, M_NULL);
MbufCopyColor2d(IconImage, AllClassesImage, M_ALL_BANDS, 0, 0, M_ALL_BANDS, CurXOffset, 0, SizeX, SizeY);
MgraRect(GraContext, AllClassesImage, CurXOffset, 0, CurXOffset + SizeX - 1, SizeY - 1);
CurXOffset += SizeX;
}
return AllClassesImage;
}
const std::vector<MIL_INT> CreateShuffledIndex(MIL_INT NbEntries, unsigned int Seed)
{
std::vector<MIL_INT> IndexVector(NbEntries);
std::iota(IndexVector.begin(), IndexVector.end(), 0);
std::mt19937 gen(Seed);
std::shuffle(IndexVector.begin(), IndexVector.end(), gen);
return IndexVector;
}
void DeleteFiles(const std::vector<MIL_STRING>& Files)
{
for(const auto& FileName : Files)
{
MappFileOperation(M_DEFAULT, FileName, M_NULL, M_NULL, M_FILE_DELETE, M_DEFAULT, M_NULL);
}
}
void ListFilesInFolder(const MIL_STRING& FolderName, std::vector<MIL_STRING>& FilesInFolder)
{
MIL_STRING FileToSearch = FolderName;
FileToSearch += MIL_TEXT("*.*");
WIN32_FIND_DATA FindFileData;
HANDLE hFind;
hFind = FindFirstFile(FileToSearch.c_str(), &FindFileData);
if(hFind == INVALID_HANDLE_VALUE)
{
MosPrintf(MIL_TEXT("FindFirstFile failed (%d)\n"), GetLastError());
return;
}
do
{
if(!(FindFileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY))
{
FilesInFolder.push_back(FolderName + FindFileData.cFileName);
}
} while(FindNextFile(hFind, &FindFileData) != 0);
FindClose(hFind);
}
void AddClassDescription(MIL_ID MilSystem,
MIL_ID Dataset,
const MIL_STRING* FabricsClassName,
const MIL_STRING* FabricsClassIcon,
MIL_INT NumberOfClasses)
{
for(MIL_INT i = 0; i < NumberOfClasses; i++)
{
MclassControl(Dataset, M_DEFAULT, M_CLASS_ADD, FabricsClassName[i]);
MIL_UNIQUE_BUF_ID IconImageId = MbufRestore(FabricsClassIcon[i], MilSystem, M_UNIQUE_ID);
MclassControl(Dataset, M_CLASS_INDEX(i), M_CLASS_ICON_ID, IconImageId);
}
}
void PrepareTheDatasets(MIL_ID MilSystem,
const MIL_STRING* FabricsClassName,
const MIL_STRING* FabricsClassIcon,
MIL_INT NumberOfClasses,
const MIL_STRING& OriginalDataPath,
const MIL_STRING& ExampleDataPath,
MIL_ID TrainDataset,
MIL_ID DevDataset,
MIL_ID TestDataset)
{
PrepareExampleDataFolder(ExampleDataPath, FabricsClassName, NumberOfClasses);
CopyOriginalDataToExampleDataFolder(FabricsClassName, NumberOfClasses, OriginalDataPath, ExampleDataPath);
MosPrintf(MIL_TEXT("\nCreating the dataset containing all the data...\n"));
MIL_UNIQUE_CLASS_ID FullDataset = MclassAlloc(MilSystem, M_DATASET_IMAGES, M_DEFAULT, M_UNIQUE_ID);
MclassControl(FullDataset, M_CONTEXT, M_ROOT_PATH, GetExampleCurrentDirectory());
AddClassDescription(MilSystem, FullDataset, FabricsClassName, FabricsClassIcon, NumberOfClasses);
for (MIL_INT i = 0; i < NumberOfClasses; i++)
{
AddClassToDataset(i, ExampleDataPath, FabricsClassName[i], FullDataset);
}
MosPrintf(MIL_TEXT("\nSplitting the dataset to train/dev/test datasets...\n"));
MIL_UNIQUE_CLASS_ID WorkingDataset = MclassAlloc(MilSystem, M_DATASET_IMAGES, M_DEFAULT, M_UNIQUE_ID);
MclassControl(WorkingDataset, M_CONTEXT, M_ROOT_PATH, GetExampleCurrentDirectory());
const MIL_DOUBLE PERCENTAGE_IN_TRAIN_DATASET = 70.0;
const MIL_DOUBLE PERCENTAGE_IN_DEV_DATASET = 20.0;
const MIL_DOUBLE PERCENTAGE_IN_TEST_DATASET = 10.0;
MclassSplitDataset(M_SPLIT_CONTEXT_FIXED_SEED, FullDataset, WorkingDataset, TestDataset,
100.0-PERCENTAGE_IN_TEST_DATASET, M_NULL, M_DEFAULT);
const MIL_DOUBLE PERCENTAGE_TO_USE = PERCENTAGE_IN_TRAIN_DATASET/(100.0-PERCENTAGE_IN_TEST_DATASET)*100.0;
MclassSplitDataset(M_SPLIT_CONTEXT_FIXED_SEED, WorkingDataset, TrainDataset, DevDataset,
PERCENTAGE_TO_USE, M_NULL, M_DEFAULT);
MosPrintf(MIL_TEXT("\nAugmenting the train dataset...\n"));
AugmentDataset(MilSystem, TrainDataset, NB_AUGMENTATION_PER_IMAGE);
MosPrintf(MIL_TEXT("\nCropping images from the train/dev/test datasets...\n"));
MosPrintf(MIL_TEXT("Cropping images from the train dataset...\n"));
CropDatasetImages(MilSystem, TrainDataset, TRAIN_IMAGE_SIZE);
MosPrintf(MIL_TEXT("Cropping images from the dev dataset...\n"));
CropDatasetImages(MilSystem, DevDataset, TRAIN_IMAGE_SIZE);
MosPrintf(MIL_TEXT("Cropping images from the test dataset...\n"));
CropDatasetImages(MilSystem, TestDataset , TRAIN_IMAGE_SIZE);
MIL_STRING FolderList;
for(MIL_INT i = 0; i < NumberOfClasses; i++)
{
if(i == (NumberOfClasses - 1))
FolderList += FabricsClassName[i];
else
FolderList += (FabricsClassName[i] + MIL_TEXT(", "));
}
MosPrintf(MIL_TEXT("\nData preparation was successful.\n"));
MosPrintf(MIL_TEXT("\nA train dataset was created using %.lf%% of the original images.\n"), PERCENTAGE_IN_TRAIN_DATASET);
MosPrintf(MIL_TEXT("Augmented data (%d augmented images for each image in the dataset)\nwas also added to that set\n"), NB_AUGMENTATION_PER_IMAGE);
MosPrintf(MIL_TEXT("to ensure enough data for training and to be tolerant to small changes\nin the later acquisition setup.\n"));
MosPrintf(MIL_TEXT("\nA dev dataset was created using %.lf%% of the original images.\n"), PERCENTAGE_IN_DEV_DATASET);
MosPrintf(MIL_TEXT("\nA test dataset was created using %.lf%% of the original images.\n"), PERCENTAGE_IN_TEST_DATASET);
MosPrintf(MIL_TEXT("\nOriginal images are a little larger than those in the final application.\n"));
MosPrintf(MIL_TEXT("This ensures augmented images do not contain overscan pixels.\n"));
MosPrintf(MIL_TEXT("In the final datasets, the images were cropped to meet the final \n"));
MosPrintf(MIL_TEXT("application's image size requirement.\n"));
MosPrintf(MIL_TEXT("Press <enter> to continue...\n"));
MosGetch();
}
void CopyOriginalDataToExampleDataFolder(const MIL_STRING* FabricsClassName,
MIL_INT NumberOfClasses,
const MIL_STRING& OriginalDataPath,
const MIL_STRING& ExampleDataPath)
{
MosPrintf(MIL_TEXT("\nCopying original images\nfrom %s folder\nto %s folder...\n"), OriginalDataPath.c_str(), ExampleDataPath.c_str());
for(MIL_INT FabricIndex=0; FabricIndex<NumberOfClasses; FabricIndex++)
{
MIL_INT NbImages = FABRICS_CLASS_NB_IMAGES[FabricIndex];
for(MIL_INT i = 1; i <= NbImages; i++)
{
MIL_TEXT_CHAR OriginalFileName[512];
MosSprintf(OriginalFileName, 512, MIL_TEXT("%s%s/%04d.mim"), OriginalDataPath.c_str(), FabricsClassName[FabricIndex].c_str(), i);
MIL_TEXT_CHAR DestFileName[512];
MosSprintf(DestFileName, 512, MIL_TEXT("%s%s/%04d.mim"), ExampleDataPath.c_str(), FabricsClassName[FabricIndex].c_str(), i);
MappFileOperation(M_DEFAULT, OriginalFileName, M_DEFAULT, DestFileName, M_FILE_COPY, M_DEFAULT, M_NULL);
}
}
}
void DeleteFilesInFolder(const MIL_STRING& FolderName)
{
std::vector<MIL_STRING> FilesInFolder;
ListFilesInFolder(FolderName, FilesInFolder);
DeleteFiles(FilesInFolder);
}
void PrepareExampleDataFolder(const MIL_STRING& ExampleDataPath, const MIL_STRING* FabricsClassName, MIL_INT NumberOfClasses)
{
MIL_INT FileExists;
MappFileOperation(M_DEFAULT, ExampleDataPath, M_NULL, M_NULL, M_FILE_EXISTS, M_DEFAULT, &FileExists);
if(FileExists != M_YES)
{
MosPrintf(MIL_TEXT("\nCreating the %s folder and a sub folder for each class...\n"), ExampleDataPath.c_str());
MappFileOperation(M_DEFAULT, ExampleDataPath, M_NULL, M_NULL, M_FILE_MAKE_DIR, M_DEFAULT, M_NULL);
for(MIL_INT i = 0; i < NumberOfClasses; i++)
{
MappFileOperation(M_DEFAULT, ExampleDataPath + FabricsClassName[i], M_NULL, M_NULL, M_FILE_MAKE_DIR, M_DEFAULT, M_NULL);
}
}
else
{
MosPrintf(MIL_TEXT("\nDeleting files in the %s folder to ensure example repeatability...\n"), ExampleDataPath.c_str());
for(MIL_INT i = 0; i < NumberOfClasses; i++)
{
MappFileOperation(M_DEFAULT, ExampleDataPath + FabricsClassName[i], M_NULL, M_NULL, M_FILE_EXISTS, M_DEFAULT, &FileExists);
if(FileExists)
DeleteFilesInFolder(ExampleDataPath + FabricsClassName[i] + MIL_TEXT("/"));
else
MappFileOperation(M_DEFAULT, ExampleDataPath + FabricsClassName[i], M_NULL, M_NULL, M_FILE_MAKE_DIR, M_DEFAULT, M_NULL);
}
}
}
void AddClassToDataset(MIL_INT ClassIndex, const MIL_STRING& DataToTrainPath, const MIL_STRING& FabricName, MIL_ID Dataset)
{
MIL_INT NbEntries;
MclassInquire(Dataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, &NbEntries);
MIL_STRING FolderName = DataToTrainPath + FabricName + MIL_TEXT("/");
std::vector<MIL_STRING> FilesInFolder;
ListFilesInFolder(FolderName, FilesInFolder);
MIL_INT CurImageIndex = 0;
for(const auto& File : FilesInFolder)
{
MclassControl(Dataset, M_DEFAULT, M_ENTRY_ADD, M_DEFAULT);
MclassControlEntry(Dataset, NbEntries + CurImageIndex, M_DEFAULT_KEY, M_REGION_INDEX(0), M_CLASS_INDEX_GROUND_TRUTH, ClassIndex, M_NULL, M_DEFAULT);
MclassControlEntry(Dataset, NbEntries + CurImageIndex, M_DEFAULT_KEY, M_DEFAULT, M_FILE_PATH, M_DEFAULT, File, M_DEFAULT);
CurImageIndex++;
}
}
void AugmentDataset(MIL_ID System, MIL_ID Dataset, MIL_INT NbAugmentPerImage)
{
auto AugmentContext = MimAlloc(System, M_AUGMENTATION_CONTEXT, M_DEFAULT, M_UNIQUE_ID);
auto AugmentResult = MimAllocResult(System, M_DEFAULT, M_AUGMENTATION_RESULT, M_UNIQUE_ID);
MimControl(AugmentContext, M_AUG_SEED_MODE, M_RNG_INIT_VALUE);
MimControl(AugmentContext, M_AUG_RNG_INIT_VALUE, 42);
MimControl(AugmentContext, M_AUG_TRANSLATION_X_OP, M_ENABLE);
MimControl(AugmentContext, M_AUG_TRANSLATION_X_OP_MAX, 2);
MimControl(AugmentContext, M_AUG_TRANSLATION_Y_OP, M_ENABLE);
MimControl(AugmentContext, M_AUG_TRANSLATION_Y_OP_MAX, 2);
MimControl(AugmentContext, M_AUG_SCALE_OP, M_ENABLE);
MimControl(AugmentContext, M_AUG_SCALE_OP_FACTOR_MIN, 0.97);
MimControl(AugmentContext, M_AUG_SCALE_OP_FACTOR_MAX, 1.03);
MimControl(AugmentContext, M_AUG_ROTATION_OP, M_ENABLE);
MimControl(AugmentContext, M_AUG_ROTATION_OP_ANGLE_DELTA, 5.0);
MIL_INT NbEntries;
MclassInquire(Dataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, &NbEntries);
MIL_INT PosInAugmentDataset = NbEntries;
for(MIL_INT i = 0; i < NbEntries; i++)
{
MosPrintf(MIL_TEXT("%d of %d completed\r"), i + 1, NbEntries);
MIL_STRING FilePath;
MclassInquireEntry(Dataset, i, M_DEFAULT_KEY, M_DEFAULT, M_FILE_PATH, FilePath);
MIL_INT GroundTruthIndex;
MclassInquireEntry(Dataset, i, M_DEFAULT_KEY, M_REGION_INDEX(0), M_CLASS_INDEX_GROUND_TRUTH + M_TYPE_MIL_INT, &GroundTruthIndex);
MIL_UNIQUE_BUF_ID OrginalImage = MbufRestore(FilePath, System, M_UNIQUE_ID);
MIL_UNIQUE_BUF_ID AugmentedImage = MbufClone(OrginalImage, M_DEFAULT, M_DEFAULT, M_DEFAULT, M_DEFAULT, M_DEFAULT, M_DEFAULT, M_UNIQUE_ID);
for(MIL_INT AugIndex = 0; AugIndex < NbAugmentPerImage; AugIndex++)
{
MbufClear(AugmentedImage, 0.0);
MimAugment(AugmentContext, OrginalImage, AugmentedImage, M_DEFAULT, M_DEFAULT);
MIL_TEXT_CHAR Suffix[128];
MosSprintf(Suffix, 128, MIL_TEXT("_Aug_%d"), AugIndex);
MIL_STRING AugFileName = FilePath;
std::size_t DotPos = AugFileName.rfind(MIL_TEXT("."));
AugFileName.insert(DotPos, Suffix);
MbufSave(AugFileName, AugmentedImage);
MclassControl(Dataset, M_DEFAULT, M_ENTRY_ADD, M_DEFAULT);
MclassControlEntry(Dataset, PosInAugmentDataset, M_DEFAULT_KEY, M_REGION_INDEX(0), M_CLASS_INDEX_GROUND_TRUTH, GroundTruthIndex, M_NULL, M_DEFAULT);
MclassControlEntry(Dataset, PosInAugmentDataset, M_DEFAULT_KEY, M_DEFAULT, M_FILE_PATH, M_DEFAULT, AugFileName, M_DEFAULT);
MclassControlEntry(Dataset, PosInAugmentDataset, M_DEFAULT_KEY, M_DEFAULT, M_AUGMENTATION_SOURCE, i, M_NULL, M_DEFAULT);
PosInAugmentDataset++;
}
}
MosPrintf(MIL_TEXT("\n"));
}
void CropDatasetImages(MIL_ID MilSystem, MIL_ID Dataset, MIL_INT FinalImageSize)
{
MIL_INT NbEntries;
MclassInquire(Dataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, &NbEntries);
for(MIL_INT i = 0; i < NbEntries; i++)
{
MosPrintf(MIL_TEXT("%d of %d completed\r"), i + 1, NbEntries);
MIL_STRING FilePath;
MclassInquireEntry(Dataset, i, M_DEFAULT_KEY, M_DEFAULT, M_FILE_PATH, FilePath);
MIL_UNIQUE_BUF_ID OriginalImage = MbufRestore(FilePath, MilSystem, M_UNIQUE_ID);
MIL_INT ImageSizeX = MbufInquire(OriginalImage, M_SIZE_X, M_NULL);
MIL_INT ImageSizeY = MbufInquire(OriginalImage, M_SIZE_Y, M_NULL);
MIL_INT OffsetX = (ImageSizeX - FinalImageSize) / 2;
MIL_INT OffsetY = (ImageSizeY - FinalImageSize) / 2;
MIL_UNIQUE_BUF_ID CroppedImage = MbufClone(OriginalImage, M_DEFAULT, FinalImageSize, FinalImageSize, M_DEFAULT, M_DEFAULT, M_DEFAULT, M_UNIQUE_ID);
MbufCopyColor2d(OriginalImage, CroppedImage, M_ALL_BANDS, OffsetX, OffsetY, M_ALL_BANDS, 0, 0, FinalImageSize, FinalImageSize);
MbufSave(FilePath, CroppedImage);
}
MosPrintf(MIL_TEXT("\n"));
}
MIL_INT MFTYPE HookFuncEpoch(MIL_INT , MIL_ID EventId, void* UserData)
{
auto HookData = (HookEpochData *)UserData;
MIL_DOUBLE CurBench = 0.0;
MIL_DOUBLE CurBenchMean = -1.0;
MIL_INT CurEpochIndex = 0;
MclassGetHookInfo(EventId, M_EPOCH_INDEX + M_TYPE_MIL_INT, &CurEpochIndex);
MappTimer(M_DEFAULT, M_TIMER_READ, &CurBench);
MIL_DOUBLE EpochBenchMean = CurBench / (CurEpochIndex+1);
MIL_DOUBLE TrainErrorRate = 0;
MclassGetHookInfo(EventId, M_TRAIN_DATASET_ERROR_RATE, &TrainErrorRate);
MIL_DOUBLE DevErrorRate = 0;
MclassGetHookInfo(EventId, M_DEV_DATASET_ERROR_RATE, &DevErrorRate);
MIL_INT AreTrainedCNNParametersUpdated = M_FALSE;
MclassGetHookInfo(EventId,
M_TRAINED_CNN_PARAMETERS_UPDATED+M_TYPE_MIL_INT,
&AreTrainedCNNParametersUpdated);
bool TheEpochIsTheBestUpToNow = (AreTrainedCNNParametersUpdated == M_TRUE);
HookData->TheDashboard->AddEpochData(
TrainErrorRate, DevErrorRate,
CurEpochIndex, TheEpochIsTheBestUpToNow, EpochBenchMean);
return M_NULL;
}
MIL_INT MFTYPE HookFuncMiniBatch(MIL_INT , MIL_ID EventId, void* UserData)
{
auto HookData = (HookMiniBatchData *)UserData;
MIL_DOUBLE LossError = 0;
MclassGetHookInfo(EventId, M_MINI_BATCH_LOSS, &LossError);
MIL_INT MiniBatchIdx = 0;
MclassGetHookInfo(EventId, M_MINI_BATCH_INDEX + M_TYPE_MIL_INT, &MiniBatchIdx);
MIL_INT EpochIdx = 0;
MclassGetHookInfo(EventId, M_EPOCH_INDEX + M_TYPE_MIL_INT, &EpochIdx);
MIL_INT NbMiniBatchPerEpoch = 0;
MclassGetHookInfo(EventId, M_MINI_BATCH_PER_EPOCH + M_TYPE_MIL_INT, &NbMiniBatchPerEpoch);
if(EpochIdx == 0 && MiniBatchIdx == 0)
{
MappTimer(M_DEFAULT, M_TIMER_RESET, M_NULL);
}
HookData->TheDashboard->AddMiniBatchData(LossError, MiniBatchIdx, EpochIdx, NbMiniBatchPerEpoch);
if(MosKbhit() != 0)
{
char KeyVal = (char)MosGetch();
if(KeyVal == 'p')
{
MosPrintf(MIL_TEXT("\nPress 's' to stop the training or any other key to continue.\n"));
while(MosKbhit() == 0)
{
KeyVal = (char)MosGetch();
if(KeyVal == 's')
{
MIL_ID HookInfoTrainResId = M_NULL;
MclassGetHookInfo(EventId, M_RESULT_ID + M_TYPE_MIL_ID, &HookInfoTrainResId);
MclassControl(HookInfoTrainResId, M_DEFAULT, M_STOP_TRAIN, M_DEFAULT);
MosPrintf(MIL_TEXT("The training have been stopped.\n"));
break;
}
else
{
MosPrintf(MIL_TEXT("The training will continue.\n"));
break;
}
}
}
}
return(M_NULL);
}
MIL_UNIQUE_CLASS_ID TrainTheModel(MIL_ID MilSystem, MIL_ID TrainDataset, MIL_ID DevDataset, MIL_ID MilDisplay)
{
MIL_INT TrainDatasetNbImages = MclassInquire(TrainDataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, M_NULL);
MIL_INT DevDatasetNbImages = MclassInquire(DevDataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, M_NULL);
MIL_UNIQUE_CLASS_ID TrainCtx = MclassAlloc(MilSystem, M_TRAIN_CNN, M_DEFAULT, M_UNIQUE_ID);
MIL_UNIQUE_CLASS_ID TrainRes = MclassAllocResult(MilSystem, M_TRAIN_CNN_RESULT, M_DEFAULT, M_UNIQUE_ID);
const MIL_INT MAX_NUMBER_OF_EPOCH = 10;
const MIL_INT MINI_BATCH_SIZE = 64;
const MIL_DOUBLE LEARNING_RATE = 0.001;
MclassControl(TrainCtx, M_DEFAULT, M_MAX_EPOCH , MAX_NUMBER_OF_EPOCH);
MclassControl(TrainCtx, M_DEFAULT, M_MINI_BATCH_SIZE , MINI_BATCH_SIZE);
MclassControl(TrainCtx, M_DEFAULT, M_INITIAL_LEARNING_RATE, LEARNING_RATE);
MclassPreprocess(TrainCtx, M_DEFAULT);
MIL_INT TrainEngineUsed;
MclassInquire(TrainCtx, M_CONTEXT, M_CNN_TRAIN_ENGINE_USED + M_TYPE_MIL_INT, &TrainEngineUsed);
MIL_INT GpuTrainEngineStatus;
MclassInquire(TrainCtx, M_CONTEXT, M_GPU_TRAIN_ENGINE_LOAD_STATUS + M_TYPE_MIL_INT, &GpuTrainEngineStatus);
if(TrainEngineUsed == M_GPU && GpuTrainEngineStatus == M_JIT_COMPILATION_REQUIRED)
{
MosPrintf(MIL_TEXT("\nWarning :: The training might not be optimal for the current system.\n"));
MosPrintf(MIL_TEXT("Use the CNN Train Engine Test under Classification in MILConfig for more information.\n"));
MosPrintf(MIL_TEXT("It may take some time before displaying the first results...\n"));
}
else if(GpuTrainEngineStatus != M_SUCCESS)
{
MosPrintf(MIL_TEXT("\nWarning :: The training is beeing done on the CPU.\n"));
MosPrintf(MIL_TEXT("If a training on GPU was expected, use the CNN Train Engine Test under Classification in MILConfig for more information.\n"));
}
MIL_STRING TrainEngineDescription;
MclassInquire(TrainCtx, M_CONTEXT, M_CNN_TRAIN_ENGINE_USED_DESCRIPTION, TrainEngineDescription);
CTrainEvolutionDashboard TheTrainEvolutionDashboard(MilSystem, MAX_NUMBER_OF_EPOCH,
MINI_BATCH_SIZE, LEARNING_RATE,
TRAIN_IMAGE_SIZE, TRAIN_IMAGE_SIZE,
TrainDatasetNbImages, DevDatasetNbImages,
TrainEngineUsed, TrainEngineDescription);
MdispSelect(MilDisplay, TheTrainEvolutionDashboard.GetDashboardBufId());
HookEpochData TheHookEpochData;
TheHookEpochData.TheDashboard = &TheTrainEvolutionDashboard;
MclassHookFunction(TrainCtx, M_EPOCH_TRAINED, HookFuncEpoch, &TheHookEpochData);
HookMiniBatchData TheHookMiniBatchData;
TheHookMiniBatchData.TheDashboard = &TheTrainEvolutionDashboard;
MclassHookFunction(TrainCtx, M_MINI_BATCH_TRAINED, HookFuncMiniBatch, &TheHookMiniBatchData);
MIL_UNIQUE_CLASS_ID PretrainedCtx = MclassAlloc(MilSystem, M_CLASSIFIER_CNN_PREDEFINED, M_FCNET_M, M_UNIQUE_ID);
MclassTrain(TrainCtx, PretrainedCtx, TrainDataset, DevDataset, TrainRes, M_DEFAULT);
MIL_UNIQUE_CLASS_ID TrainedCtx;
MIL_INT Status = -1;
MclassGetResult(TrainRes, M_DEFAULT, M_STATUS + M_TYPE_MIL_INT, &Status);
if(Status == M_COMPLETE)
{
MosPrintf(MIL_TEXT("\nTraining was successful.\n"));
MIL_INT NbErrorImage = -1;
MclassGetResult(TrainRes, M_DEFAULT, M_TRAIN_DATASET_ERROR_ENTRIES + M_NB_ELEMENTS + M_TYPE_MIL_INT, &NbErrorImage);
if(NbErrorImage != 0)
{
MosPrintf(MIL_TEXT("Warning :: few images (%d) were missing at some part of the training.\n"), NbErrorImage);
}
TrainedCtx = MclassAlloc(MilSystem, M_CLASSIFIER_CNN_PREDEFINED, M_DEFAULT, M_UNIQUE_ID);
MclassCopyResult(TrainRes, M_DEFAULT, TrainedCtx, M_DEFAULT, M_TRAINED_CLASSIFIER_CNN, M_DEFAULT);
MosPrintf(MIL_TEXT("A training report was saved: \"TrainReport.csv\".\n"));
MclassExport(MIL_TEXT("TrainReport.csv"), M_FORMAT_TXT, TrainRes, M_DEFAULT, M_TRAIN_REPORT, M_DEFAULT);
MIL_DOUBLE TrainErrorRate = 0;
MclassGetResult(TrainRes, M_DEFAULT, M_TRAIN_DATASET_ERROR_RATE, &TrainErrorRate);
MIL_DOUBLE DevErrorRate = 0;
MclassGetResult(TrainRes, M_DEFAULT, M_DEV_DATASET_ERROR_RATE, &DevErrorRate);
MIL_INT LastUpdatedEpochIndex;
MclassGetResult(TrainRes, M_DEFAULT, M_LAST_EPOCH_UPDATED_PARAMETERS + M_TYPE_MIL_INT, &LastUpdatedEpochIndex);
MosPrintf(MIL_TEXT("\nThe best epoch was epoch %d with an error rate on the dev dataset of %.8lf.\n"), LastUpdatedEpochIndex, DevErrorRate);
MosPrintf(MIL_TEXT("The associated train error rate is %.8lf.\n"), TrainErrorRate);
MosPrintf(MIL_TEXT("Press <enter> to continue...\n"));
MosGetch();
}
return TrainedCtx;
}
void PredictUsingTrainedContext(MIL_ID MilSystem, MIL_ID MilDisplay, MIL_ID TrainedCtx, MIL_ID TestDataset)
{
CPredictResultDisplay ThePredictResultDisplay(MilSystem, MilDisplay, TestDataset);
MclassPreprocess(TrainedCtx, M_DEFAULT);
MIL_UNIQUE_CLASS_ID PredictedDataset = MclassAlloc(MilSystem, M_DATASET_IMAGES, M_DEFAULT, M_UNIQUE_ID);
MclassPredict(TrainedCtx, TestDataset, PredictedDataset, M_DEFAULT);
MIL_INT NbEntries = 0;
MIL_INT NbEntriesPredicted = 0;
MIL_DOUBLE PredAvg = 0;
MclassInquire(PredictedDataset, M_DEFAULT, M_NUMBER_OF_ENTRIES + M_TYPE_MIL_INT, &NbEntries);
MclassInquire(PredictedDataset, M_DEFAULT, M_NUMBER_OF_ENTRIES_PREDICTED + M_TYPE_MIL_INT, &NbEntriesPredicted);
MclassInquire(PredictedDataset, M_DEFAULT, M_PREDICTED_SCORE_AVERAGE, &PredAvg);
const unsigned int ShuffledIndexSeed = 49;
const std::vector<MIL_INT>& ShuffeldIndex = CreateShuffledIndex(NbEntries, ShuffledIndexSeed);
const MIL_INT NbPredictionToShow = std::min<MIL_INT>(10, NbEntriesPredicted);
MosPrintf(MIL_TEXT("\nPredictions will be performed on the test dataset as a final check\nof the trained CNN model.\n"));
MosPrintf(MIL_TEXT("The test dataset contains %d images.\n"), NbEntries);
MosPrintf(MIL_TEXT("The prediction results will be shown for the first %d images.\n"), NbPredictionToShow);
MIL_INT NbGoodPredictions = 0;
for(MIL_INT i = 0; i < NbEntries; i++)
{
MIL_INT EntryPredicted = 0;
MclassInquireEntry(PredictedDataset, ShuffeldIndex[i], M_DEFAULT_KEY, M_REGION_INDEX(0), M_CLASS_INDEX_PREDICTED + M_NB_ELEMENTS + M_TYPE_MIL_INT, &EntryPredicted);
if(EntryPredicted == 1)
{
MIL_INT GroundTruthIndex;
MclassInquireEntry(TestDataset, ShuffeldIndex[i], M_DEFAULT_KEY, M_REGION_INDEX(0), M_CLASS_INDEX_GROUND_TRUTH + M_TYPE_MIL_INT, &GroundTruthIndex);
MIL_INT PredIndex = 0;
MclassInquireEntry(PredictedDataset, ShuffeldIndex[i], M_DEFAULT_KEY, M_REGION_INDEX(0), M_CLASS_INDEX_PREDICTED + M_TYPE_MIL_INT, &PredIndex);
std::vector<MIL_DOUBLE> PredScores;
MclassInquireEntry(PredictedDataset, ShuffeldIndex[i], M_DEFAULT_KEY, M_REGION_INDEX(0), M_PREDICTED_CLASS_SCORES, PredScores);
if(PredIndex == GroundTruthIndex)
NbGoodPredictions++;
if(i < NbPredictionToShow)
{
MIL_STRING FilePath;
MclassInquireEntry(TestDataset, ShuffeldIndex[i], M_DEFAULT_KEY, M_DEFAULT, M_FILE_PATH, FilePath);
MIL_UNIQUE_BUF_ID ImageToPredict = MbufRestore(FilePath, MilSystem, M_UNIQUE_ID);
ThePredictResultDisplay.Update(ImageToPredict, PredIndex, PredScores[PredIndex]);
MosPrintf(MIL_TEXT("The predicted index is %d and the predicted score is %.2lf%s (Ground truth=%d)\n"), PredIndex, PredScores[PredIndex], MIL_TEXT("%"), GroundTruthIndex);
MosPrintf(MIL_TEXT("Press <enter> to continue...\n"));
MosGetch();
}
}
else
{
MIL_STRING FilePath;
MclassInquireEntry(TestDataset, ShuffeldIndex[i], M_DEFAULT_KEY, M_DEFAULT, M_FILE_PATH, FilePath);
MosPrintf(MIL_TEXT("The image \"%s\" failed to be predicted.\n"), FilePath.c_str());
}
}
MclassSave(MIL_TEXT("FabricsNet_Gray.mclass"), TrainedCtx, M_DEFAULT);
MosPrintf(MIL_TEXT("The accuracy on the test dataset using the trained context is %.2lf%s.\n"), ((MIL_DOUBLE)NbGoodPredictions / (MIL_DOUBLE)NbEntriesPredicted)*100.0, MIL_TEXT("%"));
MosPrintf(MIL_TEXT("The average predicted score on the test dataset using the trained\ncontext is %.2lf%s.\n"), PredAvg, MIL_TEXT("%"));
MosPrintf(MIL_TEXT("The trained context was saved: \"FabricsNet_Gray.mclass\".\n"));
MosPrintf(MIL_TEXT("Press <enter> to end...\n"));
MosGetch();
}
CTrainEvolutionDashboard::CTrainEvolutionDashboard(MIL_ID MilSystem, MIL_INT MaxEpoch, MIL_INT MinibatchSize,
MIL_DOUBLE LearningRate,
MIL_INT TrainImageSizeX, MIL_INT TrainImageSizeY,
MIL_INT TrainDatasetSize, MIL_INT DevDatasetSize,
MIL_INT TrainEngineUsed, MIL_STRING& TrainEngineDescription):
m_DashboardBufId(M_NULL),
m_TheGraContext(M_NULL),
m_EpochInfoBufId(M_NULL),
m_EpochGraphBufId(M_NULL),
m_LossInfoBufId(M_NULL),
m_LossGraphBufId(M_NULL),
m_ProgressionInfoBufId(M_NULL),
m_MaxEpoch(MaxEpoch),
m_DashboardWidth(0),
m_LastTrainPosX(0),
m_LastTrainPosY(0),
m_LastDevPosX(0),
m_LastDevPosY(0),
m_LastTrainMinibatchPosX(0),
m_LastTrainMinibatchPosY(0),
m_YPositionForLossText(0),
m_EpochBenchMean(-1.0),
GRAPH_SIZE_X(400),
GRAPH_SIZE_Y(400),
GRAPH_TOP_MARGIN(30),
MARGIN(50),
EPOCH_AND_MINIBATCH_REGION_HEIGHT(190),
PROGRESSION_INFO_REGION_HEIGHT(100),
LOSS_EXPONENT_MAX(0),
LOSS_EXPONENT_MIN(-5),
COLOR_GENERAL_INFO(M_RGB888(0, 176, 255)),
COLOR_DEV_SET_INFO(M_COLOR_MAGENTA),
COLOR_TRAIN_SET_INFO(M_COLOR_GREEN),
COLOR_PROGRESS_BAR(M_COLOR_DARK_GREEN)
{
const MIL_INT GraphBoxWidth = GRAPH_SIZE_X + 2 * MARGIN;
const MIL_INT GraphBoxHeight = GRAPH_SIZE_Y + MARGIN + GRAPH_TOP_MARGIN;
m_DashboardWidth = 2 * GraphBoxWidth;
const MIL_INT DashboardHeight = EPOCH_AND_MINIBATCH_REGION_HEIGHT + GraphBoxHeight + PROGRESSION_INFO_REGION_HEIGHT;
m_DashboardBufId = MbufAllocColor(MilSystem, 3, m_DashboardWidth, DashboardHeight,
8 + M_UNSIGNED, M_IMAGE + M_PROC + M_DISP, M_UNIQUE_ID);
MbufClear(m_DashboardBufId, M_COLOR_BLACK);
m_TheGraContext = MgraAlloc(MilSystem, M_UNIQUE_ID);
const MIL_INT GraphYPosition = EPOCH_AND_MINIBATCH_REGION_HEIGHT;
const MIL_INT ProgressionInfoYPosition = GraphYPosition + GraphBoxHeight;
m_EpochInfoBufId = MbufChild2d(m_DashboardBufId, 0, 0, GraphBoxWidth, EPOCH_AND_MINIBATCH_REGION_HEIGHT, M_UNIQUE_ID);
m_LossInfoBufId = MbufChild2d(m_DashboardBufId, GraphBoxWidth, 0, GraphBoxWidth, EPOCH_AND_MINIBATCH_REGION_HEIGHT, M_UNIQUE_ID);
m_EpochGraphBufId = MbufChild2d(m_DashboardBufId, 0, GraphYPosition, GraphBoxWidth, GraphBoxHeight, M_UNIQUE_ID);
m_LossGraphBufId = MbufChild2d(m_DashboardBufId, GraphBoxWidth, GraphYPosition, GraphBoxWidth, GraphBoxHeight, M_UNIQUE_ID);
m_ProgressionInfoBufId = MbufChild2d(m_DashboardBufId, 0, ProgressionInfoYPosition, m_DashboardWidth, PROGRESSION_INFO_REGION_HEIGHT, M_UNIQUE_ID);
DrawSectionSeparators();
InitializeEpochGraph();
InitializeLossGraph();
WriteGeneralTrainInfo(MinibatchSize, TrainImageSizeX, TrainImageSizeY, TrainDatasetSize,
DevDatasetSize, LearningRate, TrainEngineUsed, TrainEngineDescription);
}
CTrainEvolutionDashboard::~CTrainEvolutionDashboard()
{
m_TheGraContext = M_NULL;
m_EpochInfoBufId = M_NULL;
m_LossInfoBufId = M_NULL;
m_EpochGraphBufId = M_NULL;
m_LossGraphBufId = M_NULL;
m_ProgressionInfoBufId = M_NULL;
m_DashboardBufId = M_NULL;
}
void CTrainEvolutionDashboard::DrawBufferFrame(MIL_ID BufId, MIL_INT FrameThickness)
{
MIL_ID SizeX = MbufInquire(BufId, M_SIZE_X, M_NULL);
MIL_ID SizeY = MbufInquire(BufId, M_SIZE_Y, M_NULL);
MgraColor(m_TheGraContext, COLOR_GENERAL_INFO);
MgraRectFill(m_TheGraContext, BufId, 0, 0, SizeX - 1, FrameThickness - 1);
MgraRectFill(m_TheGraContext, BufId, SizeX - FrameThickness, 0, SizeX - 1, SizeY - 1);
MgraRectFill(m_TheGraContext, BufId, 0, SizeY - FrameThickness, SizeX - 1, SizeY - 1);
MgraRectFill(m_TheGraContext, BufId, 0, 0, FrameThickness - 1, SizeY - 1);
}
void CTrainEvolutionDashboard::DrawSectionSeparators()
{
DrawBufferFrame(m_DashboardBufId, 4);
DrawBufferFrame(m_EpochInfoBufId, 2);
DrawBufferFrame(m_EpochGraphBufId, 2);
DrawBufferFrame(m_LossInfoBufId, 2);
DrawBufferFrame(m_LossGraphBufId, 2);
DrawBufferFrame(m_ProgressionInfoBufId, 2);
}
void CTrainEvolutionDashboard::InitializeEpochGraph()
{
MgraColor(m_TheGraContext, M_COLOR_WHITE);
MgraRect(m_TheGraContext, m_EpochGraphBufId, MARGIN, GRAPH_TOP_MARGIN, MARGIN + GRAPH_SIZE_X, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y);
MgraControl(m_TheGraContext, M_TEXT_ALIGN_HORIZONTAL, M_RIGHT);
MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN - 5, GRAPH_TOP_MARGIN, MIL_TEXT("100"));
MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN - 5, GRAPH_TOP_MARGIN + ((MIL_INT)(0.25*GRAPH_SIZE_Y)), MIL_TEXT("75"));
MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN - 5, GRAPH_TOP_MARGIN + ((MIL_INT)(0.50*GRAPH_SIZE_Y)), MIL_TEXT("50"));
MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN - 5, GRAPH_TOP_MARGIN + ((MIL_INT)(0.75*GRAPH_SIZE_Y)), MIL_TEXT("25"));
MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN - 5, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y, MIL_TEXT("0"));
MgraLine(m_TheGraContext, m_EpochGraphBufId, MARGIN, GRAPH_TOP_MARGIN + ((MIL_INT)(0.25*GRAPH_SIZE_Y)), MARGIN + 5, GRAPH_TOP_MARGIN + ((MIL_INT)(0.25*GRAPH_SIZE_Y)));
MgraLine(m_TheGraContext, m_EpochGraphBufId, MARGIN, GRAPH_TOP_MARGIN + ((MIL_INT)(0.50*GRAPH_SIZE_Y)), MARGIN + 5, GRAPH_TOP_MARGIN + ((MIL_INT)(0.50*GRAPH_SIZE_Y)));
MgraLine(m_TheGraContext, m_EpochGraphBufId, MARGIN, GRAPH_TOP_MARGIN + ((MIL_INT)(0.75*GRAPH_SIZE_Y)), MARGIN + 5, GRAPH_TOP_MARGIN + ((MIL_INT)(0.75*GRAPH_SIZE_Y)));
MgraControl(m_TheGraContext, M_TEXT_ALIGN_HORIZONTAL, M_LEFT);
MIL_INT NbTick = std::min<MIL_INT>(m_MaxEpoch, 10);
const MIL_INT EpochTickValue = m_MaxEpoch / NbTick;
for(MIL_INT CurTick = 1; CurTick <= m_MaxEpoch; CurTick += EpochTickValue)
{
MIL_DOUBLE Percentage = (MIL_DOUBLE)CurTick / (MIL_DOUBLE)m_MaxEpoch;
MIL_INT XOffset = (MIL_INT)(Percentage * GRAPH_SIZE_X);
MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN + XOffset, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y + 5, M_TO_STRING(CurTick-1));
MgraLine(m_TheGraContext, m_EpochGraphBufId, MARGIN + XOffset, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y - 5, MARGIN + XOffset, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y);
}
}
void CTrainEvolutionDashboard::InitializeLossGraph()
{
MgraColor(m_TheGraContext, M_COLOR_WHITE);
MgraRect(m_TheGraContext, m_LossGraphBufId, MARGIN, GRAPH_TOP_MARGIN, MARGIN + GRAPH_SIZE_X, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y);
MgraControl(m_TheGraContext, M_TEXT_ALIGN_HORIZONTAL, M_RIGHT);
const MIL_INT NbLossValueTick = LOSS_EXPONENT_MAX - LOSS_EXPONENT_MIN;
const MIL_DOUBLE TickRatio = 1.0/(MIL_DOUBLE)NbLossValueTick;
MIL_DOUBLE TickNum = 0.0;
for(MIL_INT i = LOSS_EXPONENT_MAX; i >= LOSS_EXPONENT_MIN; i--)
{
MIL_TEXT_CHAR CurTickText[128];
MosSprintf(CurTickText, 128, MIL_TEXT("1e%d"), i);
MIL_INT TickYPos = (MIL_INT)(TickNum*TickRatio*GRAPH_SIZE_Y);
MgraText(m_TheGraContext, m_LossGraphBufId, MARGIN - 5, GRAPH_TOP_MARGIN + TickYPos, CurTickText);
if(i!=LOSS_EXPONENT_MAX && i!= LOSS_EXPONENT_MIN)
{
MgraLine(m_TheGraContext, m_LossGraphBufId, MARGIN, GRAPH_TOP_MARGIN + TickYPos, MARGIN + 5, GRAPH_TOP_MARGIN + TickYPos);
}
TickNum = TickNum + 1.0;
}
MgraControl(m_TheGraContext, M_TEXT_ALIGN_HORIZONTAL, M_LEFT);
const MIL_INT NbEpochTick = std::min<MIL_INT>(m_MaxEpoch, 10);
const MIL_INT EpochTickValue = m_MaxEpoch / NbEpochTick;
for(MIL_INT CurTick = 1; CurTick <= m_MaxEpoch; CurTick += EpochTickValue)
{
MIL_DOUBLE Percentage = (MIL_DOUBLE)CurTick / (MIL_DOUBLE)m_MaxEpoch;
MIL_INT XOffset = (MIL_INT)(Percentage * GRAPH_SIZE_X);
MgraText(m_TheGraContext, m_LossGraphBufId, MARGIN + XOffset, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y + 5, M_TO_STRING(CurTick-1));
MgraLine(m_TheGraContext, m_LossGraphBufId, MARGIN + XOffset, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y - 5, MARGIN + XOffset, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y);
}
}
void CTrainEvolutionDashboard::WriteGeneralTrainInfo(MIL_INT MinibatchSize,
MIL_INT TrainImageSizeX,
MIL_INT TrainImageSizeY,
MIL_INT TrainDatasetSize,
MIL_INT DevDatasetSize,
MIL_DOUBLE LearningRate,
MIL_INT TrainEngineUsed,
MIL_STRING& TrainEngineDescription)
{
MgraControl(m_TheGraContext, M_BACKGROUND_MODE, M_OPAQUE);
MgraControl(m_TheGraContext, M_BACKCOLOR, M_COLOR_BLACK);
MgraControl(m_TheGraContext, M_TEXT_ALIGN_HORIZONTAL, M_LEFT);
const MIL_INT YMargin = 15;
const MIL_INT TextHeight = 20;
const MIL_INT TextMargin = MARGIN - 10;
MIL_INT TextYPos = YMargin;
MgraColor(m_TheGraContext, COLOR_GENERAL_INFO);
MIL_TEXT_CHAR TheString[512];
if(TrainEngineUsed==M_CPU)
MosSprintf(TheString, 512, MIL_TEXT("Training is being performed on the CPU"));
else
MosSprintf(TheString, 512, MIL_TEXT("Training is being performed on the GPU"));
MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
TextYPos += TextHeight;
MosSprintf(TheString, 512, MIL_TEXT("Training engine: %s"), TrainEngineDescription.c_str());
MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
TextYPos += TextHeight;
MosSprintf(TheString, 512, MIL_TEXT("Train image size: %dx%d"), TrainImageSizeX, TrainImageSizeY);
MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
TextYPos += TextHeight;
MosSprintf(TheString, 512, MIL_TEXT("Train and Dev dataset size: %d and %d images"), TrainDatasetSize, DevDatasetSize);
MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
TextYPos += TextHeight;
MosSprintf(TheString, 512, MIL_TEXT("Max number of epochs: %d"), m_MaxEpoch);
MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
TextYPos += TextHeight;
MosSprintf(TheString, 512, MIL_TEXT("Minibatch size: %d"), MinibatchSize);
MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
TextYPos += TextHeight;
MosSprintf(TheString, 512, MIL_TEXT("Learning rate: %.2e"), LearningRate);
MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, TextYPos, TheString);
TextYPos += TextHeight;
m_YPositionForLossText = TextYPos;
}
void CTrainEvolutionDashboard::AddEpochData(MIL_DOUBLE TrainErrorRate, MIL_DOUBLE DevErrorRate,
MIL_INT CurEpoch, bool TheEpochIsTheBestUpToNow,
MIL_DOUBLE EpochBenchMean)
{
m_EpochBenchMean = EpochBenchMean;
UpdateEpochInfo(TrainErrorRate, DevErrorRate, CurEpoch, TheEpochIsTheBestUpToNow);
UpdateEpochGraph(TrainErrorRate, DevErrorRate, CurEpoch);
}
void CTrainEvolutionDashboard::AddMiniBatchData(MIL_DOUBLE LossError, MIL_INT MinibatchIdx, MIL_INT EpochIdx, MIL_INT NbBatchPerEpoch)
{
UpdateLoss(LossError);
UpdateLossGraph(LossError, MinibatchIdx, EpochIdx, NbBatchPerEpoch);
UpdateProgression(MinibatchIdx, EpochIdx, NbBatchPerEpoch);
}
void CTrainEvolutionDashboard::UpdateEpochInfo(MIL_DOUBLE TrainErrorRate, MIL_DOUBLE DevErrorRate, MIL_INT CurEpoch, bool TheEpochIsTheBestUpToNow)
{
const MIL_INT YMargin = 15;
const MIL_INT TextHeight = 20;
const MIL_INT TextMargin = MARGIN - 10;
MgraColor(m_TheGraContext, COLOR_DEV_SET_INFO);
MIL_TEXT_CHAR DevError[512];
MosSprintf(DevError, 512, MIL_TEXT("Current Dev error rate: %7.4lf %%"), DevErrorRate);
MgraText(m_TheGraContext, m_EpochInfoBufId, TextMargin, YMargin, DevError);
MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
MIL_TEXT_CHAR TrainError[512];
MosSprintf(TrainError, 512, MIL_TEXT("Current Train error rate: %7.4lf %%"), TrainErrorRate);
MgraText(m_TheGraContext, m_EpochInfoBufId, TextMargin, YMargin + TextHeight, TrainError);
if(TheEpochIsTheBestUpToNow)
{
MgraColor(m_TheGraContext, COLOR_DEV_SET_INFO);
MIL_TEXT_CHAR BestDevError[512];
MosSprintf(BestDevError, 512, MIL_TEXT("Best epoch Dev error rate: %7.4lf %% (Epoch %d)"), DevErrorRate, CurEpoch);
MgraText(m_TheGraContext, m_EpochInfoBufId, TextMargin, YMargin + 2 * TextHeight, BestDevError);
MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
MIL_TEXT_CHAR TrainErrorBest[512];
MosSprintf(TrainErrorBest, 512, MIL_TEXT("Train error rate for the best epoch: %7.4lf %%"), TrainErrorRate);
MgraText(m_TheGraContext, m_EpochInfoBufId, TextMargin, YMargin + 3 * TextHeight, TrainErrorBest);
}
}
void CTrainEvolutionDashboard::UpdateLoss(MIL_DOUBLE LossError)
{
const MIL_INT TextMargin = MARGIN - 10;
MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
MIL_TEXT_CHAR LossText[512];
MosSprintf(LossText, 512, MIL_TEXT("Current loss value: %11.7lf"), LossError);
MgraText(m_TheGraContext, m_LossInfoBufId, TextMargin, m_YPositionForLossText, LossText);
}
void CTrainEvolutionDashboard::UpdateEpochGraph(MIL_DOUBLE TrainErrorRate, MIL_DOUBLE DevErrorRate, MIL_INT CurEpoch)
{
MIL_INT EpochIndex = CurEpoch + 1;
MIL_INT CurTrainPosX = MARGIN + (MIL_INT)((MIL_DOUBLE)(EpochIndex) / (MIL_DOUBLE)(m_MaxEpoch)*(MIL_DOUBLE)GRAPH_SIZE_X);
MIL_INT CurTrainPosY = GRAPH_TOP_MARGIN + (MIL_INT)((MIL_DOUBLE)GRAPH_SIZE_Y*(1.0 - TrainErrorRate * 0.01));
MIL_INT CurDevPosX = CurTrainPosX;
MIL_INT CurDevPosY = GRAPH_TOP_MARGIN + (MIL_INT)((MIL_DOUBLE)GRAPH_SIZE_Y*(1.0 - DevErrorRate * 0.01));
if(CurEpoch == 0)
{
MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
MgraArcFill(m_TheGraContext, m_EpochGraphBufId, CurTrainPosX, CurTrainPosY, 2, 2, 0, 360);
MgraColor(m_TheGraContext, COLOR_DEV_SET_INFO);
MgraArcFill(m_TheGraContext, m_EpochGraphBufId, CurDevPosX, CurDevPosY, 2, 2, 0, 360);
}
else
{
MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
MgraLine(m_TheGraContext, m_EpochGraphBufId, m_LastTrainPosX, m_LastTrainPosY, CurTrainPosX, CurTrainPosY);
MgraColor(m_TheGraContext, COLOR_DEV_SET_INFO);
MgraLine(m_TheGraContext, m_EpochGraphBufId, m_LastDevPosX, m_LastDevPosY, CurDevPosX, CurDevPosY);
}
m_LastTrainPosX = CurTrainPosX;
m_LastTrainPosY = CurTrainPosY;
m_LastDevPosX = CurDevPosX;
m_LastDevPosY = CurDevPosY;
MgraColor(m_TheGraContext, COLOR_GENERAL_INFO);
MIL_TEXT_CHAR EpochText[128];
MosSprintf(EpochText, 128, MIL_TEXT("Epoch %d completed"), CurEpoch);
MgraText(m_TheGraContext, m_EpochGraphBufId, MARGIN, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y + 25, EpochText);
}
void CTrainEvolutionDashboard::UpdateLossGraph(MIL_DOUBLE LossError, MIL_INT MiniBatchIdx, MIL_INT EpochIdx, MIL_INT NbBatchPerEpoch)
{
MIL_INT NBMiniBatch = m_MaxEpoch * NbBatchPerEpoch;
MIL_INT CurMiniBatch = EpochIdx * NbBatchPerEpoch + MiniBatchIdx;
MIL_DOUBLE XRatio = ((MIL_DOUBLE)CurMiniBatch / (MIL_DOUBLE)(NBMiniBatch));
MIL_INT CurTrainMBPosX = MARGIN + (MIL_INT)(XRatio * (MIL_DOUBLE)GRAPH_SIZE_X);
const MIL_DOUBLE MaxVal = std::pow(10.0, LOSS_EXPONENT_MAX);
const MIL_INT NbTick = LOSS_EXPONENT_MAX - LOSS_EXPONENT_MIN;
LossError = std::min<MIL_DOUBLE>(LossError, MaxVal);
MIL_DOUBLE Log10RemapPos = std::max<MIL_DOUBLE>(log10(LossError) + (-LOSS_EXPONENT_MIN), 0.0);
MIL_DOUBLE YRatio = Log10RemapPos / (MIL_DOUBLE)NbTick;
MIL_INT CurTrainMBPosY = GRAPH_TOP_MARGIN + (MIL_INT)((MIL_DOUBLE)GRAPH_SIZE_Y*(1.0 - YRatio));
if(EpochIdx == 0 && MiniBatchIdx == 0)
{
MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
MgraDot(m_TheGraContext, m_LossGraphBufId, CurTrainMBPosX, CurTrainMBPosY);
}
else
{
MgraColor(m_TheGraContext, COLOR_TRAIN_SET_INFO);
MgraLine(m_TheGraContext, m_LossGraphBufId, m_LastTrainMinibatchPosX, m_LastTrainMinibatchPosY, CurTrainMBPosX, CurTrainMBPosY);
}
m_LastTrainMinibatchPosX = CurTrainMBPosX;
m_LastTrainMinibatchPosY = CurTrainMBPosY;
MgraColor(m_TheGraContext, COLOR_GENERAL_INFO);
MgraText(m_TheGraContext, m_LossGraphBufId, MARGIN, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y + 25, MIL_TEXT(" "));
MIL_TEXT_CHAR EpochText[512];
MosSprintf(EpochText, 512, MIL_TEXT("Epoch %d :: Minibatch %d"), EpochIdx, MiniBatchIdx);
MgraText(m_TheGraContext, m_LossGraphBufId, MARGIN, GRAPH_TOP_MARGIN + GRAPH_SIZE_Y + 25, EpochText);
}
void CTrainEvolutionDashboard::UpdateProgression(MIL_INT MinibatchIdx, MIL_INT EpochIdx, MIL_INT NbBatchPerEpoch)
{
const MIL_INT YMargin = 20;
const MIL_INT TextHeight = 30;
const MIL_INT NbMinibatch = m_MaxEpoch * NbBatchPerEpoch;
const MIL_INT NbMinibatchDone = EpochIdx * NbBatchPerEpoch + MinibatchIdx + 1;
const MIL_INT NbMinibatchRemaining = NbMinibatch - NbMinibatchDone - 1;
MgraColor(m_TheGraContext, COLOR_GENERAL_INFO);
if(EpochIdx == 0)
{
MgraText(m_TheGraContext, m_ProgressionInfoBufId, MARGIN, YMargin, MIL_TEXT("Estimated remaining time: N/A"));
}
else
{
MIL_DOUBLE MinibatchBenchMean = m_EpochBenchMean/(MIL_DOUBLE)NbBatchPerEpoch;
MIL_DOUBLE RemainingTime = MinibatchBenchMean * (MIL_DOUBLE)NbMinibatchRemaining;
MIL_TEXT_CHAR RemainingTimeText[512];
MosSprintf(RemainingTimeText, 512, MIL_TEXT("Estimated remaining time: %8.0lf seconds"), RemainingTime);
if(NbMinibatchDone == NbMinibatch)
MgraText(m_TheGraContext, m_ProgressionInfoBufId, MARGIN, YMargin, MIL_TEXT("Training completed! "));
else
MgraText(m_TheGraContext, m_ProgressionInfoBufId, MARGIN, YMargin, RemainingTimeText);
}
const MIL_INT ProgressionBarWidth = m_DashboardWidth - 2 * MARGIN;
const MIL_INT ProgressionBarHeight = 30;
MgraColor(m_TheGraContext, COLOR_GENERAL_INFO);
MgraRectFill(m_TheGraContext, m_ProgressionInfoBufId, MARGIN, YMargin + TextHeight, MARGIN + ProgressionBarWidth, YMargin + TextHeight + ProgressionBarHeight);
MIL_DOUBLE PercentageComplete = (MIL_DOUBLE)(NbMinibatchDone) / (MIL_DOUBLE)(NbMinibatch);
MIL_INT PercentageCompleteWidth = (MIL_INT)(PercentageComplete*ProgressionBarWidth);
MgraColor(m_TheGraContext, COLOR_PROGRESS_BAR);
MgraRectFill(m_TheGraContext, m_ProgressionInfoBufId, MARGIN, YMargin + TextHeight, MARGIN + PercentageCompleteWidth, YMargin + TextHeight + ProgressionBarHeight);
}
CPredictResultDisplay::CPredictResultDisplay(MIL_ID MilSystem, MIL_ID MilDisplay, MIL_ID TestDataset):
m_MilSystem(MilSystem),
m_MilDisplay(MilDisplay),
m_MaxTrainImageSize(-1),
m_MilDispImage(M_NULL),
m_MilDispChild(M_NULL),
m_MilOverlay(M_NULL),
m_GraContext(M_NULL),
COLOR_PREDICT_INFO(M_COLOR_GREEN),
MARGIN(100)
{
MIL_INT NbClassDef = MclassInquire(TestDataset, M_DEFAULT, M_NUMBER_OF_CLASSES, M_NULL);
std::vector<MIL_ID> ClassImages(NbClassDef);
for(MIL_INT i = 0; i < NbClassDef; i++)
{
ClassImages[i] = MclassInquire(TestDataset, M_CLASS_INDEX(i), M_CLASS_ICON_ID + M_TYPE_MIL_ID, M_NULL);
MIL_INT SizeX = MbufInquire(ClassImages[i], M_SIZE_X, M_NULL);
MIL_INT SizeY = MbufInquire(ClassImages[i], M_SIZE_Y, M_NULL);
m_MaxTrainImageSize = std::max<MIL_INT>(m_MaxTrainImageSize, SizeX);
m_MaxTrainImageSize = std::max<MIL_INT>(m_MaxTrainImageSize, SizeY);
}
m_MilDispImage = MbufAllocColor(m_MilSystem, 3, 2 * m_MaxTrainImageSize + MARGIN, NbClassDef*m_MaxTrainImageSize, 8 + M_UNSIGNED, M_IMAGE + M_PROC + M_DISP, M_UNIQUE_ID);
MbufClear(m_MilDispImage, M_COLOR_BLACK);
m_MilDispChild = MbufChild2d(m_MilDispImage, MARGIN / 2, m_MaxTrainImageSize, m_MaxTrainImageSize, m_MaxTrainImageSize, M_UNIQUE_ID);
m_GraContext = MgraAlloc(MilSystem, M_UNIQUE_ID);
MgraColor(m_GraContext, M_COLOR_RED);
MIL_INT PosY = 0;
for(const auto& Image : ClassImages)
{
MbufCopyColor2d(Image, m_MilDispImage, M_ALL_BANDS, 0, 0, M_ALL_BANDS, m_MaxTrainImageSize + MARGIN, PosY, m_MaxTrainImageSize, m_MaxTrainImageSize);
MgraRect(m_GraContext, m_MilDispImage, m_MaxTrainImageSize + MARGIN, PosY, m_MaxTrainImageSize + MARGIN + m_MaxTrainImageSize - 1, PosY + m_MaxTrainImageSize - 1);
PosY += m_MaxTrainImageSize;
}
MdispSelect(m_MilDisplay, m_MilDispImage);
MdispControl(m_MilDisplay, M_OVERLAY, M_ENABLE);
m_MilOverlay = MdispInquire(MilDisplay, M_OVERLAY_ID, M_NULL);
}
CPredictResultDisplay::~CPredictResultDisplay()
{
m_GraContext = M_NULL;
m_MilDispChild = M_NULL;
m_MilDispImage = M_NULL;
}
void CPredictResultDisplay::Update(MIL_ID ImageToPredict, MIL_INT BestIndex, MIL_DOUBLE BestScore)
{
MdispControl(m_MilDisplay, M_UPDATE, M_DISABLE);
MdispControl(m_MilDisplay, M_OVERLAY_CLEAR, M_TRANSPARENT_COLOR);
MbufCopy(ImageToPredict, m_MilDispChild);
MIL_INT RectOffsetX = m_MaxTrainImageSize + 100;
MIL_INT RectOffsetY = BestIndex * m_MaxTrainImageSize;
MgraColor(m_GraContext, COLOR_PREDICT_INFO);
MgraRect(m_GraContext, m_MilOverlay, RectOffsetX, RectOffsetY, RectOffsetX + m_MaxTrainImageSize - 1, RectOffsetY + m_MaxTrainImageSize - 1);
MIL_TEXT_CHAR Accuracy_text[256];
MosSprintf(Accuracy_text, 256, MIL_TEXT("%.2lf%%"), BestScore);
MgraControl(m_GraContext, M_BACKGROUND_MODE, M_OPAQUE);
MgraFont(m_GraContext, M_FONT_DEFAULT_SMALL);
MgraText(m_GraContext, m_MilOverlay, RectOffsetX + 2, RectOffsetY + 2, Accuracy_text);
MdispControl(m_MilDisplay, M_UPDATE, M_ENABLE);
}