#ifndef KERNEL_CONFIG_CUH #define KERNEL_CONFIG_CUH #include #include /** * Structure to hold grid launch configuration */ struct KernelConfig { dim3 blocks; dim3 threads; // Convenience constructor KernelConfig(dim3 b, dim3 t) : blocks(b), threads(t) {} // Total number of threads launched size_t total_threads() const; // Print configuration for debugging void print() const; }; /** * Calculate optimal CUDA launch configuration for 1D problem * * @param n_elements Number of elements to process * @param threads_per_block Desired threads per block (default: 256) * @param max_blocks_per_dim Maximum blocks per grid dimension (default: 65535) * @return LaunchConfig with optimal grid and block dimensions */ KernelConfig get_launch_config(size_t n_elements, int threads_per_block = 256, int max_blocks_per_dim = 65535); /** * Calculate 1D thread index for kernels launched with get_launch_config() * Use this inside your CUDA kernels */ __device__ inline size_t get_thread_id() { return (size_t)blockIdx.z * gridDim.x * gridDim.y * blockDim.x + (size_t)blockIdx.y * gridDim.x * blockDim.x + (size_t)blockIdx.x * blockDim.x + threadIdx.x; } /** * Alternative version that takes grid dimensions as parameters * Useful if you need the index calculation in multiple places */ __device__ inline size_t get_thread_id(dim3 gridDim, dim3 blockDim, dim3 blockIdx, dim3 threadIdx) { return (size_t)blockIdx.z * gridDim.x * gridDim.y * blockDim.x + (size_t)blockIdx.y * gridDim.x * blockDim.x + (size_t)blockIdx.x * blockDim.x + threadIdx.x; } /** * GPU device properties helper - gets optimal block size for current device */ int get_optimal_block_size(int device_id = 0); /** * Advanced version that considers device properties */ KernelConfig get_launch_config_advanced(size_t n_elements, int device_id = 0); // Example usage in your kernel: /* template __global__ void calc_forces_and_energies(float4 *pos, float4 *force_energies, size_t n_particles, real *box_len, PotentialType potential) { size_t i = get_thread_id(); if (i < n_particles) { // Your existing force calculation code here... float4 my_pos = pos[i]; // ... rest of kernel unchanged } } */ #endif