masking_specialization.hpp Source File

masking_specialization.hpp Source File#

Composable Kernel: masking_specialization.hpp Source File
masking_specialization.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6namespace ck {
7namespace tensor_operation {
8namespace device {
9
15
16#ifndef __HIPCC_RTC__
18{
19 switch(s)
20 {
21 case MaskingSpecialization::MaskDisabled: return "MaskDisabled";
22 case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle";
23 default: return "Unrecognized specialization!";
24 }
25}
26#endif
27
29{
30 __host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const
31 {
32 return false;
33 };
34
35 __host__ __device__ constexpr bool
36 IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const
37 {
38 return false;
39 }
40};
41
43{
44 __host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; }
45
46 __host__ __device__ constexpr bool
47 IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const
48 {
49 return operator()(m + m_tile - 1, n);
50 }
51};
52
53// to track the points which need to be set to -inf on C0
54// Note: no need to reset M padding value, because they will not be stored out.
55template <typename MaskOutPredicate>
57{
58 __host__ __device__ constexpr C0MatrixMask_impl(index_t NRaw)
59 : NRaw_(NRaw), predicate_(MaskOutPredicate{})
60 {
61 }
62
63 __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
64 {
65 return n >= NRaw_;
66 }
67
68 __host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const
69 {
70 return predicate_(m, n) || IsNOutOfBound(n);
71 }
72
73 __host__ __device__ constexpr bool
74 IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const
75 {
76 return predicate_.IsTileSkippable(m, n, m_tile, n_tile);
77 }
78
79 private:
80 // index_t MRaw_;
81 index_t NRaw_;
82 MaskOutPredicate predicate_;
83};
84
85} // namespace device
86} // namespace tensor_operation
87} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition masking_specialization.hpp:17
MaskingSpecialization
Definition masking_specialization.hpp:11
@ MaskDisabled
Definition masking_specialization.hpp:12
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr bool IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const
Definition masking_specialization.hpp:74
__host__ __device__ constexpr C0MatrixMask_impl(index_t NRaw)
Definition masking_specialization.hpp:58
__host__ __device__ constexpr bool IsNOutOfBound(index_t n) const
Definition masking_specialization.hpp:63
__host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const
Definition masking_specialization.hpp:68
Definition masking_specialization.hpp:29
__host__ __device__ constexpr bool IsTileSkippable(index_t, index_t, index_t, index_t) const
Definition masking_specialization.hpp:36
__host__ __device__ constexpr bool operator()(index_t, index_t) const
Definition masking_specialization.hpp:30
Definition masking_specialization.hpp:43
__host__ __device__ constexpr bool IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t) const
Definition masking_specialization.hpp:47
__host__ __device__ constexpr bool operator()(index_t m, index_t n) const
Definition masking_specialization.hpp:44