25 bool local_expert_masking,
26 bool skip_experts_with_zero_token =
true)
29 const index_t num_token = tokens;
32 std::vector<std::vector<IndexType>> expert_tokens(
37 std::vector<IndexType>(unit_size, num_token));
39 std::vector<std::vector<WeightType>> expert_token_weights(
40 experts, std::vector<WeightType>(unit_size, 0));
42 std::vector<IndexType> expert_slices(experts, 1);
44 std::vector<IndexType> expert_slice_idxs(experts, 0);
47 for(
index_t t = 0; t < num_token; t++)
49 for(
index_t k = 0; k < topk; k++)
51 IndexType e = topk_ids(t, k);
52 WeightType w = weights(t, k);
53 index_t idx = expert_slice_idxs[e];
54 if(idx > expert_slices[e] * unit_size - 1)
57 index_t new_size = expert_slices[e] * unit_size;
58 expert_tokens[e].resize(new_size);
59 expert_token_weights[e].resize(new_size);
60 for(
index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++)
62#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
65 expert_tokens[e][i] = num_token;
67 expert_token_weights[e][i] = 0;
70#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
73 expert_tokens[e][idx] = t;
75 expert_token_weights[e][idx] = w;
76 expert_slice_idxs[e]++;
80 IndexType* out_tokens = p_sorted_token_ids.
data();
81 WeightType* out_weights = sorted_weight.
data();
82 IndexType* out_expert_id = sorted_expert_ids.
data();
83 int curr_expert_id = 0;
84 for(
index_t e = 0; e < experts; e++)
86 if(local_expert_masking)
88 if(local_expert_mask(e) == 0)
91 if(skip_experts_with_zero_token)
93 if(expert_slice_idxs[e] == 0)
100 memcpy(out_tokens, expert_tokens[e].data(),
sizeof(
index_t) * expert_slices[e] * unit_size);
101 out_tokens += expert_slices[e] * unit_size;
103 expert_token_weights[e].data(),
104 sizeof(WeightType) * expert_slices[e] * unit_size);
105 out_weights += expert_slices[e] * unit_size;
107 for(
index_t s = 0; s < expert_slices[e]; s++)
109 out_expert_id[s] = curr_expert_id;
112 out_expert_id += expert_slices[e];
115 unit_cnt *= unit_size;
CK_TILE_HOST void reference_moe_sorting(const HostTensor< IndexType > &topk_ids, const HostTensor< WeightType > &weights, const HostTensor< IndexType > &local_expert_mask, HostTensor< IndexType > &p_sorted_token_ids, HostTensor< WeightType > &sorted_weight, HostTensor< IndexType > &sorted_expert_ids, index_t &unit_cnt, const index_t experts, const index_t unit_size, const index_t tokens, bool local_expert_masking, bool skip_experts_with_zero_token=true)
Definition reference_moe_sorting.hpp:15