Skip to content

File epiassert-bones.hpp

File List > epiworld > epiassert-bones.hpp

Go to the documentation of this file

#ifndef EPIWORLD_EPIASSERT_HPP
#define EPIWORLD_EPIASSERT_HPP

namespace epiassert_detail {

    template<typename T, typename = void>
    struct is_iterable : std::false_type {};

    template<typename T>
    struct is_iterable<T, std::void_t<
        decltype(std::begin(std::declval<const T &>())),
        decltype(std::end(std::declval<const T &>()))
    >> : std::true_type {};

    template<>
    struct is_iterable<std::string> : std::false_type {};

    template<typename T>
    inline std::string to_str(const T & v)
    {
        std::ostringstream oss;
        oss << v;
        return oss.str();
    }

} // namespace epiassert_detail

class EpiAssert {
private:
    static std::string fmt_location(const std::string & caller)
    {
        if (caller.empty())
            return "";
        return " (in '" + caller + "')";
    }

public:
    // -----------------------------------------------------------------
    //  check_bounds – value(s) in [lower, upper]
    // -----------------------------------------------------------------

    template<typename T, typename BoundT>
    static void check_bounds(
        const T      value,
        const BoundT lower,
        const BoundT upper,
        const std::string varname = "value",
        const std::string caller  = ""
    )
    {
        if (lower > upper)
        {
            throw std::invalid_argument(
                "check_bounds: 'lower' (" +
                epiassert_detail::to_str(lower) +
                ") must be <= 'upper' (" +
                epiassert_detail::to_str(upper) + ")" +
                fmt_location(caller)
            );
        }

        if constexpr (epiassert_detail::is_iterable<T>::value)
        {
            size_t idx = 0;
            for (const auto & v : value)
            {
                if ((v < lower) || (v > upper))
                {
                    throw std::range_error(
                        "'" + varname + "[" +
                        std::to_string(idx) + "]' must be in [" +
                        epiassert_detail::to_str(lower) + ", " +
                        epiassert_detail::to_str(upper) + "], but got " +
                        epiassert_detail::to_str(v) +
                        fmt_location(caller)
                    );
                }
                ++idx;
            }
        }
        else
        {
            if ((value < lower) || (value > upper))
            {
                throw std::range_error(
                    "'" + varname + "' must be in [" +
                    epiassert_detail::to_str(lower) + ", " +
                    epiassert_detail::to_str(upper) + "], but got " +
                    epiassert_detail::to_str(value) +
                    fmt_location(caller)
                );
            }
        }
    }

    // -----------------------------------------------------------------
    //  check_non_negative – value(s) >= 0
    // -----------------------------------------------------------------

    template<typename T>
    static void check_non_negative(
        const T value,
        const std::string varname = "value",
        const std::string caller  = ""
    )
    {
        if constexpr (epiassert_detail::is_iterable<T>::value)
        {
            size_t idx = 0;
            for (const auto & v : value)
            {
                if (v < 0)
                {
                    throw std::range_error(
                        "'" + varname + "[" +
                        std::to_string(idx) + "]' must be non-negative"
                        ", but got " + epiassert_detail::to_str(v) +
                        fmt_location(caller)
                    );
                }
                ++idx;
            }
        }
        else
        {
            if (value < 0)
            {
                throw std::range_error(
                    "'" + varname + "' must be non-negative"
                    ", but got " + epiassert_detail::to_str(value) +
                    fmt_location(caller)
                );
            }
        }
    }

    // -----------------------------------------------------------------
    //  check_probability – value(s) in [0, 1]
    // -----------------------------------------------------------------

    template<typename T>
    static void check_probability(
        const T value,
        const std::string varname = "value",
        const std::string caller  = ""
    )
    {
        if constexpr (epiassert_detail::is_iterable<T>::value)
        {
            size_t idx = 0;
            for (const auto & v : value)
            {
                if (v < 0.0 || v > 1.0)
                {
                    throw std::range_error(
                        "'" + varname + "[" +
                        std::to_string(idx) +
                        "]' must be a probability in [0, 1]"
                        ", but got " + epiassert_detail::to_str(v) +
                        fmt_location(caller)
                    );
                }
                ++idx;
            }
        }
        else
        {
            if (value < 0.0 || value > 1.0)
            {
                throw std::range_error(
                    "'" + varname +
                    "' must be a probability in [0, 1]"
                    ", but got " + epiassert_detail::to_str(value) +
                    fmt_location(caller)
                );
            }
        }
    }

    // -----------------------------------------------------------------
    //  check_sum – container elements sum to target ± tolerance
    // -----------------------------------------------------------------

    template<typename T>
    static void check_sum(
        const T      values,
        double         target,
        const std::string varname   = "values",
        const std::string caller    = "",
        double         tolerance = 1e-8
    )
    {
        static_assert(
            epiassert_detail::is_iterable<T>::value,
            "check_sum requires an iterable type."
        );

        if (tolerance < 0.0)
        {
            throw std::invalid_argument(
                "check_sum: 'tolerance' must be non-negative, but got " +
                epiassert_detail::to_str(tolerance) +
                fmt_location(caller)
            );
        }

        double s = 0.0;
        for (const auto & v : values)
            s += static_cast<double>(v);

        if (std::abs(s - target) > tolerance)
        {
            throw std::invalid_argument(
                "'" + varname + "' elements must sum to " +
                epiassert_detail::to_str(target) +
                " (tolerance " + epiassert_detail::to_str(tolerance) +
                "), but got " + epiassert_detail::to_str(s) +
                fmt_location(caller)
            );
        }
    }

    // -----------------------------------------------------------------
    //  check_size – container has expected number of elements
    // -----------------------------------------------------------------

    template<typename T>
    static void check_size(
        const T values,
        size_t    expected,
        const std::string varname = "values",
        const std::string caller  = ""
    )
    {
        static_assert(
            epiassert_detail::is_iterable<T>::value,
            "check_size requires an iterable type."
        );

        const auto actual_size = static_cast<size_t>(
            std::distance(std::begin(values), std::end(values))
        );

        if (actual_size != expected)
        {
            throw std::invalid_argument(
                "'" + varname + "' must have " +
                std::to_string(expected) + " elements, but got " +
                std::to_string(actual_size) +
                fmt_location(caller)
            );
        }
    }

    // -----------------------------------------------------------------
    //  check – custom predicate validation
    // -----------------------------------------------------------------

    template<typename T, typename Predicate>
    static void check(
        const T     value,
        Predicate     pred,
        const std::string message,
        const std::string caller = ""
    )
    {
        if (!pred(value))
        {
            throw std::invalid_argument(
                message + fmt_location(caller)
            );
        }
    }
};


#endif