Compare commits
2 commits
8dec472929
...
d957a90573
Author | SHA1 | Date | |
---|---|---|---|
d957a90573 | |||
f3e701236e |
6 changed files with 84 additions and 274 deletions
|
@ -1,194 +1 @@
|
|||
#include "neighbor_list.cuh"
|
||||
|
||||
/**
|
||||
* Step 1: Assign particles to cells
|
||||
*/
|
||||
__global__ void assign_particles_to_cells_kernel(const float4 *positions,
|
||||
int *particle_cells,
|
||||
const CellList cell_list,
|
||||
size_t n_particles) {
|
||||
size_t i = get_thread_id();
|
||||
if (i >= n_particles)
|
||||
return;
|
||||
|
||||
float4 pos = positions[i];
|
||||
float3 pos3 = make_float3(pos.x, pos.y, pos.z);
|
||||
|
||||
int3 cell_coords = cell_list.get_cell_coords(pos3);
|
||||
|
||||
// Clamp to valid range (handle edge cases)
|
||||
cell_coords.x = max(0, min(cell_coords.x, cell_list.grid_size.x - 1));
|
||||
cell_coords.y = max(0, min(cell_coords.y, cell_list.grid_size.y - 1));
|
||||
cell_coords.z = max(0, min(cell_coords.z, cell_list.grid_size.z - 1));
|
||||
|
||||
particle_cells[i] =
|
||||
cell_list.get_cell_index(cell_coords.x, cell_coords.y, cell_coords.z);
|
||||
}
|
||||
|
||||
/**
|
||||
* Step 2: Find cell boundaries after sorting
|
||||
*/
|
||||
__global__ void find_cell_boundaries_kernel(const int *sorted_particle_cells,
|
||||
int *cell_starts, int *cell_ends,
|
||||
size_t n_particles,
|
||||
size_t total_cells) {
|
||||
size_t i = get_thread_id();
|
||||
if (i >= n_particles)
|
||||
return;
|
||||
|
||||
int cell = sorted_particle_cells[i];
|
||||
|
||||
// Check if this is the start of a new cell
|
||||
if (i == 0 || sorted_particle_cells[i - 1] != cell) {
|
||||
cell_starts[cell] = i;
|
||||
}
|
||||
|
||||
// Check if this is the end of a cell
|
||||
if (i == n_particles - 1 || sorted_particle_cells[i + 1] != cell) {
|
||||
cell_ends[cell] = i + 1;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Step 3: Build actual neighbor lists using cell lists
|
||||
*/
|
||||
__global__ void
|
||||
build_neighbor_lists_kernel(const float4 *positions, const CellList cell_list,
|
||||
NeighborList neighbor_list, float cutoff_squared,
|
||||
const float3 *box_lengths, size_t n_particles) {
|
||||
size_t i = get_thread_id();
|
||||
if (i >= n_particles)
|
||||
return;
|
||||
|
||||
float4 pos_i = positions[i];
|
||||
float3 my_pos = make_float3(pos_i.x, pos_i.y, pos_i.z);
|
||||
|
||||
int3 my_cell = cell_list.get_cell_coords(my_pos);
|
||||
int neighbor_count = 0;
|
||||
int max_neighbors = neighbor_list.max_neighbors;
|
||||
int *my_neighbors = neighbor_list.get_neighbors(i);
|
||||
|
||||
// Search neighboring cells (3x3x3 = 27 cells including self)
|
||||
for (int dz = -1; dz <= 1; dz++) {
|
||||
for (int dy = -1; dy <= 1; dy++) {
|
||||
for (int dx = -1; dx <= 1; dx++) {
|
||||
|
||||
int3 neighbor_cell =
|
||||
make_int3(my_cell.x + dx, my_cell.y + dy, my_cell.z + dz);
|
||||
|
||||
// Check bounds
|
||||
if (neighbor_cell.x < 0 || neighbor_cell.x >= cell_list.grid_size.x ||
|
||||
neighbor_cell.y < 0 || neighbor_cell.y >= cell_list.grid_size.y ||
|
||||
neighbor_cell.z < 0 || neighbor_cell.z >= cell_list.grid_size.z) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cell_idx = cell_list.get_cell_index(
|
||||
neighbor_cell.x, neighbor_cell.y, neighbor_cell.z);
|
||||
int start = cell_list.cell_starts[cell_idx];
|
||||
int end = cell_list.cell_ends[cell_idx];
|
||||
|
||||
// Check all particles in this cell
|
||||
for (int idx = start; idx < end; idx++) {
|
||||
int j = cell_list.sorted_particles[idx];
|
||||
|
||||
if (i >= j)
|
||||
continue; // Avoid double counting and self-interaction
|
||||
|
||||
float4 pos_j = positions[j];
|
||||
float3 other_pos = make_float3(pos_j.x, pos_j.y, pos_j.z);
|
||||
|
||||
// Calculate distance with periodic boundary conditions
|
||||
float dx = my_pos.x - other_pos.x;
|
||||
float dy = my_pos.y - other_pos.y;
|
||||
float dz = my_pos.z - other_pos.z;
|
||||
|
||||
// Apply PBC
|
||||
dx -= box_lengths->x * roundf(dx / box_lengths->x);
|
||||
dy -= box_lengths->y * roundf(dy / box_lengths->y);
|
||||
dz -= box_lengths->z * roundf(dz / box_lengths->z);
|
||||
|
||||
float r_squared = dx * dx + dy * dy + dz * dz;
|
||||
|
||||
if (r_squared <= cutoff_squared && neighbor_count < max_neighbors) {
|
||||
my_neighbors[neighbor_count] = j;
|
||||
neighbor_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
neighbor_list.neighbor_counts[i] = neighbor_count;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// HOST FUNCTIONS
|
||||
// =============================================================================
|
||||
|
||||
void build_cell_list(const float4 *positions, CellList &cell_list,
|
||||
const float3 *box_lengths) {
|
||||
|
||||
// Step 1: Assign particles to cells
|
||||
auto config = get_launch_config(cell_list.n_particles);
|
||||
|
||||
assign_particles_to_cells_kernel<<<config.blocks, config.threads>>>(
|
||||
positions, cell_list.particle_cells, cell_list, cell_list.n_particles);
|
||||
|
||||
// Step 2: Sort particles by cell index using Thrust
|
||||
thrust::device_ptr<int> particle_cells_ptr(cell_list.particle_cells);
|
||||
thrust::device_ptr<int> sorted_particles_ptr(cell_list.sorted_particles);
|
||||
|
||||
// Initialize particle indices 0, 1, 2, ...
|
||||
thrust::sequence(sorted_particles_ptr,
|
||||
sorted_particles_ptr + cell_list.n_particles);
|
||||
|
||||
// Sort particle indices by their cell assignments
|
||||
thrust::sort_by_key(particle_cells_ptr,
|
||||
particle_cells_ptr + cell_list.n_particles,
|
||||
sorted_particles_ptr);
|
||||
|
||||
// Step 3: Initialize cell boundaries to -1
|
||||
cudaMemset(cell_list.cell_starts, -1, cell_list.total_cells * sizeof(int));
|
||||
cudaMemset(cell_list.cell_ends, -1, cell_list.total_cells * sizeof(int));
|
||||
|
||||
// Step 4: Find cell boundaries
|
||||
config = get_launch_config(cell_list.n_particles);
|
||||
find_cell_boundaries_kernel<<<config.blocks, config.threads>>>(
|
||||
cell_list.particle_cells, cell_list.cell_starts, cell_list.cell_ends,
|
||||
cell_list.n_particles, cell_list.total_cells);
|
||||
}
|
||||
|
||||
void build_neighbor_list(const float4 *positions, const CellList &cell_list,
|
||||
NeighborList &neighbor_list, float cutoff_squared,
|
||||
const float3 *box_lengths) {
|
||||
|
||||
// Initialize neighbor counts to 0
|
||||
cudaMemset(neighbor_list.neighbor_counts, 0,
|
||||
neighbor_list.n_particles * sizeof(int));
|
||||
|
||||
auto config = get_launch_config(neighbor_list.n_particles);
|
||||
|
||||
build_neighbor_lists_kernel<<<config.blocks, config.threads>>>(
|
||||
positions, cell_list, neighbor_list, cutoff_squared, box_lengths,
|
||||
neighbor_list.n_particles);
|
||||
}
|
||||
|
||||
void build_neighbor_list_optimized(const float4 *positions, size_t n_particles,
|
||||
const float3 *box_lengths, float cutoff,
|
||||
NeighborList &neighbor_list) {
|
||||
|
||||
// Get box size for cell list construction
|
||||
float3 box_size;
|
||||
cudaMemcpy(&box_size, box_lengths, sizeof(float3), cudaMemcpyDeviceToHost);
|
||||
|
||||
// Create cell list
|
||||
CellList cell_list(n_particles, box_size, cutoff);
|
||||
|
||||
// Build cell list
|
||||
build_cell_list(positions, cell_list, box_lengths);
|
||||
|
||||
// Build neighbor list using cell list
|
||||
build_neighbor_list(positions, cell_list, neighbor_list, cutoff * cutoff,
|
||||
box_lengths);
|
||||
}
|
||||
|
|
|
@ -1,68 +1,21 @@
|
|||
#ifndef NEIGHBOR_LIST_CUH
|
||||
#define NEIGHBOR_LIST_CUH
|
||||
|
||||
#include "kernel_config.cuh" // From previous artifact
|
||||
#include "box.hpp"
|
||||
#include "kernel_config.cuh"
|
||||
#include "utils.cuh"
|
||||
#include <cuda_runtime.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/pair.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/sort.h>
|
||||
|
||||
/**
|
||||
* Simple kernel to initialize neighbor offsets
|
||||
*/
|
||||
__global__ void init_neighbor_offsets(int *offsets, size_t n_particles,
|
||||
int max_neighbors) {
|
||||
size_t i = get_thread_id();
|
||||
if (i <= n_particles) { // Note: <= because we need n_particles + 1 elements
|
||||
offsets[i] = i * max_neighbors;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Neighbor list data structure
|
||||
* Uses a compact format similar to CSR sparse matrices
|
||||
*/
|
||||
struct NeighborList {
|
||||
int *neighbor_offsets; // Size: n_particles + 1, offset into neighbor_indices
|
||||
int *neighbor_indices; // Size: total_neighbors, actual neighbor particle IDs
|
||||
int *neighbor_counts; // Size: n_particles, number of neighbors per particle
|
||||
size_t max_neighbors; // Maximum neighbors allocated per particle
|
||||
size_t n_particles;
|
||||
|
||||
// Constructor
|
||||
NeighborList(size_t n_particles, size_t max_neighbors_per_particle)
|
||||
: max_neighbors(max_neighbors_per_particle), n_particles(n_particles) {
|
||||
|
||||
cudaMalloc(&neighbor_offsets, (n_particles + 1) * sizeof(int));
|
||||
cudaMalloc(&neighbor_indices, n_particles * max_neighbors * sizeof(int));
|
||||
cudaMalloc(&neighbor_counts, n_particles * sizeof(int));
|
||||
|
||||
// Initialize offsets
|
||||
auto kernel_config = get_launch_config(n_particles);
|
||||
init_neighbor_offsets<<<kernel_config.blocks, kernel_config.threads>>>(
|
||||
neighbor_offsets, n_particles, max_neighbors);
|
||||
}
|
||||
|
||||
~NeighborList() {
|
||||
cudaFree(neighbor_offsets);
|
||||
cudaFree(neighbor_indices);
|
||||
cudaFree(neighbor_counts);
|
||||
}
|
||||
|
||||
// Get neighbors for particle i (device function)
|
||||
__device__ int *get_neighbors(int i) const {
|
||||
return &neighbor_indices[neighbor_offsets[i]];
|
||||
}
|
||||
|
||||
__device__ int get_neighbor_count(int i) const { return neighbor_counts[i]; }
|
||||
};
|
||||
|
||||
/**
|
||||
* Cell list structure for spatial hashing
|
||||
*/
|
||||
struct CellList {
|
||||
int *cell_starts; // Size: total_cells, start index in sorted_particles
|
||||
int *cell_ends; // Size: total_cells, end index in sorted_particles
|
||||
int *cell_count; // Size: total_cells, end index in sorted_particles
|
||||
int *sorted_particles; // Size: n_particles, particle IDs sorted by cell
|
||||
int *particle_cells; // Size: n_particles, which cell each particle belongs to
|
||||
|
||||
|
@ -72,40 +25,55 @@ struct CellList {
|
|||
size_t total_cells;
|
||||
size_t n_particles;
|
||||
|
||||
CellList(size_t n_particles, float3 box_size, float cutoff)
|
||||
CellList(size_t n_particles, Box &box, float r_cutoff)
|
||||
: n_particles(n_particles) {
|
||||
|
||||
// Calculate grid dimensions (cell size slightly larger than cutoff)
|
||||
float cell_margin = 1.001f; // Small safety margin
|
||||
cell_size = make_float3(cutoff * cell_margin, cutoff * cell_margin,
|
||||
cutoff * cell_margin);
|
||||
box_min.x = box.xlo;
|
||||
box_min.y = box.ylo;
|
||||
box_min.z = box.zlo;
|
||||
|
||||
grid_size.x = (int)ceilf(box_size.x / cell_size.x);
|
||||
grid_size.y = (int)ceilf(box_size.y / cell_size.y);
|
||||
grid_size.z = (int)ceilf(box_size.z / cell_size.z);
|
||||
auto [grid_size, cell_size] = calc_grid_and_cell_size(box, r_cutoff);
|
||||
this->grid_size = grid_size;
|
||||
this->cell_size = cell_size;
|
||||
|
||||
total_cells = (size_t)grid_size.x * grid_size.y * grid_size.z;
|
||||
box_min = make_float3(0.0f, 0.0f, 0.0f); // Assume box starts at origin
|
||||
total_cells = grid_size.x * grid_size.y * grid_size.z;
|
||||
|
||||
// Allocate memory
|
||||
cudaMalloc(&cell_starts, total_cells * sizeof(int));
|
||||
cudaMalloc(&cell_ends, total_cells * sizeof(int));
|
||||
cudaMalloc(&cell_count, total_cells * sizeof(int));
|
||||
cudaMalloc(&sorted_particles, n_particles * sizeof(int));
|
||||
cudaMalloc(&particle_cells, n_particles * sizeof(int));
|
||||
}
|
||||
|
||||
~CellList() {
|
||||
cudaFree(cell_starts);
|
||||
cudaFree(cell_ends);
|
||||
cudaFree(cell_count);
|
||||
cudaFree(sorted_particles);
|
||||
cudaFree(particle_cells);
|
||||
}
|
||||
|
||||
// Get cell index from 3D coordinates
|
||||
__device__ int get_cell_index(int x, int y, int z) const {
|
||||
// TODO; Maybe update this to use Morton Encodings in the future to improve
|
||||
// locality of particle indices. Unclear how much of a benefit this will add,
|
||||
// but would be cool to do
|
||||
__host__ __device__ int get_cell_index(int x, int y, int z) const {
|
||||
return z * grid_size.x * grid_size.y + y * grid_size.x + x;
|
||||
}
|
||||
|
||||
std::pair<int3, float3> calc_grid_and_cell_size(Box &box,
|
||||
float r_cutoff) const {
|
||||
int3 grid_size = {utils::max((int)(box.xhi - box.xlo) / r_cutoff, 1),
|
||||
utils::max((int)(box.yhi - box.ylo) / r_cutoff, 1),
|
||||
utils::max((int)(box.zhi - box.zlo) / r_cutoff, 1)};
|
||||
|
||||
float3 cell_size = {
|
||||
(box.xhi - box.xlo) / grid_size.x,
|
||||
(box.yhi - box.ylo) / grid_size.y,
|
||||
(box.zhi - box.zlo) / grid_size.z,
|
||||
};
|
||||
|
||||
return std::make_pair(grid_size, cell_size);
|
||||
}
|
||||
|
||||
// Get cell coordinates from position
|
||||
__device__ int3 get_cell_coords(float3 pos) const {
|
||||
return make_int3((int)((pos.x - box_min.x) / cell_size.x),
|
||||
|
@ -114,19 +82,4 @@ struct CellList {
|
|||
}
|
||||
};
|
||||
|
||||
// Forward declarations
|
||||
void build_cell_list(const float4 *positions, CellList &cell_list,
|
||||
const float3 *box_lengths);
|
||||
|
||||
void build_neighbor_list(const float4 *positions, const CellList &cell_list,
|
||||
NeighborList &neighbor_list, float cutoff_squared,
|
||||
const float3 *box_lengths);
|
||||
|
||||
/**
|
||||
* High-level interface - builds complete neighbor list
|
||||
*/
|
||||
void build_neighbor_list_optimized(const float4 *positions, size_t n_particles,
|
||||
const float3 *box_lengths, float cutoff,
|
||||
NeighborList &neighbor_list);
|
||||
|
||||
#endif
|
||||
|
|
8
kernels/utils.cuh
Normal file
8
kernels/utils.cuh
Normal file
|
@ -0,0 +1,8 @@
|
|||
#ifndef UTILS_CUH
|
||||
#define UTILS_CUH
|
||||
|
||||
namespace utils {
|
||||
__device__ __host__ inline int max(int a, int b) { return (a > b) ? a : b; }
|
||||
} // namespace utils
|
||||
|
||||
#endif
|
10
src/box.hpp
10
src/box.hpp
|
@ -18,6 +18,16 @@ struct Box {
|
|||
bool x_is_periodic;
|
||||
bool y_is_periodic;
|
||||
bool z_is_periodic;
|
||||
|
||||
Box(real xlo, real xhi, real ylo, real yhi, real zlo, real zhi,
|
||||
bool x_is_periodic, bool y_is_periodic, bool z_is_periodic)
|
||||
: xlo(xlo), xhi(xhi), ylo(ylo), yhi(yhi), zlo(zlo), zhi(zhi),
|
||||
x_is_periodic(x_is_periodic), y_is_periodic(y_is_periodic),
|
||||
z_is_periodic(z_is_periodic) {}
|
||||
|
||||
Box(real xlo, real xhi, real ylo, real yhi, real zlo, real zhi)
|
||||
: xlo(xlo), xhi(xhi), ylo(ylo), yhi(yhi), zlo(zlo), zhi(zhi),
|
||||
x_is_periodic(true), y_is_periodic(true), z_is_periodic(true) {}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
|
|
@ -4,6 +4,7 @@ add_executable(${NAME}_cuda_tests
|
|||
test_potential.cu
|
||||
test_forces.cu
|
||||
test_kernel_config.cu
|
||||
test_neighbor_list.cu
|
||||
)
|
||||
|
||||
target_link_libraries(${NAME}_cuda_tests gtest gtest_main)
|
||||
|
|
31
tests/cuda_unit_tests/test_neighbor_list.cu
Normal file
31
tests/cuda_unit_tests/test_neighbor_list.cu
Normal file
|
@ -0,0 +1,31 @@
|
|||
|
||||
#include "box.hpp"
|
||||
#include "neighbor_list.cuh"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
TEST(CellListTest, Constructor) {
|
||||
// Test case 1: Simple case
|
||||
Box box(0, 10, 0, 10, 0, 10);
|
||||
float cutoff = 2.5;
|
||||
CellList cell_list(100, box, cutoff);
|
||||
|
||||
EXPECT_EQ(cell_list.grid_size.x, 4);
|
||||
EXPECT_EQ(cell_list.grid_size.y, 4);
|
||||
EXPECT_EQ(cell_list.grid_size.z, 4);
|
||||
EXPECT_FLOAT_EQ(cell_list.cell_size.x, 2.5);
|
||||
EXPECT_FLOAT_EQ(cell_list.cell_size.y, 2.5);
|
||||
EXPECT_FLOAT_EQ(cell_list.cell_size.z, 2.5);
|
||||
EXPECT_EQ(cell_list.total_cells, 64);
|
||||
}
|
||||
|
||||
TEST(CellListTest, GetCellIndex) {
|
||||
Box box(0, 10, 0, 10, 0, 10);
|
||||
float cutoff = 2.5;
|
||||
CellList cell_list(100, box, cutoff);
|
||||
|
||||
int x = 1, y = 2, z = 3;
|
||||
int expected_index = z * cell_list.grid_size.x * cell_list.grid_size.y +
|
||||
y * cell_list.grid_size.x + x;
|
||||
|
||||
EXPECT_EQ(cell_list.get_cell_index(x, y, z), expected_index);
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue