diff --git a/kernels/neighbor_list.cuh b/kernels/neighbor_list.cuh index 6c02378..35f4bb8 100644 --- a/kernels/neighbor_list.cuh +++ b/kernels/neighbor_list.cuh @@ -6,6 +6,8 @@ #include "utils.cuh" #include #include +#include +#include #include #include #include @@ -76,20 +78,40 @@ struct CellList { return std::make_pair(grid_size, cell_size); } - __device__ int3 get_cell_coords_from_position(float3 pos) const { + __host__ __device__ int3 get_cell_coords_from_position(float3 pos) const { return make_int3((int)((pos.x - box_min.x) / cell_size.x), (int)((pos.y - box_min.y) / cell_size.y), (int)((pos.z - box_min.z) / cell_size.z)); } - __device__ int get_cell_index_from_position(float3 pos) const { + __host__ __device__ int get_cell_index_from_position(float3 pos) const { return get_cell_index_from_cell_coords(get_cell_coords_from_position(pos)); } - __device__ void assign_particles_to_cells(float3 *positions) { - for (int i = 0; i < this->n_particles; i++) { + void assign_particles_to_cells(float3 *positions) { + thrust::device_ptr particle_cells_ptr(particle_cells); + thrust::device_ptr sorted_particles_ptr(sorted_particles); + thrust::device_ptr cell_starts_ptr(cell_starts); + thrust::device_ptr cell_count_ptr(cell_count); + + thrust::sequence(sorted_particles_ptr, sorted_particles_ptr + n_particles); + + for (size_t i = 0; i < n_particles; i++) { particle_cells[i] = get_cell_index_from_position(positions[i]); } + + thrust::sort_by_key(particle_cells_ptr, particle_cells_ptr + n_particles, + sorted_particles_ptr); + + thrust::fill(cell_starts_ptr, cell_starts_ptr + total_cells, 0); + thrust::fill(cell_count_ptr, cell_count_ptr + total_cells, 0); + + thrust::reduce_by_key(particle_cells_ptr, particle_cells_ptr + n_particles, + thrust::constant_iterator(1), + thrust::discard_iterator(), cell_count_ptr); + + thrust::exclusive_scan(cell_count_ptr, cell_count_ptr + total_cells, + cell_starts_ptr); } };