From e24f94cd586a8651ca3893948ecdf48eb6f92602 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=8D=E9=87=91=E5=BC=88?= Date: Mon, 23 Sep 2024 10:16:30 +0800 Subject: [PATCH] basic ut for new fp8 types --- tests/ut/exe_graph/tiling_data_unittest.cc | 88 +++++++++++++++++++ .../ut/register/testcase/register_unittest.cc | 65 ++++++++++++++ 2 files changed, 153 insertions(+) diff --git a/tests/ut/exe_graph/tiling_data_unittest.cc b/tests/ut/exe_graph/tiling_data_unittest.cc index d9ed660d2..b9ec1081a 100644 --- a/tests/ut/exe_graph/tiling_data_unittest.cc +++ b/tests/ut/exe_graph/tiling_data_unittest.cc @@ -7,6 +7,7 @@ * See LICENSE in the root of the software repository for the full text of the License. * ===================================================================================================================*/ +#include #include "exe_graph/runtime/tiling_data.h" #include "common/util/tiling_utils.h" #include "faker/kernel_run_context_faker.h" @@ -39,7 +40,44 @@ FakeKernelContextHolder BuildTestContext() { return holder; } +template +class Fp8 { +public: + Fp8(uint8_t u8) : u8_(u8) {} + bool IsNanOrInf() const; + uint8_t Exp() const { + return (u8_ & kExpMask) >> kExpOffset; + } + uint8_t Mantissa() const { + return u8_ & kMantissaMask; + } + explicit operator float() const { + float sign = (u8_ & kSignMask) != 0 ? -1.0 : 1.0; + bool normal = Exp() != 0; + int8_t exp = Exp() - bias; + uint8_t mantissa = Mantissa(); + float f32 = 0.0; + for (uint8_t off = 0; off < kExpOffset; off++) { + if ((mantissa & (1U << off)) == 0) { + continue; + } + f32 += std::pow(2, static_cast(off - kExpOffset)); + } + if (normal) { + return sign * std::pow(2, exp) * (1 + f32); + } else { + return sign * std::pow(2, -bias + 1) * f32; + } + } +private: + static constexpr uint8_t kSignMask = 0x80U; + static constexpr uint8_t kMantissaMask = 0x7FU >> width; + static constexpr uint8_t kExpMask = 0x7FU & ~kMantissaMask; + static constexpr uint8_t kExpOffset = 7 - width; + uint8_t u8_; +}; } // namespace + TEST_F(TilingDataUT, AppendSameTypesOk) { auto data = TilingData::CreateCap(2048); auto tiling_data = reinterpret_cast(data.get()); @@ -212,6 +250,56 @@ TEST_F(TilingDataUT, AppendAttrFloat32ToBfloat16Ok) { EXPECT_EQ(tiling_data->GetDataSize(), sizeof(uint16_t)); } +using Fp8E5m2 = Fp8<5, 15>; +template<> +bool Fp8E5m2::IsNanOrInf() const { + return Exp() == 0b11111; +} + +using Fp8E4m3fn = Fp8<4, 7>; +template<> +bool Fp8E4m3fn::IsNanOrInf() const { + return Exp() == 0b1111 && Mantissa() == 0b111; +} + +TEST_F(TilingDataUT, FloatToFp8_Ok) { + for (uint8_t u8 = 0; u8 < 255; u8++) { + Fp8E5m2 e5m2(u8); + Fp8E4m3fn e4m3fn(u8); + + if (!e5m2.IsNanOrInf()) { + EXPECT_EQ(u8, optiling::FloatToF8E5m2(static_cast(e5m2))); + } + if (!e4m3fn.IsNanOrInf()) { + EXPECT_EQ(u8, optiling::FloatToF8E4m3fn(static_cast(e4m3fn))); + } + } +} + +TEST_F(TilingDataUT, AppendAttrFloat32ToFloat8E5m2Ok) { + auto data = TilingData::CreateCap(20); + auto tiling_data = reinterpret_cast(data.get()); + auto holder = BuildTestContext(); + auto context = holder.GetContext(); + EXPECT_NE(context, nullptr); + EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 3, AttrDataType::kFloat32, AttrDataType::KFloat8E5m2), + ge::GRAPH_SUCCESS); + EXPECT_EQ(*reinterpret_cast(tiling_data->GetData()), optiling::FloatToF8E5m2(10.101)); + EXPECT_EQ(tiling_data->GetDataSize(), sizeof(uint8_t)); +} + +TEST_F(TilingDataUT, AppendAttrFloat32ToFloat8E4m3fnOk) { + auto data = TilingData::CreateCap(20); + auto tiling_data = reinterpret_cast(data.get()); + auto holder = BuildTestContext(); + auto context = holder.GetContext(); + EXPECT_NE(context, nullptr); + EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 3, AttrDataType::kFloat32, AttrDataType::KFloat8E4m3fn), + ge::GRAPH_SUCCESS); + EXPECT_EQ(*reinterpret_cast(tiling_data->GetData()), optiling::FloatToF8E4m3fn(10.101)); + EXPECT_EQ(tiling_data->GetDataSize(), sizeof(uint8_t)); +} + TEST_F(TilingDataUT, AppendAttrFloat32ToInt32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); diff --git a/tests/ut/register/testcase/register_unittest.cc b/tests/ut/register/testcase/register_unittest.cc index c54a7f5be..8ac4f27cc 100644 --- a/tests/ut/register/testcase/register_unittest.cc +++ b/tests/ut/register/testcase/register_unittest.cc @@ -278,6 +278,11 @@ extern "C" const char *DoOpTilingForCompile(const char *optype, const char *comp size_t run_info_len, uint64_t *elapse, const char *extra_info); + +extern "C" int OpTilingForCompile(const char *optype, const char *compile_info, const char *compile_info_hash, + const char *inputs, const char *outputs, const char *attrs, char *run_info_json, + size_t run_info_len, uint64_t *elapse, const char *extra_info); + bool op_tiling_stub_v2(const Operator &op, const utils::OpCompileInfo &compile_info, utils::OpRunInfo &run_info) { return true; } @@ -2341,4 +2346,64 @@ TEST_F(UtestRegister, new_optiling_py_interface_ok_with_bf16_data) { default_space_registry->op_impl_registries_.clear(); } +UINT32 OpTilingStubFp8(gert::TilingContext *kernel_context) { + auto tensor0 = kernel_context->GetInputTensor(0); + auto tensor1 = kernel_context->GetInputTensor(1); + std::vector real_data = {1.1, 2.1, 3.1, 4.1}; + for (size_t i = 0UL; i < 4UL; ++i) { + EXPECT_EQ((tensor0->GetData())[i], optiling::FloatToF8E5m2(real_data[i])); + } + for (size_t i = 0UL; i < 4UL; ++i) { + EXPECT_EQ((tensor1->GetData())[i], optiling::FloatToF8E4m3fn(real_data[i])); + } + return ge::GRAPH_SUCCESS; +} +TEST_F(UtestRegister, new_optiling_py_interface_ok_with_fp8_data) { + const nlohmann::json input = R"([ + { + "dtype": "float8_e5m2", + "const_value": [1.1, 2.1, 3.1, 4.1], + "shape": [4, 4, 4, 4], + "ori_shape": [4, 4, 4, 4], + "format": "ND" + }, + { + "dtype": "float8_e4m3fn", + "const_value": [1.1, 2.1, 3.1, 4.1], + "shape": [4, 4, 4, 4], + "ori_shape": [4, 4, 4, 4], + "format": "ND" + }])"_json; + const nlohmann::json output = R"([ + { + "dtype": "int8", + "shape": [4, 4, 4, 4], + "ori_shape": [4, 4, 4, 4], + "format": "ND", + "ori_format": "ND" + }])"_json; + const char *op_type = "DummyFp8Op"; + std::string runinfo(130U, 'a'); + const nlohmann::json attrs = R"([{"name": "op_para_size", "dtype": "int", "value": 50}])"_json; + const size_t max_tiling_size = 50U; + + auto space_registry = std::make_shared(); + auto registry_holder = std::make_shared(); + gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; + op_impl_func.tiling = OpTilingStubFp8; + op_impl_func.tiling_parse = OpTilingParseStubV5; + op_impl_func.compile_info_creator = CreateCompileInfo; + op_impl_func.compile_info_deleter = DeleteCompileInfo; + op_impl_func.max_tiling_data_size = max_tiling_size; + registry_holder->AddTypesToImpl(op_type, op_impl_func); + space_registry->AddRegistry(registry_holder); + gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); + + EXPECT_EQ(OpTilingForCompile(op_type, "", "", input.dump().c_str(), output.dump().c_str(), + attrs.dump().c_str(), const_cast(runinfo.c_str()), runinfo.length(), nullptr, nullptr), + 1); + auto default_space_registry = gert::DefaultOpImplSpaceRegistry::GetInstance().GetDefaultSpaceRegistry(); + default_space_registry->merged_types_to_impl_.clear(); + default_space_registry->op_impl_registries_.clear(); +} -- Gitee