From d957a9057383f71ac86f5ee0f53adf70bd4724fb Mon Sep 17 00:00:00 2001 From: Alex Selimov Date: Fri, 19 Sep 2025 23:46:21 -0400 Subject: [PATCH] Cleanup code slightly and implement tests for CellList --- kernels/neighbor_list.cuh | 20 ++++++++----- src/box.hpp | 10 +++++++ tests/cuda_unit_tests/CMakeLists.txt | 1 + tests/cuda_unit_tests/test_neighbor_list.cu | 31 +++++++++++++++++++++ 4 files changed, 55 insertions(+), 7 deletions(-) create mode 100644 tests/cuda_unit_tests/test_neighbor_list.cu diff --git a/kernels/neighbor_list.cuh b/kernels/neighbor_list.cuh index fe4b67b..73039be 100644 --- a/kernels/neighbor_list.cuh +++ b/kernels/neighbor_list.cuh @@ -25,14 +25,16 @@ struct CellList { size_t total_cells; size_t n_particles; - CellList(size_t n_particles, Box &box, float cutoff) + CellList(size_t n_particles, Box &box, float r_cutoff) : n_particles(n_particles) { box_min.x = box.xlo; box_min.y = box.ylo; box_min.z = box.zlo; - auto [grid_size, cell_size] = calc_grid_and_cell_size(box, cutoff); + auto [grid_size, cell_size] = calc_grid_and_cell_size(box, r_cutoff); + this->grid_size = grid_size; + this->cell_size = cell_size; total_cells = grid_size.x * grid_size.y * grid_size.z; @@ -50,14 +52,18 @@ struct CellList { } // Get cell index from 3D coordinates - __device__ int get_cell_index(int x, int y, int z) const { + // TODO; Maybe update this to use Morton Encodings in the future to improve + // locality of particle indices. Unclear how much of a benefit this will add, + // but would be cool to do + __host__ __device__ int get_cell_index(int x, int y, int z) const { return z * grid_size.x * grid_size.y + y * grid_size.x + x; } - std::pair calc_grid_and_cell_size(Box &box, float cutoff) { - int3 grid_size = {utils::max((int)(box.xhi - box.xlo) / cutoff, 1), - utils::max((int)(box.yhi - box.ylo) / cutoff, 1), - utils::max((int)(box.zhi - box.zlo) / cutoff, 1)}; + std::pair calc_grid_and_cell_size(Box &box, + float r_cutoff) const { + int3 grid_size = {utils::max((int)(box.xhi - box.xlo) / r_cutoff, 1), + utils::max((int)(box.yhi - box.ylo) / r_cutoff, 1), + utils::max((int)(box.zhi - box.zlo) / r_cutoff, 1)}; float3 cell_size = { (box.xhi - box.xlo) / grid_size.x, diff --git a/src/box.hpp b/src/box.hpp index b588c49..ab4a4ec 100644 --- a/src/box.hpp +++ b/src/box.hpp @@ -18,6 +18,16 @@ struct Box { bool x_is_periodic; bool y_is_periodic; bool z_is_periodic; + + Box(real xlo, real xhi, real ylo, real yhi, real zlo, real zhi, + bool x_is_periodic, bool y_is_periodic, bool z_is_periodic) + : xlo(xlo), xhi(xhi), ylo(ylo), yhi(yhi), zlo(zlo), zhi(zhi), + x_is_periodic(x_is_periodic), y_is_periodic(y_is_periodic), + z_is_periodic(z_is_periodic) {} + + Box(real xlo, real xhi, real ylo, real yhi, real zlo, real zhi) + : xlo(xlo), xhi(xhi), ylo(ylo), yhi(yhi), zlo(zlo), zhi(zhi), + x_is_periodic(true), y_is_periodic(true), z_is_periodic(true) {} }; #endif diff --git a/tests/cuda_unit_tests/CMakeLists.txt b/tests/cuda_unit_tests/CMakeLists.txt index 4ead02b..571d8eb 100644 --- a/tests/cuda_unit_tests/CMakeLists.txt +++ b/tests/cuda_unit_tests/CMakeLists.txt @@ -4,6 +4,7 @@ add_executable(${NAME}_cuda_tests test_potential.cu test_forces.cu test_kernel_config.cu + test_neighbor_list.cu ) target_link_libraries(${NAME}_cuda_tests gtest gtest_main) diff --git a/tests/cuda_unit_tests/test_neighbor_list.cu b/tests/cuda_unit_tests/test_neighbor_list.cu new file mode 100644 index 0000000..9818075 --- /dev/null +++ b/tests/cuda_unit_tests/test_neighbor_list.cu @@ -0,0 +1,31 @@ + +#include "box.hpp" +#include "neighbor_list.cuh" +#include "gtest/gtest.h" + +TEST(CellListTest, Constructor) { + // Test case 1: Simple case + Box box(0, 10, 0, 10, 0, 10); + float cutoff = 2.5; + CellList cell_list(100, box, cutoff); + + EXPECT_EQ(cell_list.grid_size.x, 4); + EXPECT_EQ(cell_list.grid_size.y, 4); + EXPECT_EQ(cell_list.grid_size.z, 4); + EXPECT_FLOAT_EQ(cell_list.cell_size.x, 2.5); + EXPECT_FLOAT_EQ(cell_list.cell_size.y, 2.5); + EXPECT_FLOAT_EQ(cell_list.cell_size.z, 2.5); + EXPECT_EQ(cell_list.total_cells, 64); +} + +TEST(CellListTest, GetCellIndex) { + Box box(0, 10, 0, 10, 0, 10); + float cutoff = 2.5; + CellList cell_list(100, box, cutoff); + + int x = 1, y = 2, z = 3; + int expected_index = z * cell_list.grid_size.x * cell_list.grid_size.y + + y * cell_list.grid_size.x + x; + + EXPECT_EQ(cell_list.get_cell_index(x, y, z), expected_index); +}