diff --git a/kernels/neighbor_list.cuh b/kernels/neighbor_list.cuh index 14d1201..6c02378 100644 --- a/kernels/neighbor_list.cuh +++ b/kernels/neighbor_list.cuh @@ -76,12 +76,21 @@ struct CellList { return std::make_pair(grid_size, cell_size); } - // Get cell coordinates from position - __device__ int3 get_cell_coords(float3 pos) const { + __device__ int3 get_cell_coords_from_position(float3 pos) const { return make_int3((int)((pos.x - box_min.x) / cell_size.x), (int)((pos.y - box_min.y) / cell_size.y), (int)((pos.z - box_min.z) / cell_size.z)); } + + __device__ int get_cell_index_from_position(float3 pos) const { + return get_cell_index_from_cell_coords(get_cell_coords_from_position(pos)); + } + + __device__ void assign_particles_to_cells(float3 *positions) { + for (int i = 0; i < this->n_particles; i++) { + particle_cells[i] = get_cell_index_from_position(positions[i]); + } + } }; #endif