Cleanup code slightly and implement tests for CellList
Some checks failed
Build and Test / build-and-test (push) Failing after 5m4s
Some checks failed
Build and Test / build-and-test (push) Failing after 5m4s
This commit is contained in:
parent
f3e701236e
commit
d957a90573
4 changed files with 55 additions and 7 deletions
|
@ -25,14 +25,16 @@ struct CellList {
|
||||||
size_t total_cells;
|
size_t total_cells;
|
||||||
size_t n_particles;
|
size_t n_particles;
|
||||||
|
|
||||||
CellList(size_t n_particles, Box &box, float cutoff)
|
CellList(size_t n_particles, Box &box, float r_cutoff)
|
||||||
: n_particles(n_particles) {
|
: n_particles(n_particles) {
|
||||||
|
|
||||||
box_min.x = box.xlo;
|
box_min.x = box.xlo;
|
||||||
box_min.y = box.ylo;
|
box_min.y = box.ylo;
|
||||||
box_min.z = box.zlo;
|
box_min.z = box.zlo;
|
||||||
|
|
||||||
auto [grid_size, cell_size] = calc_grid_and_cell_size(box, cutoff);
|
auto [grid_size, cell_size] = calc_grid_and_cell_size(box, r_cutoff);
|
||||||
|
this->grid_size = grid_size;
|
||||||
|
this->cell_size = cell_size;
|
||||||
|
|
||||||
total_cells = grid_size.x * grid_size.y * grid_size.z;
|
total_cells = grid_size.x * grid_size.y * grid_size.z;
|
||||||
|
|
||||||
|
@ -50,14 +52,18 @@ struct CellList {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get cell index from 3D coordinates
|
// Get cell index from 3D coordinates
|
||||||
__device__ int get_cell_index(int x, int y, int z) const {
|
// TODO; Maybe update this to use Morton Encodings in the future to improve
|
||||||
|
// locality of particle indices. Unclear how much of a benefit this will add,
|
||||||
|
// but would be cool to do
|
||||||
|
__host__ __device__ int get_cell_index(int x, int y, int z) const {
|
||||||
return z * grid_size.x * grid_size.y + y * grid_size.x + x;
|
return z * grid_size.x * grid_size.y + y * grid_size.x + x;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<int3, float3> calc_grid_and_cell_size(Box &box, float cutoff) {
|
std::pair<int3, float3> calc_grid_and_cell_size(Box &box,
|
||||||
int3 grid_size = {utils::max((int)(box.xhi - box.xlo) / cutoff, 1),
|
float r_cutoff) const {
|
||||||
utils::max((int)(box.yhi - box.ylo) / cutoff, 1),
|
int3 grid_size = {utils::max((int)(box.xhi - box.xlo) / r_cutoff, 1),
|
||||||
utils::max((int)(box.zhi - box.zlo) / cutoff, 1)};
|
utils::max((int)(box.yhi - box.ylo) / r_cutoff, 1),
|
||||||
|
utils::max((int)(box.zhi - box.zlo) / r_cutoff, 1)};
|
||||||
|
|
||||||
float3 cell_size = {
|
float3 cell_size = {
|
||||||
(box.xhi - box.xlo) / grid_size.x,
|
(box.xhi - box.xlo) / grid_size.x,
|
||||||
|
|
10
src/box.hpp
10
src/box.hpp
|
@ -18,6 +18,16 @@ struct Box {
|
||||||
bool x_is_periodic;
|
bool x_is_periodic;
|
||||||
bool y_is_periodic;
|
bool y_is_periodic;
|
||||||
bool z_is_periodic;
|
bool z_is_periodic;
|
||||||
|
|
||||||
|
Box(real xlo, real xhi, real ylo, real yhi, real zlo, real zhi,
|
||||||
|
bool x_is_periodic, bool y_is_periodic, bool z_is_periodic)
|
||||||
|
: xlo(xlo), xhi(xhi), ylo(ylo), yhi(yhi), zlo(zlo), zhi(zhi),
|
||||||
|
x_is_periodic(x_is_periodic), y_is_periodic(y_is_periodic),
|
||||||
|
z_is_periodic(z_is_periodic) {}
|
||||||
|
|
||||||
|
Box(real xlo, real xhi, real ylo, real yhi, real zlo, real zhi)
|
||||||
|
: xlo(xlo), xhi(xhi), ylo(ylo), yhi(yhi), zlo(zlo), zhi(zhi),
|
||||||
|
x_is_periodic(true), y_is_periodic(true), z_is_periodic(true) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -4,6 +4,7 @@ add_executable(${NAME}_cuda_tests
|
||||||
test_potential.cu
|
test_potential.cu
|
||||||
test_forces.cu
|
test_forces.cu
|
||||||
test_kernel_config.cu
|
test_kernel_config.cu
|
||||||
|
test_neighbor_list.cu
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_libraries(${NAME}_cuda_tests gtest gtest_main)
|
target_link_libraries(${NAME}_cuda_tests gtest gtest_main)
|
||||||
|
|
31
tests/cuda_unit_tests/test_neighbor_list.cu
Normal file
31
tests/cuda_unit_tests/test_neighbor_list.cu
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
|
||||||
|
#include "box.hpp"
|
||||||
|
#include "neighbor_list.cuh"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
TEST(CellListTest, Constructor) {
|
||||||
|
// Test case 1: Simple case
|
||||||
|
Box box(0, 10, 0, 10, 0, 10);
|
||||||
|
float cutoff = 2.5;
|
||||||
|
CellList cell_list(100, box, cutoff);
|
||||||
|
|
||||||
|
EXPECT_EQ(cell_list.grid_size.x, 4);
|
||||||
|
EXPECT_EQ(cell_list.grid_size.y, 4);
|
||||||
|
EXPECT_EQ(cell_list.grid_size.z, 4);
|
||||||
|
EXPECT_FLOAT_EQ(cell_list.cell_size.x, 2.5);
|
||||||
|
EXPECT_FLOAT_EQ(cell_list.cell_size.y, 2.5);
|
||||||
|
EXPECT_FLOAT_EQ(cell_list.cell_size.z, 2.5);
|
||||||
|
EXPECT_EQ(cell_list.total_cells, 64);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CellListTest, GetCellIndex) {
|
||||||
|
Box box(0, 10, 0, 10, 0, 10);
|
||||||
|
float cutoff = 2.5;
|
||||||
|
CellList cell_list(100, box, cutoff);
|
||||||
|
|
||||||
|
int x = 1, y = 2, z = 3;
|
||||||
|
int expected_index = z * cell_list.grid_size.x * cell_list.grid_size.y +
|
||||||
|
y * cell_list.grid_size.x + x;
|
||||||
|
|
||||||
|
EXPECT_EQ(cell_list.get_cell_index(x, y, z), expected_index);
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue