Cleanup code slightly and implement tests for CellList
Some checks failed
Build and Test / build-and-test (push) Failing after 5m4s
Some checks failed
Build and Test / build-and-test (push) Failing after 5m4s
This commit is contained in:
parent
f3e701236e
commit
d957a90573
4 changed files with 55 additions and 7 deletions
|
@ -25,14 +25,16 @@ struct CellList {
|
|||
size_t total_cells;
|
||||
size_t n_particles;
|
||||
|
||||
CellList(size_t n_particles, Box &box, float cutoff)
|
||||
CellList(size_t n_particles, Box &box, float r_cutoff)
|
||||
: n_particles(n_particles) {
|
||||
|
||||
box_min.x = box.xlo;
|
||||
box_min.y = box.ylo;
|
||||
box_min.z = box.zlo;
|
||||
|
||||
auto [grid_size, cell_size] = calc_grid_and_cell_size(box, cutoff);
|
||||
auto [grid_size, cell_size] = calc_grid_and_cell_size(box, r_cutoff);
|
||||
this->grid_size = grid_size;
|
||||
this->cell_size = cell_size;
|
||||
|
||||
total_cells = grid_size.x * grid_size.y * grid_size.z;
|
||||
|
||||
|
@ -50,14 +52,18 @@ struct CellList {
|
|||
}
|
||||
|
||||
// Get cell index from 3D coordinates
|
||||
__device__ int get_cell_index(int x, int y, int z) const {
|
||||
// TODO; Maybe update this to use Morton Encodings in the future to improve
|
||||
// locality of particle indices. Unclear how much of a benefit this will add,
|
||||
// but would be cool to do
|
||||
__host__ __device__ int get_cell_index(int x, int y, int z) const {
|
||||
return z * grid_size.x * grid_size.y + y * grid_size.x + x;
|
||||
}
|
||||
|
||||
std::pair<int3, float3> calc_grid_and_cell_size(Box &box, float cutoff) {
|
||||
int3 grid_size = {utils::max((int)(box.xhi - box.xlo) / cutoff, 1),
|
||||
utils::max((int)(box.yhi - box.ylo) / cutoff, 1),
|
||||
utils::max((int)(box.zhi - box.zlo) / cutoff, 1)};
|
||||
std::pair<int3, float3> calc_grid_and_cell_size(Box &box,
|
||||
float r_cutoff) const {
|
||||
int3 grid_size = {utils::max((int)(box.xhi - box.xlo) / r_cutoff, 1),
|
||||
utils::max((int)(box.yhi - box.ylo) / r_cutoff, 1),
|
||||
utils::max((int)(box.zhi - box.zlo) / r_cutoff, 1)};
|
||||
|
||||
float3 cell_size = {
|
||||
(box.xhi - box.xlo) / grid_size.x,
|
||||
|
|
10
src/box.hpp
10
src/box.hpp
|
@ -18,6 +18,16 @@ struct Box {
|
|||
bool x_is_periodic;
|
||||
bool y_is_periodic;
|
||||
bool z_is_periodic;
|
||||
|
||||
Box(real xlo, real xhi, real ylo, real yhi, real zlo, real zhi,
|
||||
bool x_is_periodic, bool y_is_periodic, bool z_is_periodic)
|
||||
: xlo(xlo), xhi(xhi), ylo(ylo), yhi(yhi), zlo(zlo), zhi(zhi),
|
||||
x_is_periodic(x_is_periodic), y_is_periodic(y_is_periodic),
|
||||
z_is_periodic(z_is_periodic) {}
|
||||
|
||||
Box(real xlo, real xhi, real ylo, real yhi, real zlo, real zhi)
|
||||
: xlo(xlo), xhi(xhi), ylo(ylo), yhi(yhi), zlo(zlo), zhi(zhi),
|
||||
x_is_periodic(true), y_is_periodic(true), z_is_periodic(true) {}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
|
|
@ -4,6 +4,7 @@ add_executable(${NAME}_cuda_tests
|
|||
test_potential.cu
|
||||
test_forces.cu
|
||||
test_kernel_config.cu
|
||||
test_neighbor_list.cu
|
||||
)
|
||||
|
||||
target_link_libraries(${NAME}_cuda_tests gtest gtest_main)
|
||||
|
|
31
tests/cuda_unit_tests/test_neighbor_list.cu
Normal file
31
tests/cuda_unit_tests/test_neighbor_list.cu
Normal file
|
@ -0,0 +1,31 @@
|
|||
|
||||
#include "box.hpp"
|
||||
#include "neighbor_list.cuh"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
TEST(CellListTest, Constructor) {
|
||||
// Test case 1: Simple case
|
||||
Box box(0, 10, 0, 10, 0, 10);
|
||||
float cutoff = 2.5;
|
||||
CellList cell_list(100, box, cutoff);
|
||||
|
||||
EXPECT_EQ(cell_list.grid_size.x, 4);
|
||||
EXPECT_EQ(cell_list.grid_size.y, 4);
|
||||
EXPECT_EQ(cell_list.grid_size.z, 4);
|
||||
EXPECT_FLOAT_EQ(cell_list.cell_size.x, 2.5);
|
||||
EXPECT_FLOAT_EQ(cell_list.cell_size.y, 2.5);
|
||||
EXPECT_FLOAT_EQ(cell_list.cell_size.z, 2.5);
|
||||
EXPECT_EQ(cell_list.total_cells, 64);
|
||||
}
|
||||
|
||||
TEST(CellListTest, GetCellIndex) {
|
||||
Box box(0, 10, 0, 10, 0, 10);
|
||||
float cutoff = 2.5;
|
||||
CellList cell_list(100, box, cutoff);
|
||||
|
||||
int x = 1, y = 2, z = 3;
|
||||
int expected_index = z * cell_list.grid_size.x * cell_list.grid_size.y +
|
||||
y * cell_list.grid_size.x + x;
|
||||
|
||||
EXPECT_EQ(cell_list.get_cell_index(x, y, z), expected_index);
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue