From b3e6c201d1cddc3fc097b795f0b21deaf5633476 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isa=C3=AFe?= Date: Tue, 29 Oct 2024 18:16:25 +0100 Subject: [PATCH] refactor(kernels::grisubal): rewrite step 1 to enable parallelization (#210) * add intersection metadata prealloc * update tests & cleanup * remove redundant computations * build segment HashMap by collecting an iterator * better format & remove redundant realloc * address feedback * cleaner prefixsum computation * remove an allocation --- honeycomb-kernels/src/grisubal/grid.rs | 2 +- honeycomb-kernels/src/grisubal/kernel.rs | 158 +++++++++++++---------- honeycomb-kernels/src/grisubal/tests.rs | 13 +- 3 files changed, 105 insertions(+), 68 deletions(-) diff --git a/honeycomb-kernels/src/grisubal/grid.rs b/honeycomb-kernels/src/grisubal/grid.rs index a39861b2a..b9ff73003 100644 --- a/honeycomb-kernels/src/grisubal/grid.rs +++ b/honeycomb-kernels/src/grisubal/grid.rs @@ -10,7 +10,7 @@ /// /// Cells `(X, Y)` take value in range `(0, 0)` to `(N, M)`, /// from left to right (X), from bottom to top (Y). -#[derive(PartialEq)] +#[derive(PartialEq, Clone, Copy)] pub struct GridCellId(pub usize, pub usize); impl GridCellId { diff --git a/honeycomb-kernels/src/grisubal/kernel.rs b/honeycomb-kernels/src/grisubal/kernel.rs index d9cb67c28..fbcbbe480 100644 --- a/honeycomb-kernels/src/grisubal/kernel.rs +++ b/honeycomb-kernels/src/grisubal/kernel.rs @@ -193,33 +193,61 @@ pub(super) fn generate_intersection_data( [cx, cy]: [T; 2], origin: Vertex2, ) -> (Segments, Vec<(DartIdentifier, T)>) { - let mut intersection_metadata = Vec::new(); - let mut new_segments = HashMap::with_capacity(geometry.poi.len() * 2); // that *2 has no basis - geometry.segments.iter().for_each(|&(v1_id, v2_id)| { - // fetch vertices of the segment - let Vertex2(ox, oy) = origin; - let (v1, v2) = (&geometry.vertices[v1_id], &geometry.vertices[v2_id]); - // compute their position in the grid - // we assume that the origin of the grid is at (0., 0.) - let (c1, c2) = ( - GridCellId( - ((v1.x() - ox) / cx).floor().to_usize().unwrap(), - ((v1.y() - oy) / cy).floor().to_usize().unwrap(), - ), - GridCellId( - ((v2.x() - ox) / cx).floor().to_usize().unwrap(), - ((v2.y() - oy) / cy).floor().to_usize().unwrap(), - ), - ); + let tmp: Vec<_> = geometry + .segments + .iter() + .map(|&(v1_id, v2_id)| { + // fetch vertices of the segment + let Vertex2(ox, oy) = origin; + let (v1, v2) = (&geometry.vertices[v1_id], &geometry.vertices[v2_id]); + // compute their position in the grid + // we assume that the origin of the grid is at (0., 0.) + let (c1, c2) = ( + GridCellId( + ((v1.x() - ox) / cx).floor().to_usize().unwrap(), + ((v1.y() - oy) / cy).floor().to_usize().unwrap(), + ), + GridCellId( + ((v2.x() - ox) / cx).floor().to_usize().unwrap(), + ((v2.y() - oy) / cy).floor().to_usize().unwrap(), + ), + ); + ( + GridCellId::man_dist(&c1, &c2), + GridCellId::diff(&c1, &c2), + v1, + v2, + v1_id, + v2_id, + c1, + ) + }) + .collect(); + // total number of intersection + let n_intersec: usize = tmp.iter().map(|(dist, _, _, _, _, _, _)| dist).sum(); + // we're using the prefix sum to compute an offset from the start. that's why we need a 0 at the front + // we'll cut off the last element later + let prefix_sum = tmp + .iter() + .map(|(dist, _, _, _, _, _, _)| dist) + .scan(0, |state, &dist| { + *state += dist; + Some(*state - dist) // we want an offset, not the actual sum + }); + // preallocate the intersection vector + let mut intersection_metadata = vec![(NULL_DART_ID, T::nan()); n_intersec]; + + let new_segments: Segments = tmp.iter().zip(prefix_sum).flat_map(|(&(dist, diff, v1, v2, v1_id, v2_id, c1), start)| { + let transform = Box::new(|seg: &[GeometryVertex]| { + assert_eq!(seg.len(), 2); + (seg[0].clone(), seg[1].clone()) + }); // check neighbor status - match GridCellId::man_dist(&c1, &c2) { + match dist { // trivial case: // v1 & v2 belong to the same cell 0 => { - new_segments.insert( - make_geometry_vertex!(geometry, v1_id), - make_geometry_vertex!(geometry, v2_id), - ); + vec![(make_geometry_vertex!(geometry, v1_id), make_geometry_vertex!(geometry, v2_id))] } // ok case: // v1 & v2 belong to neighboring cells @@ -227,8 +255,6 @@ pub(super) fn generate_intersection_data( // fetch base dart of the cell of v1 #[allow(clippy::cast_possible_truncation)] let d_base = (1 + 4 * c1.0 + nx * 4 * c1.1) as DartIdentifier; - // which edge of the cell are we intersecting? - let diff = GridCellId::diff(&c1, &c2); // which dart does this correspond to? #[rustfmt::skip] let dart_id = match diff { @@ -253,27 +279,20 @@ pub(super) fn generate_intersection_data( _ => unreachable!(), }; - // FIXME: these two lines should be atomic - let id = intersection_metadata.len(); - intersection_metadata.push((dart_id, t)); - - new_segments.insert( - make_geometry_vertex!(geometry, v1_id), - GeometryVertex::Intersec(id), - ); - new_segments.insert( - GeometryVertex::Intersec(id), - make_geometry_vertex!(geometry, v2_id), - ); + let id = start; + intersection_metadata[id] = (dart_id, t); + + vec![ + (make_geometry_vertex!(geometry, v1_id), GeometryVertex::Intersec(id)), + (GeometryVertex::Intersec(id), make_geometry_vertex!(geometry, v2_id)), + ] } // highly annoying case: // v1 & v2 do not belong to neighboring cell _ => { - // because we're using strait segments (not curves), the manhattan distance gives us - // the number of cell we're going through to reach v2 from v1 - let diff = GridCellId::diff(&c1, &c2); // pure vertical / horizontal traversal are treated separately because it ensures we're not trying // to compute intersections of parallel segments (which results at best in a division by 0) + let i_ids = start..start+dist; match diff { (i, 0) => { // we can solve the intersection equation @@ -284,7 +303,7 @@ pub(super) fn generate_intersection_data( // i > 0: i_base..i_base + i // or // i < 0: i_base + 1 + i..i_base + 1 - (min(i_base, i_base + 1 + i)..max(i_base + i, i_base + 1)).map(|x| { + (min(i_base, i_base + 1 + i)..max(i_base + i, i_base + 1)).zip(i_ids).map(|(x, id)| { // cell base dart let d_base = (1 + 4 * x + (nx * 4 * c1.1) as isize) as DartIdentifier; @@ -304,12 +323,11 @@ pub(super) fn generate_intersection_data( left_intersec!(v1, v2, v_dart, cy) }; - // FIXME: these two lines should be atomic - let id = intersection_metadata.len(); - intersection_metadata.push((dart_id, t)); + intersection_metadata[id] = (dart_id, t); GeometryVertex::Intersec(id) }); + // because of how the range is written, we need to reverse the iterator in one case // to keep intersection ordered from v1 to v2 (i.e. ensure the segments we build are correct) let mut vs: VecDeque = if i > 0 { @@ -317,11 +335,15 @@ pub(super) fn generate_intersection_data( } else { tmp.rev().collect() }; + + // complete the vertex list vs.push_front(make_geometry_vertex!(geometry, v1_id)); vs.push_back(make_geometry_vertex!(geometry, v2_id)); - vs.make_contiguous().windows(2).for_each(|seg| { - new_segments.insert(seg[0].clone(), seg[1].clone()); - }); + + vs.make_contiguous() + .windows(2) + .map(transform) + .collect::>() } (0, j) => { // we can solve the intersection equation @@ -332,7 +354,7 @@ pub(super) fn generate_intersection_data( // j > 0: j_base..j_base + j // or // j < 0: j_base + 1 + j..j_base + 1 - (min(j_base, j_base + 1 + j)..max(j_base + j, j_base + 1)).map(|y| { + (min(j_base, j_base + 1 + j)..max(j_base + j, j_base + 1)).zip(i_ids).map(|(y, id)| { // cell base dart let d_base = (1 + 4 * c1.0 + nx * 4 * y as usize) as DartIdentifier; // intersected dart @@ -347,12 +369,11 @@ pub(super) fn generate_intersection_data( down_intersec!(v1, v2, v_dart, cx) }; - // FIXME: these two lines should be atomic - let id = intersection_metadata.len(); - intersection_metadata.push((dart_id, t)); + intersection_metadata[id] = (dart_id, t); GeometryVertex::Intersec(id) }); + // because of how the range is written, we need to reverse the iterator in one case // to keep intersection ordered from v1 to v2 (i.e. ensure the segments we build are correct) let mut vs: VecDeque = if j > 0 { @@ -360,13 +381,15 @@ pub(super) fn generate_intersection_data( } else { tmp.rev().collect() }; + // complete the vertex list vs.push_front(make_geometry_vertex!(geometry, v1_id)); vs.push_back(make_geometry_vertex!(geometry, v2_id)); - // insert new segments - vs.make_contiguous().windows(2).for_each(|seg| { - new_segments.insert(seg[0].clone(), seg[1].clone()); - }); + + vs.make_contiguous() + .windows(2) + .map(transform) + .collect::>() } (i, j) => { // in order to process this, we'll consider a "sub-grid" & use the direction of the segment to @@ -454,6 +477,7 @@ pub(super) fn generate_intersection_data( None }) .collect(); + // sort intersections from v1 to v2 intersec_data.retain(|(s, _, _)| (T::zero() <= *s) && (*s <= T::one())); // panic unreachable because of the retain above; there's no s s.t. s == NaN @@ -462,31 +486,34 @@ pub(super) fn generate_intersection_data( // collect geometry vertices let mut vs = vec![make_geometry_vertex!(geometry, v1_id)]; - vs.extend(intersec_data.iter_mut().map(|(_, t, dart_id)| { + vs.extend(intersec_data.iter_mut().zip(i_ids).map(|((_, t, dart_id), id)| { if t.is_zero() { // we assume that the segment fully goes through the corner and does not land exactly // on it, this allows us to compute directly the dart from which the next segment // should start: the one incident to the vertex in the opposite quadrant + + // in that case, the preallocated intersection metadata slot will stay as (0, Nan) + // this is ok, we can simply ignore the entry when processing the data later + let dart_in = *dart_id; GeometryVertex::IntersecCorner(dart_in) } else { - // FIXME: these two lines should be atomic - let id = intersection_metadata.len(); - intersection_metadata.push((*dart_id, *t)); + intersection_metadata[id] = (*dart_id, *t); GeometryVertex::Intersec(id) } })); + vs.push(make_geometry_vertex!(geometry, v2_id)); - // insert segments - vs.windows(2).for_each(|seg| { - new_segments.insert(seg[0].clone(), seg[1].clone()); - }); + + vs.windows(2) + .map(transform) + .collect::>() } } } - }; - }); + } + }).collect(); (new_segments, intersection_metadata) } @@ -499,6 +526,7 @@ pub(super) fn group_intersections_per_edge( HashMap::new(); intersection_metadata .into_iter() + .filter(|(_, t)| !t.is_nan()) .enumerate() .for_each(|(idx, (dart_id, mut t))| { // classify intersections per edge_id & adjust t if needed diff --git a/honeycomb-kernels/src/grisubal/tests.rs b/honeycomb-kernels/src/grisubal/tests.rs index ad0d82347..519b44da8 100644 --- a/honeycomb-kernels/src/grisubal/tests.rs +++ b/honeycomb-kernels/src/grisubal/tests.rs @@ -140,7 +140,6 @@ fn regular_intersections() { generate_intersection_data(&cmap, &geometry, [2, 2], [1.0, 1.0], Vertex2::default()); assert_eq!(intersection_metadata.len(), 4); - // FIXME: INDEX ACCESSES WON'T WORK IN PARALLEL assert_eq!(intersection_metadata[0], (2, 0.5)); assert_eq!(intersection_metadata[1], (7, 0.5)); assert_eq!(intersection_metadata[2], (16, 0.5)); @@ -261,6 +260,8 @@ fn regular_intersections() { #[test] fn corner_intersection() { + use num_traits::Float; + let mut cmap = CMapBuilder::from( GridDescriptor::default() .len_per_cell([1.0; 3]) @@ -280,7 +281,15 @@ fn corner_intersection() { let (segments, intersection_metadata) = generate_intersection_data(&cmap, &geometry, [2, 2], [1.0, 1.0], Vertex2::default()); - assert_eq!(intersection_metadata.len(), 2); + // because we intersec a corner, some entries were preallocated but not needed. + // entries were initialized with (0, Nan), so they're easy to filter + assert_eq!( + intersection_metadata + .iter() + .filter(|(_, t)| !t.is_nan()) + .count(), + 2 + ); assert_eq!(intersection_metadata[0], (2, 0.5)); assert_eq!(intersection_metadata[1], (7, 0.5));