Elementwise Add Kernel in CUDA
cuda basics techTable of Contents
In this post, I will walk through the reduce kernels in LeetCUDA and implement them interactively in this file using org-executor.
Background
In CUDA programming, an elementwise kernel is a fundamental building block that applies a given operation independently to each element of an input array (or arrays), producing an output array of the same shape. This is highly parallelizable, as each thread can process a single element without dependencies on others.
Elementwise kernels are commonly used for operations such as vector addition, scaling, activation functions in neural networks, and more. Understanding how to implement an efficient elementwise kernel is essential before moving on to more complex patterns like reductions.
In the following sections, we will review how to write a basic elementwise kernel in CUDA, discuss its memory access patterns, and explore best practices for maximizing performance.
Environment setting
Just follow the LeetCUDA’s settings, we will expose all the kernels to PyTorch and use its facilities to do performance and precession evaluation.
PyTorch
The python version:
import torch
print(torch.__version__)
2.8.0a0+5228986c39.nv25.06
Hardware
NVIDIA H100 80GB HBM3, 81559 MiB, 81080 MiB, 0 MiB, 575.57.08, 23, 0 %
Kernel Launching Utils
Common C++ header content:
#pragma once
#include <algorithm>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <float.h>
#include <stdio.h>
#include <stdlib.h>
#include <torch/extension.h>
#include <torch/types.h>
#include <vector>
There are some common code for launching kernel with torch facilities.
#define CEIL(x, y) (((x) + (y) - 1) / (y))
#define WARP_SIZE 32
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
inline void check_torch_dtype(torch::Tensor tensor,
torch::ScalarType expected_dtype) {
if (tensor.dtype() != expected_dtype) {
throw std::runtime_error("Tensor dtype mismatch");
}
}
inline std::tuple<dim3, dim3>
get_launch_dimensions(int N, int elements_per_block = 256,
int element_per_thread = 1) {
const int threads_per_block = elements_per_block / element_per_thread;
dim3 block_size(threads_per_block);
dim3 grid_size(CEIL(N, elements_per_block));
return {grid_size, block_size};
}
#define TORCH_BINDING_COMMON_EXTENSION(func) m.def(#func, &func, #func);
#define TORCH_BINDING_ELEM_ADD(packed_type, th_type, element_type, \
elements_per_thread) \
__global__ void elementwise_add_##packed_type##_kernel( \
element_type* __restrict__ a, \
element_type* __restrict__ b, \
element_type* __restrict__ c, int N); \
\
void elementwise_add_##packed_type(torch::Tensor A, torch::Tensor B, \
torch::Tensor C, \
int elements_per_block) { \
check_torch_dtype(A, th_type); \
check_torch_dtype(B, th_type); \
check_torch_dtype(C, th_type); \
auto [grid, block] = get_launch_dimensions(A.numel(), elements_per_block, \
elements_per_thread); \
elementwise_add_##packed_type##_kernel<<<grid, block>>>( \
reinterpret_cast<element_type*>(A.data_ptr()), \
reinterpret_cast<element_type*>(B.data_ptr()), \
reinterpret_cast<element_type*>(C.data_ptr()), A.numel()); \
}
Kernels
Basic kernel
This kernel demonstrates a basic elementwise addition operation in CUDA, where each thread adds two corresponding elements from the input arrays:
#include "elementwise_add.cuh"
__global__ void elementwise_add_f32_kernel(float* __restrict__ a,
float* __restrict__ b,
float* __restrict__ c, int N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N) {
c[tid] = a[tid] + b[tid];
}
}
Explain
Each thread loads one float (4 bytes) independently, this should result in poor memory coalescing.
floatx4 vector load
This kernel introduces vectorized load and store operations using `float4`, which allows each thread to process four floats at once. By loading 16 bytes (128 bits) per memory transaction instead of 4 bytes, this approach significantly improves memory bandwidth utilization and coalescing efficiency. Each thread processes 4 elements simultaneously, reducing the total number of memory transactions by 4x compared to the basic kernel:
#include "elementwise_add.cuh"
__global__ void elementwise_add_f32x4_kernel(float* __restrict__ a,
float* __restrict__ b,
float* __restrict__ c, int N) {
int idx = 4 * (blockIdx.x * blockDim.x + threadIdx.x);
if (idx + 3 < N) {
float4 reg_a = FLOAT4(a[idx]);
float4 reg_b = FLOAT4(b[idx]);
float4 reg_c;
reg_c.x = reg_a.x + reg_b.x;
reg_c.y = reg_a.y + reg_b.y;
reg_c.z = reg_a.z + reg_b.z;
reg_c.w = reg_a.w + reg_b.w;
FLOAT4(c[idx]) = reg_c;
}
}
fp16x2 kernel
This kernel leverages half-precision (fp16) data types with vectorized operations using `half2`. Each thread processes 2 half-precision values simultaneously using CUDA’s native half2 intrinsics. This provides both memory bandwidth improvements (loading 4 bytes per transaction) and computational efficiency through packed arithmetic operations:
#include "elementwise_add.cuh"
__global__ void elementwise_add_f16x2_kernel(half* __restrict__ a, half* __restrict__ b, half* __restrict__ c, int N) {
int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
if (idx + 1 < N) {
half2 reg_a = HALF2(a[idx]);
half2 reg_b = HALF2(b[idx]);
half2 reg_c = __hadd2(reg_a, reg_b);
HALF2(c[idx]) = reg_c;
}
}
fp16x8 kernel
This kernel extends the vectorization approach to process 8 half-precision values per thread, using four `half2` packed operations. This maximizes memory throughput by loading 16 bytes (128 bits) per thread while maintaining efficient packed arithmetic. The kernel includes proper bounds checking for each half2 pair to handle cases where the array size is not perfectly divisible by 8:
#include "elementwise_add.cuh"
__global__ void elementwise_add_f16x8_kernel(half* __restrict__ a,
half* __restrict__ b,
half* __restrict__ c, int N) {
const int linearThreadId = blockIdx.x * blockDim.x + threadIdx.x;
const int idx = linearThreadId * 8;
const int remaining = N - idx;
if (remaining <= 0) {
return;
}
// Fast path: full 8 elements
if (remaining >= 8) {
// Single 128-bit loads for A and B
float4 vec_a = LDST128BITS(a[idx]);
float4 vec_b = LDST128BITS(b[idx]);
// Reinterpret as four half2 lanes, compute, then store as 128-bit
union Pack16 {
float4 f4;
half2 h2[4];
} pa, pb, pc;
pa.f4 = vec_a;
pb.f4 = vec_b;
pc.h2[0] = __hadd2(pa.h2[0], pb.h2[0]);
pc.h2[1] = __hadd2(pa.h2[1], pb.h2[1]);
pc.h2[2] = __hadd2(pa.h2[2], pb.h2[2]);
pc.h2[3] = __hadd2(pa.h2[3], pb.h2[3]);
// Single 128-bit store for C
LDST128BITS(c[idx]) = pc.f4;
return;
}
// Tail path: handle <8 remaining elements
int i = 0;
for (; i + 1 < remaining; i += 2) {
half2 ra = HALF2(a[idx + i]);
half2 rb = HALF2(b[idx + i]);
HALF2(c[idx + i]) = __hadd2(ra, rb);
}
if (i < remaining) {
c[idx + i] = __hadd(a[idx + i], b[idx + i]);
}
}
Register the kernels and benchmark
Register the kernel:
#include "elementwise_add.cuh"
TORCH_BINDING_ELEM_ADD(f32, torch::kFloat32, float, 1)
TORCH_BINDING_ELEM_ADD(f32x4, torch::kFloat32, float, 4)
TORCH_BINDING_ELEM_ADD(f16x2, torch::kFloat16, half, 2)
TORCH_BINDING_ELEM_ADD(f16x8, torch::kFloat16, half, 8)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f32)
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f32x4)
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16x2)
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16x8)
}
Compile PyTorch module
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
source_files = [
"elementwise_add_basic.cu",
"elementwise_add_f32x4.cu",
"elementwise_add_f16x2.cu",
"elementwise_add_f16x8.cu",
"elementwise_add_lib.cu",
]
setup(
name='elementwise_lib', # The name of your module
ext_modules=[
CppExtension(
'elementwise_lib',
source_files
),
],
cmdclass={
'build_ext': BuildExtension
}
)
Launching in PyTorch:
import time
from functools import partial
from typing import Optional
import torch
import os
import sys
workspace = os.environ["__WORKSPACE__"]
# The built torch lib is in the following path
lib_dir = f"{workspace}/build/lib.linux-x86_64-cpython-312"
print(f"lib: {lib_dir}")
sys.path.append(lib_dir)
import elementwise_lib as lib
torch.set_grad_enabled(False)
print(f"Compiling Torch kernel")
# Load the CUDA kernel as a python module
import hashlib
import os
def get_file_hash(filepath):
"""Get MD5 hash of file content"""
with open(filepath, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()[:8] # Use first 8 chars
print(f"running benchmark")
def run_benchmark(
perf_func: callable,
a: torch.Tensor,
b: torch.Tensor,
tag: str,
out: Optional[torch.Tensor] = None,
warmup: int = 10,
iters: int = 1000,
show_all: bool = False,
elements_per_block = 256,
):
if out is not None:
out.fill_(0)
# Warmup
for _ in range(warmup):
perf_func(a, b, out, elements_per_block)
torch.cuda.synchronize()
# Benchmark
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(iters):
perf_func(a, b, out, elements_per_block)
end_event.record()
torch.cuda.synchronize()
total_time = start_event.elapsed_time(end_event) # ms
mean_time = total_time / iters
out_info = f"out_{tag}"
out_val = out.flatten().detach().cpu().numpy().tolist()[:2]
out_val = [round(v, 8) for v in out_val]
print(f"{out_info:>18}: {out_val}, time:{mean_time:.8f}ms")
if show_all:
print(out)
return out, mean_time
Run the benchmark:
shapes = [
(2096, 4096), (2048, 2048), (2048, 1024), (1024, 1024), (512, 512), (256, 256)]
for shape in shapes:
print(f"Running benchmark for shape: {shape}")
A = torch.randn(*shape, dtype=torch.float32, device="cuda").contiguous()
B = torch.randn(*shape, 1024, dtype=torch.float32, device="cuda").contiguous()
C = torch.zeros_like(A).contiguous()
# Create fp16 tensors for fp16 kernels
A_fp16 = A.half().contiguous()
B_fp16 = B.half().contiguous()
C_fp16 = torch.zeros_like(A_fp16).contiguous()
elements_per_block = 256
print(f"elements_per_block: {elements_per_block}")
# Increase elements_per_block to make sure that each kernel has same threads_per_block
run_benchmark(lib.elementwise_add_f32, A, B, "basic", C, elements_per_block)
run_benchmark(lib.elementwise_add_f32x4, A, B, "f32x4", C, elements_per_block * 4)
run_benchmark(lib.elementwise_add_f16x2, A_fp16, B_fp16, "f16x2", C_fp16, elements_per_block * 2)
run_benchmark(lib.elementwise_add_f16x8, A_fp16, B_fp16, "f16x8", C_fp16, elements_per_block * 8)
print(f"--")
Results:
lib: /workspace/project/superjomn.github.io/content-org/_build/build/lib.linux-x86_64-cpython-312
Compiling Torch kernel
running benchmark
Running benchmark for shape: (2096, 4096)
elements_per_block: 256
out_basic: [0.62899578, -3.16506243], time:0.04013546ms
out_f32x4: [0.62899578, -3.16506243], time:0.03716669ms
out_f16x2: [0.62890625, -3.1640625], time:0.02376186ms
out_f16x8: [0.62890625, -3.1640625], time:0.02382634ms
--
Running benchmark for shape: (2048, 2048)
elements_per_block: 256
out_basic: [0.98603237, -2.21596098], time:0.02142691ms
out_f32x4: [0.98603237, -2.21596098], time:0.01876467ms
out_f16x2: [0.98632812, -2.21484375], time:0.01224410ms
out_f16x8: [0.98632812, -2.21484375], time:0.01231584ms
--
Running benchmark for shape: (2048, 1024)
elements_per_block: 256
out_basic: [-1.68364513, 0.07630849], time:0.00754973ms
out_f32x4: [-1.68364513, 0.07630849], time:0.00768909ms
out_f16x2: [-1.68359375, 0.07666016], time:0.00720714ms
out_f16x8: [-1.68359375, 0.07666016], time:0.00725242ms
--
Running benchmark for shape: (1024, 1024)
elements_per_block: 256
out_basic: [0.41730967, -2.56410193], time:0.00473962ms
out_f32x4: [0.41730967, -2.56410193], time:0.00490102ms
out_f16x2: [0.41723633, -2.56445312], time:0.00471936ms
out_f16x8: [0.41723633, -2.56445312], time:0.00484266ms
--
Running benchmark for shape: (512, 512)
elements_per_block: 256
out_basic: [-0.84098238, -0.51086581], time:0.00448192ms
out_f32x4: [-0.84098238, -0.51086581], time:0.00438221ms
out_f16x2: [-0.84130859, -0.51074219], time:0.00443571ms
out_f16x8: [-0.84130859, -0.51074219], time:0.00444208ms
--
Running benchmark for shape: (256, 256)
elements_per_block: 256
out_basic: [2.76621795, 1.71955645], time:0.00438445ms
out_f32x4: [2.76621795, 1.71955645], time:0.00440973ms
out_f16x2: [2.765625, 1.71972656], time:0.00444272ms
out_f16x8: [2.765625, 1.71972656], time:0.00445354ms
--