Created
October 6, 2024 21:44
Pi estimation in Triton because why not
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import triton | |
import triton.language as tl | |
import torch | |
import time | |
import numpy as np | |
@triton.jit | |
def pi_kernel( | |
total_ptr, | |
BLOCK_SIZE: tl.constexpr, | |
ITERATIONS_PER_THREAD: tl.constexpr, | |
): | |
pid = tl.program_id(0) | |
tid = tl.arange(0, BLOCK_SIZE) | |
counter = tl.zeros((BLOCK_SIZE,), dtype=tl.int64) | |
for i in range(ITERATIONS_PER_THREAD): | |
seed = (pid * BLOCK_SIZE + tid) * ITERATIONS_PER_THREAD + i | |
x = tl.rand(seed, tid) | |
y = tl.rand(seed + 1, tid) | |
inside_circle = (x * x + y * y <= 1.0) | |
counter += tl.where(inside_circle, 1, 0) | |
block_total = tl.sum(counter) | |
tl.atomic_add(total_ptr, block_total) | |
def estimate_pi(NBLOCKS, BLOCK_SIZE, ITERATIONS_PER_THREAD): | |
total = torch.zeros(1, dtype=torch.int64, device='cuda') | |
pi_kernel[(NBLOCKS,)]( | |
total, | |
BLOCK_SIZE=BLOCK_SIZE, | |
ITERATIONS_PER_THREAD=ITERATIONS_PER_THREAD, | |
) | |
total_hits = total.item() | |
total_points = NBLOCKS * BLOCK_SIZE * ITERATIONS_PER_THREAD | |
pi_estimate = 4.0 * total_hits / total_points | |
return pi_estimate, total_points | |
def estimate_pi_python(total_points): | |
x = np.random.rand(total_points) | |
y = np.random.rand(total_points) | |
inside_circle = np.sum((x * x + y * y) <= 1.0) | |
pi_estimate = 4.0 * inside_circle / total_points | |
return pi_estimate | |
if __name__ == "__main__": | |
NBLOCKS = 1024 | |
BLOCK_SIZE = 1024 | |
ITERATIONS_PER_THREAD = 1000 | |
PI_REF = 3.141592653589793 | |
total_points = NBLOCKS * BLOCK_SIZE * ITERATIONS_PER_THREAD | |
start_time_gpu = time.time() | |
pi_estimate_gpu, _ = estimate_pi(NBLOCKS, BLOCK_SIZE, ITERATIONS_PER_THREAD) | |
torch.cuda.synchronize() | |
end_time_gpu = time.time() | |
start_time_python = time.time() | |
pi_estimate_python = estimate_pi_python(total_points) | |
end_time_python = time.time() | |
print(f"Total of {total_points/10**9:.2f}G random tests") | |
print(f"Triton Pi ~= {pi_estimate_gpu:.15f}") | |
print(f"Numpy Pi ~= {pi_estimate_python:.15f}") | |
print(f"Numpy error : {abs(PI_REF - pi_estimate_python):.15f}") | |
print(f"Triton error : {abs(PI_REF - pi_estimate_gpu):.15f}") | |
print(f"Triton time : {end_time_gpu - start_time_gpu:.6f} seconds") | |
print(f"Numpy time: {end_time_python - start_time_python:.6f} seconds") |
Author
staghado
commented
Oct 6, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment