diff --git a/kernels/neighbor_list.cu b/kernels/neighbor_list.cu index beadde8..c7016c4 100644 --- a/kernels/neighbor_list.cu +++ b/kernels/neighbor_list.cu @@ -1,194 +1 @@ #include "neighbor_list.cuh" - -/** - * Step 1: Assign particles to cells - */ -__global__ void assign_particles_to_cells_kernel(const float4 *positions, - int *particle_cells, - const CellList cell_list, - size_t n_particles) { - size_t i = get_thread_id(); - if (i >= n_particles) - return; - - float4 pos = positions[i]; - float3 pos3 = make_float3(pos.x, pos.y, pos.z); - - int3 cell_coords = cell_list.get_cell_coords(pos3); - - // Clamp to valid range (handle edge cases) - cell_coords.x = max(0, min(cell_coords.x, cell_list.grid_size.x - 1)); - cell_coords.y = max(0, min(cell_coords.y, cell_list.grid_size.y - 1)); - cell_coords.z = max(0, min(cell_coords.z, cell_list.grid_size.z - 1)); - - particle_cells[i] = - cell_list.get_cell_index(cell_coords.x, cell_coords.y, cell_coords.z); -} - -/** - * Step 2: Find cell boundaries after sorting - */ -__global__ void find_cell_boundaries_kernel(const int *sorted_particle_cells, - int *cell_starts, int *cell_ends, - size_t n_particles, - size_t total_cells) { - size_t i = get_thread_id(); - if (i >= n_particles) - return; - - int cell = sorted_particle_cells[i]; - - // Check if this is the start of a new cell - if (i == 0 || sorted_particle_cells[i - 1] != cell) { - cell_starts[cell] = i; - } - - // Check if this is the end of a cell - if (i == n_particles - 1 || sorted_particle_cells[i + 1] != cell) { - cell_ends[cell] = i + 1; - } -} - -/** - * Step 3: Build actual neighbor lists using cell lists - */ -__global__ void -build_neighbor_lists_kernel(const float4 *positions, const CellList cell_list, - NeighborList neighbor_list, float cutoff_squared, - const float3 *box_lengths, size_t n_particles) { - size_t i = get_thread_id(); - if (i >= n_particles) - return; - - float4 pos_i = positions[i]; - float3 my_pos = make_float3(pos_i.x, pos_i.y, pos_i.z); - - int3 my_cell = cell_list.get_cell_coords(my_pos); - int neighbor_count = 0; - int max_neighbors = neighbor_list.max_neighbors; - int *my_neighbors = neighbor_list.get_neighbors(i); - - // Search neighboring cells (3x3x3 = 27 cells including self) - for (int dz = -1; dz <= 1; dz++) { - for (int dy = -1; dy <= 1; dy++) { - for (int dx = -1; dx <= 1; dx++) { - - int3 neighbor_cell = - make_int3(my_cell.x + dx, my_cell.y + dy, my_cell.z + dz); - - // Check bounds - if (neighbor_cell.x < 0 || neighbor_cell.x >= cell_list.grid_size.x || - neighbor_cell.y < 0 || neighbor_cell.y >= cell_list.grid_size.y || - neighbor_cell.z < 0 || neighbor_cell.z >= cell_list.grid_size.z) { - continue; - } - - int cell_idx = cell_list.get_cell_index( - neighbor_cell.x, neighbor_cell.y, neighbor_cell.z); - int start = cell_list.cell_starts[cell_idx]; - int end = cell_list.cell_ends[cell_idx]; - - // Check all particles in this cell - for (int idx = start; idx < end; idx++) { - int j = cell_list.sorted_particles[idx]; - - if (i >= j) - continue; // Avoid double counting and self-interaction - - float4 pos_j = positions[j]; - float3 other_pos = make_float3(pos_j.x, pos_j.y, pos_j.z); - - // Calculate distance with periodic boundary conditions - float dx = my_pos.x - other_pos.x; - float dy = my_pos.y - other_pos.y; - float dz = my_pos.z - other_pos.z; - - // Apply PBC - dx -= box_lengths->x * roundf(dx / box_lengths->x); - dy -= box_lengths->y * roundf(dy / box_lengths->y); - dz -= box_lengths->z * roundf(dz / box_lengths->z); - - float r_squared = dx * dx + dy * dy + dz * dz; - - if (r_squared <= cutoff_squared && neighbor_count < max_neighbors) { - my_neighbors[neighbor_count] = j; - neighbor_count++; - } - } - } - } - } - - neighbor_list.neighbor_counts[i] = neighbor_count; -} - -// ============================================================================= -// HOST FUNCTIONS -// ============================================================================= - -void build_cell_list(const float4 *positions, CellList &cell_list, - const float3 *box_lengths) { - - // Step 1: Assign particles to cells - auto config = get_launch_config(cell_list.n_particles); - - assign_particles_to_cells_kernel<<>>( - positions, cell_list.particle_cells, cell_list, cell_list.n_particles); - - // Step 2: Sort particles by cell index using Thrust - thrust::device_ptr particle_cells_ptr(cell_list.particle_cells); - thrust::device_ptr sorted_particles_ptr(cell_list.sorted_particles); - - // Initialize particle indices 0, 1, 2, ... - thrust::sequence(sorted_particles_ptr, - sorted_particles_ptr + cell_list.n_particles); - - // Sort particle indices by their cell assignments - thrust::sort_by_key(particle_cells_ptr, - particle_cells_ptr + cell_list.n_particles, - sorted_particles_ptr); - - // Step 3: Initialize cell boundaries to -1 - cudaMemset(cell_list.cell_starts, -1, cell_list.total_cells * sizeof(int)); - cudaMemset(cell_list.cell_ends, -1, cell_list.total_cells * sizeof(int)); - - // Step 4: Find cell boundaries - config = get_launch_config(cell_list.n_particles); - find_cell_boundaries_kernel<<>>( - cell_list.particle_cells, cell_list.cell_starts, cell_list.cell_ends, - cell_list.n_particles, cell_list.total_cells); -} - -void build_neighbor_list(const float4 *positions, const CellList &cell_list, - NeighborList &neighbor_list, float cutoff_squared, - const float3 *box_lengths) { - - // Initialize neighbor counts to 0 - cudaMemset(neighbor_list.neighbor_counts, 0, - neighbor_list.n_particles * sizeof(int)); - - auto config = get_launch_config(neighbor_list.n_particles); - - build_neighbor_lists_kernel<<>>( - positions, cell_list, neighbor_list, cutoff_squared, box_lengths, - neighbor_list.n_particles); -} - -void build_neighbor_list_optimized(const float4 *positions, size_t n_particles, - const float3 *box_lengths, float cutoff, - NeighborList &neighbor_list) { - - // Get box size for cell list construction - float3 box_size; - cudaMemcpy(&box_size, box_lengths, sizeof(float3), cudaMemcpyDeviceToHost); - - // Create cell list - CellList cell_list(n_particles, box_size, cutoff); - - // Build cell list - build_cell_list(positions, cell_list, box_lengths); - - // Build neighbor list using cell list - build_neighbor_list(positions, cell_list, neighbor_list, cutoff * cutoff, - box_lengths); -} diff --git a/kernels/neighbor_list.cuh b/kernels/neighbor_list.cuh index 82afaf5..73039be 100644 --- a/kernels/neighbor_list.cuh +++ b/kernels/neighbor_list.cuh @@ -1,68 +1,21 @@ #ifndef NEIGHBOR_LIST_CUH #define NEIGHBOR_LIST_CUH -#include "kernel_config.cuh" // From previous artifact +#include "box.hpp" +#include "kernel_config.cuh" +#include "utils.cuh" #include #include +#include #include #include -/** - * Simple kernel to initialize neighbor offsets - */ -__global__ void init_neighbor_offsets(int *offsets, size_t n_particles, - int max_neighbors) { - size_t i = get_thread_id(); - if (i <= n_particles) { // Note: <= because we need n_particles + 1 elements - offsets[i] = i * max_neighbors; - } -} - -/** - * Neighbor list data structure - * Uses a compact format similar to CSR sparse matrices - */ -struct NeighborList { - int *neighbor_offsets; // Size: n_particles + 1, offset into neighbor_indices - int *neighbor_indices; // Size: total_neighbors, actual neighbor particle IDs - int *neighbor_counts; // Size: n_particles, number of neighbors per particle - size_t max_neighbors; // Maximum neighbors allocated per particle - size_t n_particles; - - // Constructor - NeighborList(size_t n_particles, size_t max_neighbors_per_particle) - : max_neighbors(max_neighbors_per_particle), n_particles(n_particles) { - - cudaMalloc(&neighbor_offsets, (n_particles + 1) * sizeof(int)); - cudaMalloc(&neighbor_indices, n_particles * max_neighbors * sizeof(int)); - cudaMalloc(&neighbor_counts, n_particles * sizeof(int)); - - // Initialize offsets - auto kernel_config = get_launch_config(n_particles); - init_neighbor_offsets<<>>( - neighbor_offsets, n_particles, max_neighbors); - } - - ~NeighborList() { - cudaFree(neighbor_offsets); - cudaFree(neighbor_indices); - cudaFree(neighbor_counts); - } - - // Get neighbors for particle i (device function) - __device__ int *get_neighbors(int i) const { - return &neighbor_indices[neighbor_offsets[i]]; - } - - __device__ int get_neighbor_count(int i) const { return neighbor_counts[i]; } -}; - /** * Cell list structure for spatial hashing */ struct CellList { int *cell_starts; // Size: total_cells, start index in sorted_particles - int *cell_ends; // Size: total_cells, end index in sorted_particles + int *cell_count; // Size: total_cells, end index in sorted_particles int *sorted_particles; // Size: n_particles, particle IDs sorted by cell int *particle_cells; // Size: n_particles, which cell each particle belongs to @@ -72,40 +25,55 @@ struct CellList { size_t total_cells; size_t n_particles; - CellList(size_t n_particles, float3 box_size, float cutoff) + CellList(size_t n_particles, Box &box, float r_cutoff) : n_particles(n_particles) { - // Calculate grid dimensions (cell size slightly larger than cutoff) - float cell_margin = 1.001f; // Small safety margin - cell_size = make_float3(cutoff * cell_margin, cutoff * cell_margin, - cutoff * cell_margin); + box_min.x = box.xlo; + box_min.y = box.ylo; + box_min.z = box.zlo; - grid_size.x = (int)ceilf(box_size.x / cell_size.x); - grid_size.y = (int)ceilf(box_size.y / cell_size.y); - grid_size.z = (int)ceilf(box_size.z / cell_size.z); + auto [grid_size, cell_size] = calc_grid_and_cell_size(box, r_cutoff); + this->grid_size = grid_size; + this->cell_size = cell_size; - total_cells = (size_t)grid_size.x * grid_size.y * grid_size.z; - box_min = make_float3(0.0f, 0.0f, 0.0f); // Assume box starts at origin + total_cells = grid_size.x * grid_size.y * grid_size.z; - // Allocate memory cudaMalloc(&cell_starts, total_cells * sizeof(int)); - cudaMalloc(&cell_ends, total_cells * sizeof(int)); + cudaMalloc(&cell_count, total_cells * sizeof(int)); cudaMalloc(&sorted_particles, n_particles * sizeof(int)); cudaMalloc(&particle_cells, n_particles * sizeof(int)); } ~CellList() { cudaFree(cell_starts); - cudaFree(cell_ends); + cudaFree(cell_count); cudaFree(sorted_particles); cudaFree(particle_cells); } // Get cell index from 3D coordinates - __device__ int get_cell_index(int x, int y, int z) const { + // TODO; Maybe update this to use Morton Encodings in the future to improve + // locality of particle indices. Unclear how much of a benefit this will add, + // but would be cool to do + __host__ __device__ int get_cell_index(int x, int y, int z) const { return z * grid_size.x * grid_size.y + y * grid_size.x + x; } + std::pair calc_grid_and_cell_size(Box &box, + float r_cutoff) const { + int3 grid_size = {utils::max((int)(box.xhi - box.xlo) / r_cutoff, 1), + utils::max((int)(box.yhi - box.ylo) / r_cutoff, 1), + utils::max((int)(box.zhi - box.zlo) / r_cutoff, 1)}; + + float3 cell_size = { + (box.xhi - box.xlo) / grid_size.x, + (box.yhi - box.ylo) / grid_size.y, + (box.zhi - box.zlo) / grid_size.z, + }; + + return std::make_pair(grid_size, cell_size); + } + // Get cell coordinates from position __device__ int3 get_cell_coords(float3 pos) const { return make_int3((int)((pos.x - box_min.x) / cell_size.x), @@ -114,19 +82,4 @@ struct CellList { } }; -// Forward declarations -void build_cell_list(const float4 *positions, CellList &cell_list, - const float3 *box_lengths); - -void build_neighbor_list(const float4 *positions, const CellList &cell_list, - NeighborList &neighbor_list, float cutoff_squared, - const float3 *box_lengths); - -/** - * High-level interface - builds complete neighbor list - */ -void build_neighbor_list_optimized(const float4 *positions, size_t n_particles, - const float3 *box_lengths, float cutoff, - NeighborList &neighbor_list); - #endif diff --git a/kernels/utils.cuh b/kernels/utils.cuh new file mode 100644 index 0000000..6f8fbd8 --- /dev/null +++ b/kernels/utils.cuh @@ -0,0 +1,8 @@ +#ifndef UTILS_CUH +#define UTILS_CUH + +namespace utils { +__device__ __host__ inline int max(int a, int b) { return (a > b) ? a : b; } +} // namespace utils + +#endif diff --git a/src/box.hpp b/src/box.hpp index b588c49..ab4a4ec 100644 --- a/src/box.hpp +++ b/src/box.hpp @@ -18,6 +18,16 @@ struct Box { bool x_is_periodic; bool y_is_periodic; bool z_is_periodic; + + Box(real xlo, real xhi, real ylo, real yhi, real zlo, real zhi, + bool x_is_periodic, bool y_is_periodic, bool z_is_periodic) + : xlo(xlo), xhi(xhi), ylo(ylo), yhi(yhi), zlo(zlo), zhi(zhi), + x_is_periodic(x_is_periodic), y_is_periodic(y_is_periodic), + z_is_periodic(z_is_periodic) {} + + Box(real xlo, real xhi, real ylo, real yhi, real zlo, real zhi) + : xlo(xlo), xhi(xhi), ylo(ylo), yhi(yhi), zlo(zlo), zhi(zhi), + x_is_periodic(true), y_is_periodic(true), z_is_periodic(true) {} }; #endif diff --git a/tests/cuda_unit_tests/CMakeLists.txt b/tests/cuda_unit_tests/CMakeLists.txt index 4ead02b..571d8eb 100644 --- a/tests/cuda_unit_tests/CMakeLists.txt +++ b/tests/cuda_unit_tests/CMakeLists.txt @@ -4,6 +4,7 @@ add_executable(${NAME}_cuda_tests test_potential.cu test_forces.cu test_kernel_config.cu + test_neighbor_list.cu ) target_link_libraries(${NAME}_cuda_tests gtest gtest_main) diff --git a/tests/cuda_unit_tests/test_neighbor_list.cu b/tests/cuda_unit_tests/test_neighbor_list.cu new file mode 100644 index 0000000..9818075 --- /dev/null +++ b/tests/cuda_unit_tests/test_neighbor_list.cu @@ -0,0 +1,31 @@ + +#include "box.hpp" +#include "neighbor_list.cuh" +#include "gtest/gtest.h" + +TEST(CellListTest, Constructor) { + // Test case 1: Simple case + Box box(0, 10, 0, 10, 0, 10); + float cutoff = 2.5; + CellList cell_list(100, box, cutoff); + + EXPECT_EQ(cell_list.grid_size.x, 4); + EXPECT_EQ(cell_list.grid_size.y, 4); + EXPECT_EQ(cell_list.grid_size.z, 4); + EXPECT_FLOAT_EQ(cell_list.cell_size.x, 2.5); + EXPECT_FLOAT_EQ(cell_list.cell_size.y, 2.5); + EXPECT_FLOAT_EQ(cell_list.cell_size.z, 2.5); + EXPECT_EQ(cell_list.total_cells, 64); +} + +TEST(CellListTest, GetCellIndex) { + Box box(0, 10, 0, 10, 0, 10); + float cutoff = 2.5; + CellList cell_list(100, box, cutoff); + + int x = 1, y = 2, z = 3; + int expected_index = z * cell_list.grid_size.x * cell_list.grid_size.y + + y * cell_list.grid_size.x + x; + + EXPECT_EQ(cell_list.get_cell_index(x, y, z), expected_index); +}