transpose_vectors.hpp Source File

transpose_vectors.hpp Source File#

Composable Kernel: transpose_vectors.hpp Source File
tile/core/utility/transpose_vectors.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck_tile {
13
14// S: scalar type (or it can be non-scalar type)
15// NX: # of vector before transpose
16// NY: # of vector after transpose
17// we got [NX, NY] amount of S data to be transposed into [NY, NX] amount of S data
18template <typename S_, index_t NX, index_t NY>
20{
21 static constexpr index_t s_per_x = NY;
22 static constexpr index_t s_per_y = NX;
23
25
28
30 {
31 };
33 {
34 };
36 {
37 };
39 {
40 };
41
42 CK_TILE_DEVICE static constexpr void
44 {
45 static_for<0, NY, 1>{}([&](auto iy) {
46 static_for<0, NX, 1>{}([&](auto ix) { vy_tuple(iy)(ix) = vx_tuple[ix][iy]; });
47 });
48 }
49
50 CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer<VX, NX>& vx_tuple,
51 thread_buffer<VY, NY>& vy_tuple,
53 {
54 static_assert(sizeof(S) == 2 && NX % 2 == 0 && NY % 2 == 0, "wrong!");
55
56 constexpr auto I1 = number<1>{};
57 constexpr auto I2 = number<2>{};
58 using S2 = array<S, 2>;
59 // loop over 2x2 tiles and transpose data from vx_tuple into vy_tuple
60 static_for<0, NY, 2>{}([&](auto iy) {
61 static_for<0, NX, 2>{}([&](auto ix) {
62 // 2 16bitx2 data from vx_tuple to be transposed
63 const S2 x_s2_0 = vx_tuple[ix].template get_as<S2>(iy / I2);
64 const S2 x_s2_1 = vx_tuple[ix + I1].template get_as<S2>(iy / I2);
65
66 // transpose 2x2 16bit
67 // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
68 // -- -- -- -- -- -- -- -- - - - -
69 // index 7 6 5 4 3 2 1 0 33 77 44 88
70 // index is reversed because of little endianness (least significant bits first)
71 const S2 y_s2_0 = bit_cast<S2>(
72 __builtin_amdgcn_perm(bit_cast<uint32_t>(x_s2_0),
73 bit_cast<uint32_t>(x_s2_1),
74 // (A0.B0.C0.D0.A1.B1.C1.D1)[1, 0, 5, 4] = (C1.D1.C0.D0)
75 0x01'00'05'04));
76 const S2 y_s2_1 = bit_cast<S2>(
77 __builtin_amdgcn_perm(bit_cast<uint32_t>(x_s2_0),
78 bit_cast<uint32_t>(x_s2_1),
79 // (A0.B0.C0.D0.A1.B1.C1.D1)[3, 2, 7, 6] = (A1.B1.A0.B0)
80 0x03'02'07'06));
81
82 // write transposed 2x2 result:
83 // write (C1.D1.C0.D0)
84 vy_tuple(iy).set_as(ix / I2, y_s2_0);
85 // write (A1.B1.A0.B0)
86 vy_tuple(iy + I1).set_as(ix / I2, y_s2_1);
87 });
88 });
89 }
90
91 CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer<VX, NX>& vx_tuple,
92 thread_buffer<VY, NY>& vy_tuple,
94 {
95 static_assert(sizeof(S) == 1 && NX % 4 == 0 && NY % 4 == 0, "wrong!");
96
97 constexpr auto I1 = number<1>{};
98 constexpr auto I2 = number<2>{};
99 constexpr auto I3 = number<3>{};
100 constexpr auto I4 = number<4>{};
101 using S4 = array<S, 4>;
102 // loop over 4x4 tiles and transpose data from vx_tuple into vy_tuple
103 static_for<0, NY, 4>{}([&](auto iy) {
104 static_for<0, NX, 4>{}([&](auto ix) {
105 // read A0.B0.C0.D0
106 const S4 x_s4_0 = vx_tuple[ix].template get_as<S4>(iy / I4);
107 // read A1.B1.C1.D1
108 const S4 x_s4_1 = vx_tuple[ix + I1].template get_as<S4>(iy / I4);
109 // read A2.B2.C2.D2
110 const S4 x_s4_2 = vx_tuple[ix + I2].template get_as<S4>(iy / I4);
111 // read A3.B3.C3.D3
112 const S4 x_s4_3 = vx_tuple[ix + I3].template get_as<S4>(iy / I4);
113
114 // (A1.B1.C1.D1.A0.B0.C0.D0)[5, 1, 4, 0] = (C1.C0.D1.D0)
115 uint32_t t_s4_0 = __builtin_amdgcn_perm(
116 bit_cast<uint32_t>(x_s4_1), bit_cast<uint32_t>(x_s4_0), 0x05'01'04'00);
117 // (A3.B3.C3.D3.A2.B2.C2.D2)[5, 1, 4, 0] = (C3.C2.D3.D2)
118 uint32_t t_s4_1 = __builtin_amdgcn_perm(
119 bit_cast<uint32_t>(x_s4_3), bit_cast<uint32_t>(x_s4_2), 0x05'01'04'00);
120 // (C3.C2.D3.D2.C1.C0.D1.D0)[5, 4, 1, 0] = (D3.D2.D1.D0)
121 const S4 y_s4_0 =
122 bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x05'04'01'00));
123 // (C3.C2.D3.D2.C1.C0.D1.D0)[7, 6, 3, 2] = (C3.C2.C1.C0)
124 const S4 y_s4_1 =
125 bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x07'06'03'02));
126 // (A1.B1.C1.D1.A0.B0.C0.D0)[7, 3, 6, 2] = (A1.A0.B1.B0)
127 t_s4_0 = __builtin_amdgcn_perm(
128 bit_cast<uint32_t>(x_s4_1), bit_cast<uint32_t>(x_s4_0), 0x07'03'06'02);
129 // (A3.B3.C3.D3.A2.B2.C2.D2)[7, 3, 6, 2] = (A3.A2.B3.B2)
130 t_s4_1 = __builtin_amdgcn_perm(
131 bit_cast<uint32_t>(x_s4_3), bit_cast<uint32_t>(x_s4_2), 0x07'03'06'02);
132 // (A3.A2.B3.B2.A1.A0.B1.B0)[5, 4, 1, 0] = (B3.B2.B1.B0)
133 const S4 y_s4_2 =
134 bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x05'04'01'00));
135 // (A3.A2.B3.B2.A1.A0.B1.B0)[7, 6, 3, 2] = (A3.A2.A1.A0)
136 const S4 y_s4_3 =
137 bit_cast<S4>(__builtin_amdgcn_perm(t_s4_1, t_s4_0, 0x07'06'03'02));
138
139 // write transposed 4x4 result:
140 // write (D3.D2.D1.D0)
141 vy_tuple(iy).set_as(ix / I4, y_s4_0);
142 // write (C3.C2.C1.C0)
143 vy_tuple(iy + I1).set_as(ix / I4, y_s4_1);
144 // write (B3.B2.B1.B0)
145 vy_tuple(iy + I2).set_as(ix / I4, y_s4_2);
146 // write (A3.A2.A1.A0)
147 vy_tuple(iy + I3).set_as(ix / I4, y_s4_3);
148 });
149 });
150 }
151
152 CK_TILE_DEVICE static constexpr void apply_impl(const thread_buffer<VX, NX>& vx_tuple,
153 thread_buffer<VY, NY>& vy_tuple,
155 {
156 static_assert(sizeof(S) == 1 && NX % 2 == 0 && NY % 2 == 0, "wrong!");
157
158 constexpr auto I1 = number<1>{};
159 constexpr auto I2 = number<2>{};
160 using S2 = array<S, 2>;
161 // loop over 2x2 tiles and transpose data from vx_tuple into vy_tuple
162 static_for<0, NY, 2>{}([&](auto iy) {
163 static_for<0, NX, 2>{}([&](auto ix) {
164 // read A0.B0
165 const S2 x_s2_0 = vx_tuple[ix].template get_as<S2>(iy / I2);
166 // read A1.B1
167 const S2 x_s2_1 = vx_tuple[ix + I1].template get_as<S2>(iy / I2);
168
169 // v_perm_b32: pick 4 bytes from 8 bytes in (input0.input1) using the mask
170 const S2 y_s2_0 = bit_cast<S2>(static_cast<uint16_t>(__builtin_amdgcn_perm(
171 static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_0)),
172 static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_1)),
173 // (XX.XX.A0.B0.XX.XX.A1.B1)[clear, clear, 0, 4] = (00.00.B1.B0)
174 0x0C'0C'00'04)));
175
176 const S2 y_s2_1 = bit_cast<S2>(static_cast<uint16_t>(__builtin_amdgcn_perm(
177 static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_0)),
178 static_cast<uint32_t>(bit_cast<uint16_t>(x_s2_1)),
179 // (XX.XX.A0.B0.XX.XX.A1.B1)[clear, clear, 1, 5] = (00.00.A1.A0)
180 0x0C'0C'01'05)));
181
182 // write transposed 2x2 result:
183 // write (B1.B0)
184 vy_tuple(iy).set_as(ix / I2, y_s2_0);
185 // write (A1.A0)
186 vy_tuple(iy + I1).set_as(ix / I2, y_s2_1);
187 });
188 });
189 }
190
191 CK_TILE_DEVICE static constexpr auto tag_dispatch()
192 {
193 if constexpr(sizeof(S) == 2 && NX % 2 == 0 && NY % 2 == 0)
194 {
195 return bytesize2_2x2_tag{};
196 }
197 else if constexpr(sizeof(S) == 1 && NX % 4 == 0 && NY % 4 == 0)
198 {
199 return bytesize1_4x4_tag{};
200 }
201 else if constexpr(sizeof(S) == 1 && NX % 2 == 0 && NY % 2 == 0)
202 {
203 return bytesize1_2x2_tag{};
204 }
205 else
206 {
207 return generic_tag{};
208 }
209 }
210
212 thread_buffer<VY, NY>& vy_tuple) const
213 {
214 apply_impl(vx_tuple, vy_tuple, tag_dispatch());
215 }
216};
217
218} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
unsigned short uint16_t
Definition stdint.h:125
unsigned int uint32_t
Definition stdint.h:126
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition tile/core/utility/functional.hpp:43
Definition tile/core/utility/debug.hpp:67
Definition tile/core/utility/transpose_vectors.hpp:39
Definition tile/core/utility/transpose_vectors.hpp:36
Definition tile/core/utility/transpose_vectors.hpp:33
Definition tile/core/utility/transpose_vectors.hpp:30
Definition tile/core/utility/transpose_vectors.hpp:20
static constexpr index_t s_per_y
Definition tile/core/utility/transpose_vectors.hpp:22
static CK_TILE_DEVICE constexpr void apply_impl(const thread_buffer< VX, NX > &vx_tuple, thread_buffer< VY, NY > &vy_tuple, bytesize1_4x4_tag)
Definition tile/core/utility/transpose_vectors.hpp:91
static CK_TILE_DEVICE constexpr void apply_impl(const thread_buffer< VX, NX > &vx_tuple, thread_buffer< VY, NY > &vy_tuple, generic_tag)
Definition tile/core/utility/transpose_vectors.hpp:43
static CK_TILE_DEVICE constexpr auto tag_dispatch()
Definition tile/core/utility/transpose_vectors.hpp:191
remove_cvref_t< S_ > S
Definition tile/core/utility/transpose_vectors.hpp:24
static CK_TILE_DEVICE constexpr void apply_impl(const thread_buffer< VX, NX > &vx_tuple, thread_buffer< VY, NY > &vy_tuple, bytesize2_2x2_tag)
Definition tile/core/utility/transpose_vectors.hpp:50
array< S, s_per_x > VX
Definition tile/core/utility/transpose_vectors.hpp:26
static constexpr index_t s_per_x
Definition tile/core/utility/transpose_vectors.hpp:21
array< S, s_per_y > VY
Definition tile/core/utility/transpose_vectors.hpp:27
static CK_TILE_DEVICE constexpr void apply_impl(const thread_buffer< VX, NX > &vx_tuple, thread_buffer< VY, NY > &vy_tuple, bytesize1_2x2_tag)
Definition tile/core/utility/transpose_vectors.hpp:152
CK_TILE_DEVICE void operator()(const thread_buffer< VX, NX > &vx_tuple, thread_buffer< VY, NY > &vy_tuple) const
Definition tile/core/utility/transpose_vectors.hpp:211