Revolutionizing Mixture of Experts Performance: 10x Speedup on AMD Instinct GPUs with Optimized Align & Sort
Apr 10, 2025

Author: Lei Wang
Efficient MoE Align & Sort design in SGLang Fused MoE
Human-brain cortex from Oxford university research paper,
archived from internet
The first truly workable version in CUDA is SwitchTransformer[1], then improved by Mistral[2] by upcycling dense models:
-
alignment: to conduct traditional alignment-based offsets computation for the radix sorting algorithm within a single block;
-
placement: to place tokens according to the offsets computed in multiple blocks;
We also analyzed the roofline model of MoE Sort & Align. The roofline model shows the kernel performancedrops in memory bound region.
In section AMD Compute Profile, we give details of the profiling data and analysis of our algorithm design inROCm platform.
The fundamental rule is, that synchronization among XCDs (Accelerated Computing Dies) is costly, better tomake full use of XCDs and L2 cache locality affinity to increase the performance.
And we should avoid expensive synchronization by either using the lowest speed computing die (XCD7 forMI300X, XCD5 for MI300A) when grid size is smaller than the number of XCDs per chip (8 for MI300X, 6 forMI300A), or adapting grid size to a multiple of the number of XCDs per chip when it exceeds that threshold.
Our implementation uses 66 active CUs in multi-blocks execution that across two dies and Die-Die exchange is inevitable in block-wise reduction. We will submit further V4 optimization to SGLang later in thisquarter.
Details will be further discussed in profiling section.
Review of Fused MoE in SGLang
Before the kernel launch, the MoE Align & Sort algorithm is applied. the MoE Align & Sort triton kernel issplit into 4 phases where direct accesses to DRAM without shared memory are employed contrast to thevectorize triton version.
Then CUDA implementation is finally split into two phases and only the second phase execution is accelerated in multiple blocks.
MoE Align & Sort CUDA Algorithm in other Open-Source Platform
Hence fusion is largely (by cost) limited to topk gating softmax, biased topk gating softmax, which are later incoroperated in SGLang.
Megatron
Megatron, before the publication of this article, for FP16/BF16, largely uses FasterTransformer approach but
added gradient operation of permute : unpermute, to facilitate training workload.
That means MoE is also not efficiently fused.
vLLM
SGLang uses many vLLM kernels, but vLLM 's Fused Moe was initially contributed by SGLang team. Hence they
deploy the same approach.
CK
The first version of AMD friendly fused MoE was proposed in CK#1634 on NOV 26, 2024. Later, MoE Align &
Sort was added in CK#1771 and CK#1840.
The high-level idea is to fuse MoE sorting with Group GEMM. And MoE & Sorting in CK largely employes
SGLang's team approach execept for CK pipliner and partitioner.
CK fused MoE High Level Idea[9]
Fusion of per_group_token_quant (for online fp8 quantization), MoE sorting and Group GEMM can beimmediately resolved by incorporating Radix Sort computing logics into Group GEMM pipeliner: countoccurencies to compute offsets followed by parallel placement.
One of the most critical problems is that how the two kinds of workloads (Radix Sorting & Group GEMM) isbalanced.
In AMD data center chips, Group GEMM fragment is more likely to be evenly distributed to all the available blocks in an XCD. While, the data exchange among blocks in different CUs are through low speed of L2 Cacheand L2 Cache fabric if multiple XCDs involved.
Writing CK kernels requires writing host side CK solution launcher:
Device entry of the kernel, tile partitioner, and stages pipliner.
The AMD CK partitioner and stages pipliner for fused MoE is also very interesting to be attributed to the final
assembly, yet out of scope of this article.
But just remember its MoE Align & Sort is part of producer codes:
So MoE Align & Sort in the AMD CK solution almost aligns with SGLang main implementation except forpartitioner and pipliner.
Note the implementation does not always promise the best performance in AMD platform (see asm MoE inAITER).
Since AMD CDNA3 arch does not support Graphcore alike on-chip shuffling (we abstracted and generalizedon-chip shuffling as Remapping Op of PopART[12] & PopRT in 2023) magics, -- which was now supported inNVIDIA H100/H200/B200 throughout high efficient on chip SM<->SM communication.
As a result, adapting the data layout cheaply among blocks to its best will be a very interesting section in the AMD open-source solution.
Hence, in philosophy, tiling based fusion code of these two different workloads may not always exceed thenon-fused version. Details of the research will be conducted in our V4 release.
AITER
AI Tensor Engine For ROCm[10]
AITER was introduced at an early time of this year to incorporate LLM kernels used in different projects. Itsupports Fused MoE via ck moe, asm version of MoE via hipModule and Triton-fused MoE.
Hence it is partially open source, since the opaque assembly and development schedule to MI300Xdevelopers.
The alleged 3x acceleration [10] of fused MoE in AITER is verified by Bruce Xu [13] and is essentail from the
acceleration observed in a group GEMM with different shapes : a gemm where each expert's FFN weights
mutliply a block of hidden states of tokens.
The proof is that asm gemm generates almost 3x improvements in PR#199:
-
__builtin_nontemporal_load,
-
__builtin_amdgcn_ds_swizzle,
-
__builtin_amdgcn_ds_permute/__builtin_amdgcn_ds_bpermute,
-
_builtin_amdgcn_mov_dpp
Cutlass v3.8
Fused MoE is not currently publicly supported in NVIDIA Cutlass 3.8.0 at the time I am writing this article.
Hence no MoE Align & Sort available this repo.
TRT-LLM
Before v0.16.0, the TRT-LLM basic follows FasterTransformer approach. After v0.17.0, the MoE part isdisclosed.
Make AMD Friendly CUDA Implementation wtih more than 3x ~ 7x Acceleration
The algorithm employes multiple blocks execution schemes and consists of 3 different sections (D-C-P):
- Distributed concurrencies counting
- Computing cumsum
- parallel unaligned local cumsum
- reduce unaligned cumsum
- align global cumsum
- store global cumsum
- Parallel placement
Parallel unaligned local cumsum
Our proposed parallel local unaligned cumsum
The algorithm was first proposed and implemented by us in PR#2970.
We load balanced the cumsum execution in each block to kElementsPerThr(16) threads, were kElementsPerThr + kElementsPerThr + threadIdx.x Add Operations needed to be processed in eachthread.
Hence wavefront is faster to reach compared to the single thread version in current repo and we herebyobserved 30% improvement in this version of implementation.
Reduce unaligned cumsum
Once we get local unaligned cumsum in each block, we proceed to block-wise reduction among the cumsumstored in the pre-allocated HBM buffer.
We chose FRAG_SIZE_M(16) x FRAG_SIZE_N(16) x FRAGS_PER_BLOCK(4) SRAM fragments for blockwisereduction, and FRAGS_PER_BLOCK is tuneable :
Block-Wise Reduction
In AMD platform, calculation is performend on a 1 warp to load / 1 warp to compute basis, while 2 warps to load and 1 warp to compute in NVIDIA platform.
The design makes use of full advantages of AMD 64 SIMD lanes in CDNA3 architecture. And the numberblocks is always multiple of the number of XCDs in this multi-die arch chip.
FRAGS_PER_BLOCK was set to 4 to facilitate re-use of SMEM in multiple rounds.
Align global cumsum & store global cumsum
We improved the vectorization codes and take care of loop tails if input data size is not aligned withkElementsPerAccess constant.
The benchmarks show coalescing rate is improved but still limited to 30%. We will work on it in V4 release.
Writing AMD friendly CUDA
Writing a Pytorch extension enables automatic translating CUDA kernel to HIP kernel with ROCm SDK.
However, there are cases where the HIP kernel works differently from CUDA kernel:
- Warp size is a architecture dependent global variable and defined in ROCm SDK as warpSize; in CDNA3
arch, warpSize is defined 64 - The device function signature may not perfectly align with CUDA, and needs conditional compiling to
support these symbols - Being aware of L2 cache optimization in multi-die chips arch
Benchmarks
We conducted extensive tests without under CUDA graph capture for large workloads of deepseek v3 models. Hence the number of experts was set to 256. The algorithm currently does not support to be under cuda graph capture and we will resolve this issue later in V4 release.
Due to the virtualization of GPU machines and the number of CPU allocated for the test node, the performance may vary from time to time compared to bare metal tests.
Hence, we use triton implementation as baseline to demonstrate the acceleration multiple and efficiency of our proposed algorithm for MoE Align & Sort.
Each test was verified first before benchmark. During the benchmark, we observed that triton in AMD platform runs significantly longer than that in NV at the time we tested. We hence recommend further optimization of triton MLIR for more efficient lowering process compared to NVDIA triton.
For AMD triton, we observed MI300X is 1.5x faster, hence improvement multiple in MI300X is not significant as MI100. And moreover, even MI300X is generally believe more faster than MI100, but in our test, the algorithm in MI100 performs better than in MI300X.
It is partially attributed to the fact that for a memory bounded op, the communication among multiple dies chip lowering the speed of execution.
In both platforms we observed significant improvements after applying our proposed algorithm, where the existing CUDA implementation almost costed the same time as Triton.
AMD system preparation
In order to make best usage of AMD heterogenous system, it is recommended to do some checking.
- Both NVIDIA Grace CPU and AMD EPYC 9004 system are generally recommended disable NUMA auto balancing to work with GPU; however there are cases where it is not
- When virtualization enabled, IOMMU pass-through mode is recommended to eliminate DMA translation, hence, to bring performance improvements
Benchmark on MI100
git clone https://github.com/yiakwy-xpu-ml-framework-team/AMD-sglang-benchmark-fork.git -b
optimize_moe_align_v3 && cd sgl-kernel && python setup_rocm.py install
Feasibility across different combination of numbers input token and experts can be verified:
cd ../benchmark/kernels/fused_moe_trition && python benchmark_deepseekv3_moe_align_blocks.py --
verify
num_tokens |
experts |
SGLang |
Triton (AMD) |
GPU |
8192 |
256 |
79.36 |
426.71 |
MI100 |
16384 |
256 |
86.4 |
681.12 |
MI100 |
16384 x 128 |
256 |
3047.68 |
62442.85 |
MI100 |
32768 x 128 |
256 |
7211.37 |
129388.43 |
MI100 |
Benchmark on A100
num_tokens |
experts |
SGLang |
Triton (NV) |
GPU |
8192 |
256 |
77.44 |
124.92 |
A100 |
16384 |
256 |
A100 |
||
16384 x 128 |
256 |
5966.81 |
17396.51 |
A100 |
32768 x 128 |
256 |
12450.05 |
34711.14 |
A100 |
Benchmark on H200
num_tokens |
experts |
SGLang |
Triton (NV) |
GPU |
8192 |
256 |
H200 |
||
16384 |
256 |
H200 |
||
16384 x 128 |
256 |
4508.42 |
12361.15 |
H200 |
32768 x 128 |
256 |
9023.48 |
24683.70 |
H200 |
Benchmark on MI300X
num_tokens |
experts |
SGLang |
Triton (AMD) |
GPU |
8192 |
256 |
88.16 |
281.64 |
MI300X |
16384 |
256 |
134.02 |
448.88 |
MI300X |
16384 x 128 |
256 |
6865.64 |
43266.09 |
MI300X |
32768 x 128 |
256 |
13431.80 |
89788.58 |
MI300X |
AMD Compute Profile
Setup
In ROCm 6.3.3, setup a rocprof-compute can be easily as three steps setup, details can be found here: https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/tree/main
Profiling Results of Vector L1 Cache
The workload 16384 tokens x (top 8 out of 256 experts) unless otherwise specified.
kernel VGPRs |
VGPRs |
SGPRs |
active CUs |
Vector L1 cache hit rate |
coalescing rate / utils |
old main moe_align_block_size_kernel (k1) |
20 |
48 |
3 |
0% |
25% / 7% |
old main count_and_sort_expert_tokens_kernel (k2) |
8 |
32 |
39 |
27% |
NaN |
our moe_align_block_size_kernel |
52 |
48 |
66 |
61% |
36% / 18% |
We maximize the usage of VGPRs but reduce total usage of SGPRs in our algorithm. The data also indicates Zero VGPRs/SGPRs spills usage that healthy usage of registers and no performance panelty for this kernel.
Vector L1 cache (vL1D) is unit local to each CU, the hit rate records cache line hit rates when data requestd from L2 Cache to CU. 30% L2 cache requests was coalesced by vL1D's texture addresser and 61% hit rates achieved, which can also be improved later if necessary.
At the time data requested from CU to vL1D's addressing processing unit (texture addresser), there are four states for the complex to decide whether to accept or roll back the data request to CU via the data processor unit in vL1D.
- Busy: the texture addresser is processing address
- Address Stall: the texture addresser is stalled from sending address to vL1D
- Data Sending Stall: the texture addresser is stalled from sending data to vL1D
- Data Waiting Stall: the texture addresser is stalled waiting to send data to data processor unit in vL1D
Detials of this micro arch behavior can be found in AMD CDNA3 ISA and rocProfiler-compute docs.
Our vL1D addresser stall
We witnessed 18.61% Data Waiting Stall rate from vector L1 cache in this aglorithm design.
The load balance of data R/W is greatly reduced from 8 kB Reading Op, 27 B Writing Op to combination of109 B Reading Op, 468 B Writing Op and 202 B Atomic Op.
Profiling Results of L2 Cache
In CDNA3 architecture, L2 Cache is shared by all CUs and is the main entry to share data among thread blocks distruted to different CUs.
With multiple channels and addresses interleaving design, requests to L2 cache can be largely handled concurrently.
Moreover with AMD specific intrinsics such as builtin_nontemporal_load, we can pass through L2 cache for data we don't need to visit again.
The details of L2 cache study will be revealed in V4 release.
Conclusion
The new algorithm accelerates MoE Align & Sort in both CUDA and ROCm platform significantly up to 3x ~ 7x by maximizing the usage of LDS and vector registers.
We also observed memory bounded op may perform worse in a multiple die chip compared to a single die chip, this indicates a new finetuning direction when programming device codes in a multiple-die chip such as MI300X/MI300A and B200/B300.
However, details of the algorithm can be still polished to improve cache hit rate and main memory coalecsing rate.
Acknowledgement
Special thanks to Prof Zhang Han (hanzhangqin8@gmail.com), Doctor Wang YunHong (yunhongwang2000@gmail.com) from NUS team for the collaboration in MI100/MI250 performance verification, Zev Rekhter (connect@evergrid.ai) for the collaboration in MI300X performance verification, Shuyi Fan (fsygd1996@163.com) for the collaboration in H200 verification and BBuf(1182563586@qq.com) for discussion and review of the solution in the SGLang.
Note this is an independent work from SGLang community.
I also express my deep thanks to Bingqing, Peng Sun and ShawHai who spare time individually in reviewing the article and giving suggestions in revision.
Reference
- W. Fedus, B. Zoph, and N. Shazeer. Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity. CoRR, abs/2101.03961, 2021. URL https://arxiv.org/ abs/2101.03961.
- A. Q. Jiang, A. Sablayrolles, A. Mensch, C. Bamford, D. S. Chaplot, D. d. l. Casas, F. Bressand, G. Lengyel, Lample, L. Saulnier, et al. Mistral 7b. arXiv preprint arXiv:2310.06825, 2023.
- DeepSeek-AI. Deepseek-v2: A strong, economical, and efficient mixture-of-experts language model. CoRR, abs/2405.04434, 2024c. URL https://doi.org/10.48550/arXiv.2405.04434.
- DeepSeek V3 : https://arxiv.org/abs/2412.19437; Retrieved on 2025-03-18
- DeepSeek R1 : https://arxiv.org/pdf/2501.12948; Retrieved on 2025-03-18
- TransformerEngine : https://github.com/NVIDIA/TransformerEngine; Retrieved on 2025-03-18
- NV Group GEMM : https://github.com/yiakwy-xpu-ml-framework-team/NV_grouped_gemm; Retrieved on 2025-03-18
- FasterTransformer : https://github.com/NVIDIA/FasterTransformer; Retrieved on 2025-03-18
- CK Fused MoE V1 : https://github.com/ROCm/composable_kernel/pull/1634
- AMD 3X MOE : https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1-Part2/README.html
- Lean Wang and Huazuo Gao and Chenggang Zhao and Xu Sun and Damai Dai Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts, 2024. URL https://arxiv.org/abs/2408.15664.
- PopART on chip TensorRemap : https://github.com/graphcore/popart/tree/sdk-release-3.4
- DeepSeek V3 Optimizatoin based on AITER backend : https://github.com/sgl-project/sglang/pull/4344
Sponsor Sources
Also see evergrid.ai and huggingface sites.
