Cleanup code slightly and implement tests for CellList
Some checks failed
Build and Test / build-and-test (push) Failing after 5m4s

This commit is contained in:
Alex Selimov 2025-09-19 23:46:21 -04:00
parent f3e701236e
commit d957a90573
Signed by: aselimov
GPG key ID: 3DDB9C3E023F1F31
4 changed files with 55 additions and 7 deletions

View file

@ -25,14 +25,16 @@ struct CellList {
size_t total_cells; size_t total_cells;
size_t n_particles; size_t n_particles;
CellList(size_t n_particles, Box &box, float cutoff) CellList(size_t n_particles, Box &box, float r_cutoff)
: n_particles(n_particles) { : n_particles(n_particles) {
box_min.x = box.xlo; box_min.x = box.xlo;
box_min.y = box.ylo; box_min.y = box.ylo;
box_min.z = box.zlo; box_min.z = box.zlo;
auto [grid_size, cell_size] = calc_grid_and_cell_size(box, cutoff); auto [grid_size, cell_size] = calc_grid_and_cell_size(box, r_cutoff);
this->grid_size = grid_size;
this->cell_size = cell_size;
total_cells = grid_size.x * grid_size.y * grid_size.z; total_cells = grid_size.x * grid_size.y * grid_size.z;
@ -50,14 +52,18 @@ struct CellList {
} }
// Get cell index from 3D coordinates // Get cell index from 3D coordinates
__device__ int get_cell_index(int x, int y, int z) const { // TODO; Maybe update this to use Morton Encodings in the future to improve
// locality of particle indices. Unclear how much of a benefit this will add,
// but would be cool to do
__host__ __device__ int get_cell_index(int x, int y, int z) const {
return z * grid_size.x * grid_size.y + y * grid_size.x + x; return z * grid_size.x * grid_size.y + y * grid_size.x + x;
} }
std::pair<int3, float3> calc_grid_and_cell_size(Box &box, float cutoff) { std::pair<int3, float3> calc_grid_and_cell_size(Box &box,
int3 grid_size = {utils::max((int)(box.xhi - box.xlo) / cutoff, 1), float r_cutoff) const {
utils::max((int)(box.yhi - box.ylo) / cutoff, 1), int3 grid_size = {utils::max((int)(box.xhi - box.xlo) / r_cutoff, 1),
utils::max((int)(box.zhi - box.zlo) / cutoff, 1)}; utils::max((int)(box.yhi - box.ylo) / r_cutoff, 1),
utils::max((int)(box.zhi - box.zlo) / r_cutoff, 1)};
float3 cell_size = { float3 cell_size = {
(box.xhi - box.xlo) / grid_size.x, (box.xhi - box.xlo) / grid_size.x,

View file

@ -18,6 +18,16 @@ struct Box {
bool x_is_periodic; bool x_is_periodic;
bool y_is_periodic; bool y_is_periodic;
bool z_is_periodic; bool z_is_periodic;
Box(real xlo, real xhi, real ylo, real yhi, real zlo, real zhi,
bool x_is_periodic, bool y_is_periodic, bool z_is_periodic)
: xlo(xlo), xhi(xhi), ylo(ylo), yhi(yhi), zlo(zlo), zhi(zhi),
x_is_periodic(x_is_periodic), y_is_periodic(y_is_periodic),
z_is_periodic(z_is_periodic) {}
Box(real xlo, real xhi, real ylo, real yhi, real zlo, real zhi)
: xlo(xlo), xhi(xhi), ylo(ylo), yhi(yhi), zlo(zlo), zhi(zhi),
x_is_periodic(true), y_is_periodic(true), z_is_periodic(true) {}
}; };
#endif #endif

View file

@ -4,6 +4,7 @@ add_executable(${NAME}_cuda_tests
test_potential.cu test_potential.cu
test_forces.cu test_forces.cu
test_kernel_config.cu test_kernel_config.cu
test_neighbor_list.cu
) )
target_link_libraries(${NAME}_cuda_tests gtest gtest_main) target_link_libraries(${NAME}_cuda_tests gtest gtest_main)

View file

@ -0,0 +1,31 @@
#include "box.hpp"
#include "neighbor_list.cuh"
#include "gtest/gtest.h"
TEST(CellListTest, Constructor) {
// Test case 1: Simple case
Box box(0, 10, 0, 10, 0, 10);
float cutoff = 2.5;
CellList cell_list(100, box, cutoff);
EXPECT_EQ(cell_list.grid_size.x, 4);
EXPECT_EQ(cell_list.grid_size.y, 4);
EXPECT_EQ(cell_list.grid_size.z, 4);
EXPECT_FLOAT_EQ(cell_list.cell_size.x, 2.5);
EXPECT_FLOAT_EQ(cell_list.cell_size.y, 2.5);
EXPECT_FLOAT_EQ(cell_list.cell_size.z, 2.5);
EXPECT_EQ(cell_list.total_cells, 64);
}
TEST(CellListTest, GetCellIndex) {
Box box(0, 10, 0, 10, 0, 10);
float cutoff = 2.5;
CellList cell_list(100, box, cutoff);
int x = 1, y = 2, z = 3;
int expected_index = z * cell_list.grid_size.x * cell_list.grid_size.y +
y * cell_list.grid_size.x + x;
EXPECT_EQ(cell_list.get_cell_index(x, y, z), expected_index);
}