//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "ModelAccuracyChecker.hpp"
#include <boost/filesystem.hpp>
#include <boost/log/trivial.hpp>
#include <map>
#include <vector>

namespace armnnUtils
{

armnnUtils::ModelAccuracyChecker::ModelAccuracyChecker(const std::map<std::string, std::string>& validationLabels,
                                                       const std::vector<LabelCategoryNames>& modelOutputLabels)
    : m_GroundTruthLabelSet(validationLabels)
    , m_ModelOutputLabels(modelOutputLabels)
{}

float ModelAccuracyChecker::GetAccuracy(unsigned int k)
{
    if (k > 10)
    {
        BOOST_LOG_TRIVIAL(warning) << "Accuracy Tool only supports a maximum of Top 10 Accuracy. "
                                      "Printing Top 10 Accuracy result!";
        k = 10;
    }
    unsigned int total = 0;
    for (unsigned int i = k; i > 0; --i)
    {
        total += m_TopK[i];
    }
    return static_cast<float>(total * 100) / static_cast<float>(m_ImagesProcessed);
}

// Split a string into tokens by a delimiter
std::vector<std::string>
    SplitBy(const std::string& originalString, const std::string& delimiter, bool includeEmptyToken)
{
    std::vector<std::string> tokens;
    size_t cur  = 0;
    size_t next = 0;
    while ((next = originalString.find(delimiter, cur)) != std::string::npos)
    {
        // Skip empty tokens, unless explicitly stated to include them.
        if (next - cur > 0 || includeEmptyToken)
        {
            tokens.push_back(originalString.substr(cur, next - cur));
        }
        cur = next + delimiter.size();
    }
    // Get the remaining token
    // Skip empty tokens, unless explicitly stated to include them.
    if (originalString.size() - cur > 0 || includeEmptyToken)
    {
        tokens.push_back(originalString.substr(cur, originalString.size() - cur));
    }
    return tokens;
}

// Remove any preceding and trailing character specified in the characterSet.
std::string Strip(const std::string& originalString, const std::string& characterSet)
{
    BOOST_ASSERT(!characterSet.empty());
    const std::size_t firstFound = originalString.find_first_not_of(characterSet);
    const std::size_t lastFound  = originalString.find_last_not_of(characterSet);
    // Return empty if the originalString is empty or the originalString contains only to-be-striped characters
    if (firstFound == std::string::npos || lastFound == std::string::npos)
    {
        return "";
    }
    return originalString.substr(firstFound, lastFound + 1 - firstFound);
}
}    // namespace armnnUtils