/*
 * Copyright (C) 2019-2023 Intel Corporation
 *
 * SPDX-License-Identifier: MIT
 *
 */

#pragma once

#include "shared/source/utilities/stackvec.h"
#include "shared/test/common/cmd_parse/hw_parse.h"

#include <string>
#include <vector>

namespace NEO {

struct CmdValidator {
    CmdValidator() {
    }
    virtual ~CmdValidator() = default;
    virtual bool operator()(GenCmdList::iterator it, size_t numInSection, const std::string &member, std::string &outFailReason) = 0;
};

template <typename ChildT>
struct CmdValidatorWithStaticStorage : CmdValidator {
    static ChildT *get() {
        static ChildT val;
        return &val;
    }
};

template <typename CmdT, typename ReturnT, ReturnT (CmdT::*getter)() const, ReturnT expected>
struct GenericCmdValidator : CmdValidatorWithStaticStorage<GenericCmdValidator<CmdT, ReturnT, getter, expected>> {
    bool operator()(GenCmdList::iterator it, size_t numInSection, const std::string &member, std::string &outFailReason) override {
        auto cmd = genCmdCast<CmdT *>(*it);
        UNRECOVERABLE_IF(cmd == nullptr);
        if (expected != (cmd->*getter)()) {
            outFailReason = member + " - expected: " + std::to_string(expected) + ", got: " + std::to_string((cmd->*getter)());
            return false;
        }
        return true;
    }
};

struct NamedValidator {
    NamedValidator(CmdValidator *validator)
        : NamedValidator(validator, "Unspecified") {
    }

    NamedValidator(CmdValidator *validator, const char *name)
        : validator(validator), name(name) {
    }

    CmdValidator *validator;
    const char *name;
};

#define EXPECT_MEMBER(TYPE, FUNC, EXPECTED) \
    NamedValidator { GenericCmdValidator<TYPE, std::invoke_result_t<decltype(&TYPE::FUNC), TYPE>, &TYPE::FUNC, EXPECTED>::get(), #FUNC }

using Expects = std::vector<NamedValidator>;

struct MatchCmd {
    MatchCmd(int amount, bool matchesAny)
        : amount(amount), matchesAny(matchesAny) {
    }

    MatchCmd(int amount)
        : MatchCmd(amount, false) {
    }

    virtual ~MatchCmd() = default;

    virtual bool matches(GenCmdList::iterator it) const = 0;
    virtual bool validates(GenCmdList::iterator it, std::string &outReason) const = 0;
    virtual const char *getName() const = 0;
    virtual void capture(GenCmdList::iterator it) = 0;

    int getExpectedCount() const {
        return amount;
    }

    bool getMatchesAny() const {
        return matchesAny;
    }

  protected:
    int amount = 0;
    bool matchesAny = false;
};

inline constexpr int32_t anyNumber = -1;
inline constexpr int32_t atLeastOne = -2;
inline std::string countToString(int32_t count) {
    if (count == anyNumber) {
        return "AnyNumber";
    } else if (count == atLeastOne) {
        return "AtLeastOne";
    } else {
        return std::to_string(count);
    }
}

inline bool notPreciseNumber(int32_t count) {
    return (count == anyNumber) || (count == atLeastOne);
}

struct MatchAnyCmd : MatchCmd {
    MatchAnyCmd(int amount)
        : MatchCmd(amount, true) {
        if (amount > 0) {
            captured.reserve(amount);
        }
    }
    bool matches(GenCmdList::iterator it) const override {
        return true;
    }
    bool validates(GenCmdList::iterator it, std::string &outReason) const override {
        return true;
    }
    void capture(GenCmdList::iterator it) override {
        captured.push_back(*it);
    }
    const char *getName() const override {
        return "AnyCommand";
    }

  protected:
    StackVec<const void *, 16> captured;
};

template <typename FamilyType, typename CmdType>
struct MatchHwCmd : MatchCmd {
    MatchHwCmd(int amount)
        : MatchCmd(amount) {
        if (amount > 0) {
            captured.reserve(amount);
        }
    }

    MatchHwCmd(int amount, Expects &&validators)
        : MatchHwCmd(amount) {
        this->validators.swap(validators);
    }

    bool matches(GenCmdList::iterator it) const override {
        return nullptr != genCmdCast<CmdType *>(*it);
    }

    bool validates(GenCmdList::iterator it, std::string &outReason) const override {
        for (auto &v : validators) {
            if (false == (*v.validator)(it, captured.size(), v.name, outReason)) {
                return false;
            }
        }
        return true;
    }

    void capture(GenCmdList::iterator it) override {
        UNRECOVERABLE_IF(false == matches(it));
        UNRECOVERABLE_IF(captured.size() == static_cast<size_t>(amount));
        captured.push_back(genCmdCast<CmdType *>(*it));
    }

    const char *getName() const override {
        CmdType cmd;
        cmd.init();
        return HardwareParse::getCommandName<FamilyType>(&cmd);
    }

  protected:
    StackVec<const CmdType *, 16> captured;
    Expects validators;
};

template <typename FamilyType>
inline bool expectCmdBuff(GenCmdList::iterator begin, GenCmdList::iterator end,
                          std::vector<MatchCmd *> &&expectedCmdBuffMatchers, std::string *outReason = nullptr) {
    if (expectedCmdBuffMatchers.size() == 0) {
        return begin == end;
    }
    bool failed = false;
    std::string failReason;
    auto it = begin;
    int cmdNum = 0;
    size_t currentMatcher = 0;
    int currentMatcherCount = 0;
    StackVec<std::pair<const char *, bool>, 32> matchedCommandNames;
    auto matchedCommandsString = [&]() -> std::string {
        if (matchedCommandNames.size() == 0) {
            return "EMPTY";
        }
        std::string ret = "";
        for (size_t i = 0; i < matchedCommandNames.size(); ++i) {
            if (matchedCommandNames[i].second) {
                ret += std::to_string(i) + ":ANY(" + matchedCommandNames[i].first + ") ";
            } else {
                ret += std::to_string(i) + ":" + matchedCommandNames[i].first + " ";
            }
        }
        return ret;
    };
    while (it != end) {
        if (currentMatcher < expectedCmdBuffMatchers.size()) {
            auto currentMatcherExpectedCount = expectedCmdBuffMatchers[currentMatcher]->getExpectedCount();
            if (expectedCmdBuffMatchers[currentMatcher]->getMatchesAny() && ((currentMatcherExpectedCount == anyNumber) || ((currentMatcherExpectedCount == atLeastOne) && (currentMatcherCount > 0)))) {
                if (expectedCmdBuffMatchers.size() > currentMatcher + 1) {
                    // eat as many as possible but proceed to next matcher when possible
                    if (expectedCmdBuffMatchers[currentMatcher + 1]->matches(it)) {
                        ++currentMatcher;
                        currentMatcherCount = 0;
                    }
                }
            } else if ((notPreciseNumber(expectedCmdBuffMatchers[currentMatcher]->getExpectedCount())) && (false == expectedCmdBuffMatchers[currentMatcher]->matches(it))) {
                // proceed to next matcher if not matched
                if ((expectedCmdBuffMatchers[currentMatcher]->getExpectedCount() == atLeastOne) && (currentMatcherCount < 1)) {
                    failed = true;
                    failReason = "Unmatched cmd#" + std::to_string(cmdNum) + ":" + HardwareParse::getCommandName<FamilyType>(*it) + " - expected " + std::string(expectedCmdBuffMatchers[currentMatcher]->getName()) + "(" + countToString(expectedCmdBuffMatchers[currentMatcher]->getExpectedCount()) + " - " + std::to_string(currentMatcherCount) + ") after : " + matchedCommandsString();
                    break;
                }
                ++currentMatcher;
                currentMatcherCount = 0;
            }

            while ((currentMatcher < expectedCmdBuffMatchers.size()) && expectedCmdBuffMatchers[currentMatcher]->getExpectedCount() == 0) {
                if (expectedCmdBuffMatchers[currentMatcher]->matches(it)) {
                    failed = true;
                    failReason = "Unmatched cmd#" + std::to_string(cmdNum) + " - expected anything but " + std::string(expectedCmdBuffMatchers[currentMatcher]->getName()) + "(" + countToString(expectedCmdBuffMatchers[currentMatcher]->getExpectedCount()) + " - " + std::to_string(currentMatcherCount) + ") after : " + matchedCommandsString();
                    break;
                }
                ++currentMatcher;
                currentMatcherCount = 0;
            }
        }

        if (currentMatcher >= expectedCmdBuffMatchers.size()) {
            failed = true;
            std::string unmatchedCommands;
            while (it != end) {
                unmatchedCommands += std::to_string(cmdNum) + ":" + HardwareParse::getCommandName<FamilyType>(*it) + " ";
                ++it;
                ++cmdNum;
            }
            failReason = "Unexpected commands at the end of the command buffer : " + unmatchedCommands + ", AFTER : " + matchedCommandsString();
            break;
        }

        if (false == expectedCmdBuffMatchers[currentMatcher]->matches(it)) {
            failed = true;
            failReason = "Unmatched cmd#" + std::to_string(cmdNum) + ":" + HardwareParse::getCommandName<FamilyType>(*it) + " - expected " + std::string(expectedCmdBuffMatchers[currentMatcher]->getName()) + "(" + countToString(expectedCmdBuffMatchers[currentMatcher]->getExpectedCount()) + " - " + std::to_string(currentMatcherCount) + ") after : " + matchedCommandsString();
            break;
        }

        if (false == expectedCmdBuffMatchers[currentMatcher]->validates(it, failReason)) {
            failReason = "cmd#" + std::to_string(cmdNum) + " (" + HardwareParse::getCommandName<FamilyType>(*it) + ") failed validation - reason : " + failReason + " after : " + matchedCommandsString();
            failed = true;
            break;
        }

        matchedCommandNames.push_back(std::make_pair(HardwareParse::getCommandName<FamilyType>(*it), expectedCmdBuffMatchers[currentMatcher]->getMatchesAny()));

        ++currentMatcherCount;
        if (currentMatcherCount == expectedCmdBuffMatchers[currentMatcher]->getExpectedCount()) {
            ++currentMatcher;
            currentMatcherCount = 0;
        }

        ++cmdNum;
        ++it;
    }

    if (failed == false) {
        while ((currentMatcher < expectedCmdBuffMatchers.size()) && ((expectedCmdBuffMatchers[currentMatcher]->getExpectedCount() == 0) || (expectedCmdBuffMatchers[currentMatcher]->getExpectedCount() == anyNumber))) {
            ++currentMatcher;
            currentMatcherCount = 0;
        }

        if (currentMatcher == expectedCmdBuffMatchers.size()) {
            // no more matchers
        } else if (currentMatcher + 1 == expectedCmdBuffMatchers.size()) {
            // last matcher
            auto currentMatcherExpectedCount = expectedCmdBuffMatchers[currentMatcher]->getExpectedCount();
            if ((currentMatcherExpectedCount == atLeastOne) && (currentMatcherCount < 1)) {
                failReason = "Unexpected command buffer end at cmd#" + std::to_string(cmdNum) + " - expected " + expectedCmdBuffMatchers[currentMatcher]->getName() + "(" + countToString(currentMatcherExpectedCount) + " - " + std::to_string(currentMatcherCount) + ") after : " + matchedCommandsString();
                failed = true;
            }
            if ((false == notPreciseNumber(currentMatcherExpectedCount)) && (currentMatcherExpectedCount != currentMatcherCount)) {
                failReason = "Unexpected command buffer end at cmd#" + std::to_string(cmdNum) + " - expected " + expectedCmdBuffMatchers[currentMatcher]->getName() + "(" + countToString(currentMatcherExpectedCount) + " - " + std::to_string(currentMatcherCount) + ") after : " + matchedCommandsString();
                failed = true;
            }
        } else {
            // many matchers left
            std::string expectedMatchers = "";
            int32_t currentMatcherExpectedCount = expectedCmdBuffMatchers[currentMatcher]->getExpectedCount();
            expectedMatchers = expectedCmdBuffMatchers[currentMatcher]->getName() + std::string("(") + countToString(currentMatcherExpectedCount) + " - " + std::to_string(currentMatcherCount) + "), ";
            ++currentMatcher;
            while (currentMatcher < expectedCmdBuffMatchers.size()) {
                currentMatcherExpectedCount = expectedCmdBuffMatchers[currentMatcher]->getExpectedCount();
                expectedMatchers += expectedCmdBuffMatchers[currentMatcher]->getName() + std::string("(") + countToString(currentMatcherExpectedCount) + " - 0), ";
                ++currentMatcher;
            }
            failReason = "Unexpected command buffer end at cmd#" + std::to_string(cmdNum) + " - expected " + expectedMatchers + " after : " + matchedCommandsString();
            failed = true;
        }
    } else {
        if ((it != end) && (++it != end)) {
            ++cmdNum;
            failReason += "\n Unconsumed commands after failed one : ";
            while (it != end) {
                failReason += std::to_string(cmdNum) + ":" + HardwareParse::getCommandName<FamilyType>(*it) + " ";
                ++cmdNum;
                ++it;
            }
        }
    }

    if (failed) {
        if (outReason != nullptr) {
            failReason += "\n Note : Input command buffer was : ";
            it = begin;
            cmdNum = 0;
            while (it != end) {
                failReason += std::to_string(cmdNum) + ":" + HardwareParse::getCommandName<FamilyType>(*it) + " ";
                ++cmdNum;
                ++it;
            }
            *outReason = failReason;
        }
    }

    for (auto *matcher : expectedCmdBuffMatchers) {
        delete matcher;
    }

    return (failed == false);
}

template <typename FamilyType>
inline bool expectCmdBuff(NEO::LinearStream &commandStream, size_t startOffset,
                          std::vector<MatchCmd *> &&expectedCmdBuffMatchers, std::string *outReason = nullptr) {
    HardwareParse hwParser;
    hwParser.parseCommands<FamilyType>(commandStream, startOffset);
    return expectCmdBuff<FamilyType>(hwParser.cmdList.begin(), hwParser.cmdList.end(), std::move(expectedCmdBuffMatchers), outReason);
}

} // namespace NEO
