diff --git a/src/transformer/fused_infer_attention_score/ophost/fused_infer_attention_score_tiling.h b/src/transformer/fused_infer_attention_score/ophost/fused_infer_attention_score_tiling.h index 9d8d5b5d3b43a874bde6d35a00a1e681451f83d8..fe492629a2e7a2053926c88a8b6cdce5100b0e09 100644 --- a/src/transformer/fused_infer_attention_score/ophost/fused_infer_attention_score_tiling.h +++ b/src/transformer/fused_infer_attention_score/ophost/fused_infer_attention_score_tiling.h @@ -90,6 +90,7 @@ REGISTER_TILING_DATA_CLASS(FusedInferAttentionScore_1000000000000000116, PromptF REGISTER_TILING_DATA_CLASS(FusedInferAttentionScore_1000000000000111112, PromptFlashAttentionTilingData) REGISTER_TILING_DATA_CLASS(FusedInferAttentionScore_1000000000000121112, PromptFlashAttentionTilingData) REGISTER_TILING_DATA_CLASS(FusedInferAttentionScore_1000000000000011112, PromptFlashAttentionTilingData) +REGISTER_TILING_DATA_CLASS(FusedInferAttentionScore_1000000000002011112, PromptFlashAttentionTilingData) REGISTER_TILING_DATA_CLASS(FusedInferAttentionScore_1000000000000021112, PromptFlashAttentionTilingData) // PA tilingkey REGISTER_TILING_DATA_CLASS(FusedInferAttentionScore_1000000000010101612, PromptFlashAttentionTilingData) diff --git a/src/transformer/prompt_flash_attention/ophost/prompt_flash_attention_tiling.cpp b/src/transformer/prompt_flash_attention/ophost/prompt_flash_attention_tiling.cpp index a92ef7c361d6878b3b0450b18881d1c55100f5d1..01a4b241c23d8c42f767f93dae096a1bf264f7fa 100644 --- a/src/transformer/prompt_flash_attention/ophost/prompt_flash_attention_tiling.cpp +++ b/src/transformer/prompt_flash_attention/ophost/prompt_flash_attention_tiling.cpp @@ -643,9 +643,11 @@ bool PromptFlashAttentionTiling::EnableSplitSeqOneN(PromptFlashAttentionTilingDa bool enableLeftPadding = ((contextKeyParams.queryPaddingSize != nullptr) || (contextKeyParams.kvPaddingSize != nullptr)); bool enableRingAttention = (contextKeyParams.isSoftMaxLseEnable == true); + bool flag1 = (inputType == ge::DT_FLOAT16) && (contextKeyParams.kDataType == ge::DT_FLOAT16) && (outputType == ge::DT_FLOAT16); + bool flag2 = (inputType == ge::DT_BF16) && (contextKeyParams.kDataType == ge::DT_BF16) && (outputType == ge::DT_BF16); + GetPreNextTokensLeftUp(tilingData, actualSeqLength, actualSeqLengthKV, preTokensLeftUp, nextTokensLeftUp); - bool baseCond = (hDivN == MATMUL_NORM_MIN_HEADSIZE) && (inputType == ge::DT_FLOAT16) && (contextKeyParams.kDataType == ge::DT_FLOAT16) && - (outputType == ge::DT_FLOAT16) && (usePseShift == 0) && (inputLayout == InputLayout::BNSD); + bool baseCond = (hDivN == MATMUL_NORM_MIN_HEADSIZE) && (flag1 || flag2) && (usePseShift == 0) && (inputLayout == InputLayout::BNSD); bool seqMode0 = (baseParams->get_sparseMode() == SPARSE_MODE_BAND) && (contextKeyParams.maskDataType == ge::DT_BOOL) && (nextTokensLeftUp == 0) && actualSeqLength >= seq16K && (b * n >= 12); bool seqMode1 = (baseParams->get_sparseMode() == SPARSE_MODE_NO_MASK && contextKeyParams.attentionMask == nullptr) && actualSeqLength >= seq8K; if (baseCond && !isKVHasPrefix && !enableLeftPadding && !enableRingAttention && (seqMode0 || seqMode1) && @@ -3250,6 +3252,7 @@ ge::graphStatus PromptFlashAttentionTiling::RunBigKernelTilingWithParams(Context size_t* workspaces = contextKeyParams.workspaceSize; workspaces[0] = GetPFAWorkSpaceSize(tilingData); OPS_LOG_I(contextKeyParams.opName, "The Tiling key is %lu", tilingKey); + OPS_LOG_I(contextKeyParams.opName, "########## bf16 l1reuse enabled ##########"); return ge::GRAPH_SUCCESS; } diff --git a/src/transformer/prompt_flash_attention/prompt_flash_attention.cpp b/src/transformer/prompt_flash_attention/prompt_flash_attention.cpp index 8725ffdc7048fdba7b3459813791871b1ea854bc..816b96363c29abce53408449395a2ed686258529 100644 --- a/src/transformer/prompt_flash_attention/prompt_flash_attention.cpp +++ b/src/transformer/prompt_flash_attention/prompt_flash_attention.cpp @@ -352,6 +352,9 @@ extern "C" __global__ __aicore__ void prompt_flash_attention_FIAS(__gm__ uint8_t } else if (TILING_KEY_IS(1000000000000011112)) { // BNSD layout bf16 cvdiff INVOKE_PFA_GENERAL_OP_IMPL(PromptFlashAttentionS1s2Bns1X910, PFAType); + } else if (TILING_KEY_IS(1000000000002011112)) { + // BNSD layout bf16 cvdiff + INVOKE_PFA_GENERAL_OP_IMPL(PromptFlashAttentionS1s2Bns1X910, PFAType); } else if (TILING_KEY_IS(1000000000010011112)) { // BNSD layout bf16 cvdiff, enable PA INVOKE_PFA_GENERAL_OP_IMPL(PromptFlashAttentionS1s2Bns1X910, PFAType);