cudaCAC/kernels/neighbor_list.cuh

79 lines
2.4 KiB
Text

#ifndef NEIGHBOR_LIST_CUH
#define NEIGHBOR_LIST_CUH
#include "box.hpp"
#include "kernel_config.cuh"
#include "utils.cuh"
#include <cuda_runtime.h>
#include <thrust/device_vector.h>
#include <thrust/pair.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
/**
* Cell list structure for spatial hashing
*/
struct CellList {
int *cell_starts; // Size: total_cells, start index in sorted_particles
int *cell_count; // Size: total_cells, end index in sorted_particles
int *sorted_particles; // Size: n_particles, particle IDs sorted by cell
int *particle_cells; // Size: n_particles, which cell each particle belongs to
int3 grid_size; // Number of cells in each dimension
float3 cell_size; // Size of each cell
float3 box_min; // Minimum corner of simulation box
size_t total_cells;
size_t n_particles;
CellList(size_t n_particles, Box &box, float cutoff)
: n_particles(n_particles) {
box_min.x = box.xlo;
box_min.y = box.ylo;
box_min.z = box.zlo;
auto [grid_size, cell_size] = calc_grid_and_cell_size(box, cutoff);
total_cells = grid_size.x * grid_size.y * grid_size.z;
cudaMalloc(&cell_starts, total_cells * sizeof(int));
cudaMalloc(&cell_count, total_cells * sizeof(int));
cudaMalloc(&sorted_particles, n_particles * sizeof(int));
cudaMalloc(&particle_cells, n_particles * sizeof(int));
}
~CellList() {
cudaFree(cell_starts);
cudaFree(cell_count);
cudaFree(sorted_particles);
cudaFree(particle_cells);
}
// Get cell index from 3D coordinates
__device__ int get_cell_index(int x, int y, int z) const {
return z * grid_size.x * grid_size.y + y * grid_size.x + x;
}
std::pair<int3, float3> calc_grid_and_cell_size(Box &box, float cutoff) {
int3 grid_size = {utils::max((int)(box.xhi - box.xlo) / cutoff, 1),
utils::max((int)(box.yhi - box.ylo) / cutoff, 1),
utils::max((int)(box.zhi - box.zlo) / cutoff, 1)};
float3 cell_size = {
(box.xhi - box.xlo) / grid_size.x,
(box.yhi - box.ylo) / grid_size.y,
(box.zhi - box.zlo) / grid_size.z,
};
return std::make_pair(grid_size, cell_size);
}
// Get cell coordinates from position
__device__ int3 get_cell_coords(float3 pos) const {
return make_int3((int)((pos.x - box_min.x) / cell_size.x),
(int)((pos.y - box_min.y) / cell_size.y),
(int)((pos.z - box_min.z) / cell_size.z));
}
};
#endif