Skip to content

Instantly share code, notes, and snippets.

@staghado
Created October 6, 2024 21:44
Pi estimation in Triton because why not
import triton
import triton.language as tl
import torch
import time
import numpy as np
@triton.jit
def pi_kernel(
total_ptr,
BLOCK_SIZE: tl.constexpr,
ITERATIONS_PER_THREAD: tl.constexpr,
):
pid = tl.program_id(0)
tid = tl.arange(0, BLOCK_SIZE)
counter = tl.zeros((BLOCK_SIZE,), dtype=tl.int64)
for i in range(ITERATIONS_PER_THREAD):
seed = (pid * BLOCK_SIZE + tid) * ITERATIONS_PER_THREAD + i
x = tl.rand(seed, tid)
y = tl.rand(seed + 1, tid)
inside_circle = (x * x + y * y <= 1.0)
counter += tl.where(inside_circle, 1, 0)
block_total = tl.sum(counter)
tl.atomic_add(total_ptr, block_total)
def estimate_pi(NBLOCKS, BLOCK_SIZE, ITERATIONS_PER_THREAD):
total = torch.zeros(1, dtype=torch.int64, device='cuda')
pi_kernel[(NBLOCKS,)](
total,
BLOCK_SIZE=BLOCK_SIZE,
ITERATIONS_PER_THREAD=ITERATIONS_PER_THREAD,
)
total_hits = total.item()
total_points = NBLOCKS * BLOCK_SIZE * ITERATIONS_PER_THREAD
pi_estimate = 4.0 * total_hits / total_points
return pi_estimate, total_points
def estimate_pi_python(total_points):
x = np.random.rand(total_points)
y = np.random.rand(total_points)
inside_circle = np.sum((x * x + y * y) <= 1.0)
pi_estimate = 4.0 * inside_circle / total_points
return pi_estimate
if __name__ == "__main__":
NBLOCKS = 1024
BLOCK_SIZE = 1024
ITERATIONS_PER_THREAD = 1000
PI_REF = 3.141592653589793
total_points = NBLOCKS * BLOCK_SIZE * ITERATIONS_PER_THREAD
start_time_gpu = time.time()
pi_estimate_gpu, _ = estimate_pi(NBLOCKS, BLOCK_SIZE, ITERATIONS_PER_THREAD)
torch.cuda.synchronize()
end_time_gpu = time.time()
start_time_python = time.time()
pi_estimate_python = estimate_pi_python(total_points)
end_time_python = time.time()
print(f"Total of {total_points/10**9:.2f}G random tests")
print(f"Triton Pi ~= {pi_estimate_gpu:.15f}")
print(f"Numpy Pi ~= {pi_estimate_python:.15f}")
print(f"Numpy error : {abs(PI_REF - pi_estimate_python):.15f}")
print(f"Triton error : {abs(PI_REF - pi_estimate_gpu):.15f}")
print(f"Triton time : {end_time_gpu - start_time_gpu:.6f} seconds")
print(f"Numpy time: {end_time_python - start_time_python:.6f} seconds")
@staghado
Copy link
Author

staghado commented Oct 6, 2024

Total of 10.49G random tests
Triton Pi ~= 3.141591793823242
Numpy Pi ~= 3.141590239334107
Numpy error : 0.000002414255686
Triton error : 0.000000859766551
Triton time : 3.579296 seconds
Numpy time: 195.734966 seconds

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment