Skip to content

Commit

Permalink
refactor(kernels::grisubal): rewrite step 1 to enable parallelization (
Browse files Browse the repository at this point in the history
…#210)

* add intersection metadata prealloc

* update tests & cleanup

* remove redundant computations

* build segment HashMap by collecting an iterator

* better format & remove redundant realloc

* address feedback

* cleaner prefixsum computation

* remove an allocation
  • Loading branch information
imrn99 authored Oct 29, 2024
1 parent 4c16750 commit b3e6c20
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 68 deletions.
2 changes: 1 addition & 1 deletion honeycomb-kernels/src/grisubal/grid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
///
/// Cells `(X, Y)` take value in range `(0, 0)` to `(N, M)`,
/// from left to right (X), from bottom to top (Y).
#[derive(PartialEq)]
#[derive(PartialEq, Clone, Copy)]
pub struct GridCellId(pub usize, pub usize);

impl GridCellId {
Expand Down
158 changes: 93 additions & 65 deletions honeycomb-kernels/src/grisubal/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,42 +193,68 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
[cx, cy]: [T; 2],
origin: Vertex2<T>,
) -> (Segments, Vec<(DartIdentifier, T)>) {
let mut intersection_metadata = Vec::new();
let mut new_segments = HashMap::with_capacity(geometry.poi.len() * 2); // that *2 has no basis
geometry.segments.iter().for_each(|&(v1_id, v2_id)| {
// fetch vertices of the segment
let Vertex2(ox, oy) = origin;
let (v1, v2) = (&geometry.vertices[v1_id], &geometry.vertices[v2_id]);
// compute their position in the grid
// we assume that the origin of the grid is at (0., 0.)
let (c1, c2) = (
GridCellId(
((v1.x() - ox) / cx).floor().to_usize().unwrap(),
((v1.y() - oy) / cy).floor().to_usize().unwrap(),
),
GridCellId(
((v2.x() - ox) / cx).floor().to_usize().unwrap(),
((v2.y() - oy) / cy).floor().to_usize().unwrap(),
),
);
let tmp: Vec<_> = geometry
.segments
.iter()
.map(|&(v1_id, v2_id)| {
// fetch vertices of the segment
let Vertex2(ox, oy) = origin;
let (v1, v2) = (&geometry.vertices[v1_id], &geometry.vertices[v2_id]);
// compute their position in the grid
// we assume that the origin of the grid is at (0., 0.)
let (c1, c2) = (
GridCellId(
((v1.x() - ox) / cx).floor().to_usize().unwrap(),
((v1.y() - oy) / cy).floor().to_usize().unwrap(),
),
GridCellId(
((v2.x() - ox) / cx).floor().to_usize().unwrap(),
((v2.y() - oy) / cy).floor().to_usize().unwrap(),
),
);
(
GridCellId::man_dist(&c1, &c2),
GridCellId::diff(&c1, &c2),
v1,
v2,
v1_id,
v2_id,
c1,
)
})
.collect();
// total number of intersection
let n_intersec: usize = tmp.iter().map(|(dist, _, _, _, _, _, _)| dist).sum();
// we're using the prefix sum to compute an offset from the start. that's why we need a 0 at the front
// we'll cut off the last element later
let prefix_sum = tmp
.iter()
.map(|(dist, _, _, _, _, _, _)| dist)
.scan(0, |state, &dist| {
*state += dist;
Some(*state - dist) // we want an offset, not the actual sum
});
// preallocate the intersection vector
let mut intersection_metadata = vec![(NULL_DART_ID, T::nan()); n_intersec];

let new_segments: Segments = tmp.iter().zip(prefix_sum).flat_map(|(&(dist, diff, v1, v2, v1_id, v2_id, c1), start)| {
let transform = Box::new(|seg: &[GeometryVertex]| {
assert_eq!(seg.len(), 2);
(seg[0].clone(), seg[1].clone())
});
// check neighbor status
match GridCellId::man_dist(&c1, &c2) {
match dist {
// trivial case:
// v1 & v2 belong to the same cell
0 => {
new_segments.insert(
make_geometry_vertex!(geometry, v1_id),
make_geometry_vertex!(geometry, v2_id),
);
vec![(make_geometry_vertex!(geometry, v1_id), make_geometry_vertex!(geometry, v2_id))]
}
// ok case:
// v1 & v2 belong to neighboring cells
1 => {
// fetch base dart of the cell of v1
#[allow(clippy::cast_possible_truncation)]
let d_base = (1 + 4 * c1.0 + nx * 4 * c1.1) as DartIdentifier;
// which edge of the cell are we intersecting?
let diff = GridCellId::diff(&c1, &c2);
// which dart does this correspond to?
#[rustfmt::skip]
let dart_id = match diff {
Expand All @@ -253,27 +279,20 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
_ => unreachable!(),
};

// FIXME: these two lines should be atomic
let id = intersection_metadata.len();
intersection_metadata.push((dart_id, t));

new_segments.insert(
make_geometry_vertex!(geometry, v1_id),
GeometryVertex::Intersec(id),
);
new_segments.insert(
GeometryVertex::Intersec(id),
make_geometry_vertex!(geometry, v2_id),
);
let id = start;
intersection_metadata[id] = (dart_id, t);

vec![
(make_geometry_vertex!(geometry, v1_id), GeometryVertex::Intersec(id)),
(GeometryVertex::Intersec(id), make_geometry_vertex!(geometry, v2_id)),
]
}
// highly annoying case:
// v1 & v2 do not belong to neighboring cell
_ => {
// because we're using strait segments (not curves), the manhattan distance gives us
// the number of cell we're going through to reach v2 from v1
let diff = GridCellId::diff(&c1, &c2);
// pure vertical / horizontal traversal are treated separately because it ensures we're not trying
// to compute intersections of parallel segments (which results at best in a division by 0)
let i_ids = start..start+dist;
match diff {
(i, 0) => {
// we can solve the intersection equation
Expand All @@ -284,7 +303,7 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
// i > 0: i_base..i_base + i
// or
// i < 0: i_base + 1 + i..i_base + 1
(min(i_base, i_base + 1 + i)..max(i_base + i, i_base + 1)).map(|x| {
(min(i_base, i_base + 1 + i)..max(i_base + i, i_base + 1)).zip(i_ids).map(|(x, id)| {
// cell base dart
let d_base =
(1 + 4 * x + (nx * 4 * c1.1) as isize) as DartIdentifier;
Expand All @@ -304,24 +323,27 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
left_intersec!(v1, v2, v_dart, cy)
};

// FIXME: these two lines should be atomic
let id = intersection_metadata.len();
intersection_metadata.push((dart_id, t));
intersection_metadata[id] = (dart_id, t);

GeometryVertex::Intersec(id)
});

// because of how the range is written, we need to reverse the iterator in one case
// to keep intersection ordered from v1 to v2 (i.e. ensure the segments we build are correct)
let mut vs: VecDeque<GeometryVertex> = if i > 0 {
tmp.collect()
} else {
tmp.rev().collect()
};

// complete the vertex list
vs.push_front(make_geometry_vertex!(geometry, v1_id));
vs.push_back(make_geometry_vertex!(geometry, v2_id));
vs.make_contiguous().windows(2).for_each(|seg| {
new_segments.insert(seg[0].clone(), seg[1].clone());
});

vs.make_contiguous()
.windows(2)
.map(transform)
.collect::<Vec<_>>()
}
(0, j) => {
// we can solve the intersection equation
Expand All @@ -332,7 +354,7 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
// j > 0: j_base..j_base + j
// or
// j < 0: j_base + 1 + j..j_base + 1
(min(j_base, j_base + 1 + j)..max(j_base + j, j_base + 1)).map(|y| {
(min(j_base, j_base + 1 + j)..max(j_base + j, j_base + 1)).zip(i_ids).map(|(y, id)| {
// cell base dart
let d_base = (1 + 4 * c1.0 + nx * 4 * y as usize) as DartIdentifier;
// intersected dart
Expand All @@ -347,26 +369,27 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
down_intersec!(v1, v2, v_dart, cx)
};

// FIXME: these two lines should be atomic
let id = intersection_metadata.len();
intersection_metadata.push((dart_id, t));
intersection_metadata[id] = (dart_id, t);

GeometryVertex::Intersec(id)
});

// because of how the range is written, we need to reverse the iterator in one case
// to keep intersection ordered from v1 to v2 (i.e. ensure the segments we build are correct)
let mut vs: VecDeque<GeometryVertex> = if j > 0 {
tmp.collect()
} else {
tmp.rev().collect()
};

// complete the vertex list
vs.push_front(make_geometry_vertex!(geometry, v1_id));
vs.push_back(make_geometry_vertex!(geometry, v2_id));
// insert new segments
vs.make_contiguous().windows(2).for_each(|seg| {
new_segments.insert(seg[0].clone(), seg[1].clone());
});

vs.make_contiguous()
.windows(2)
.map(transform)
.collect::<Vec<_>>()
}
(i, j) => {
// in order to process this, we'll consider a "sub-grid" & use the direction of the segment to
Expand Down Expand Up @@ -454,6 +477,7 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(
None
})
.collect();

// sort intersections from v1 to v2
intersec_data.retain(|(s, _, _)| (T::zero() <= *s) && (*s <= T::one()));
// panic unreachable because of the retain above; there's no s s.t. s == NaN
Expand All @@ -462,31 +486,34 @@ pub(super) fn generate_intersection_data<T: CoordsFloat>(

// collect geometry vertices
let mut vs = vec![make_geometry_vertex!(geometry, v1_id)];
vs.extend(intersec_data.iter_mut().map(|(_, t, dart_id)| {
vs.extend(intersec_data.iter_mut().zip(i_ids).map(|((_, t, dart_id), id)| {
if t.is_zero() {
// we assume that the segment fully goes through the corner and does not land exactly
// on it, this allows us to compute directly the dart from which the next segment
// should start: the one incident to the vertex in the opposite quadrant

// in that case, the preallocated intersection metadata slot will stay as (0, Nan)
// this is ok, we can simply ignore the entry when processing the data later

let dart_in = *dart_id;
GeometryVertex::IntersecCorner(dart_in)
} else {
// FIXME: these two lines should be atomic
let id = intersection_metadata.len();
intersection_metadata.push((*dart_id, *t));
intersection_metadata[id] = (*dart_id, *t);

GeometryVertex::Intersec(id)
}
}));

vs.push(make_geometry_vertex!(geometry, v2_id));
// insert segments
vs.windows(2).for_each(|seg| {
new_segments.insert(seg[0].clone(), seg[1].clone());
});

vs.windows(2)
.map(transform)
.collect::<Vec<_>>()
}
}
}
};
});
}
}).collect();
(new_segments, intersection_metadata)
}

Expand All @@ -499,6 +526,7 @@ pub(super) fn group_intersections_per_edge<T: CoordsFloat>(
HashMap::new();
intersection_metadata
.into_iter()
.filter(|(_, t)| !t.is_nan())
.enumerate()
.for_each(|(idx, (dart_id, mut t))| {
// classify intersections per edge_id & adjust t if needed
Expand Down
13 changes: 11 additions & 2 deletions honeycomb-kernels/src/grisubal/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ fn regular_intersections() {
generate_intersection_data(&cmap, &geometry, [2, 2], [1.0, 1.0], Vertex2::default());

assert_eq!(intersection_metadata.len(), 4);
// FIXME: INDEX ACCESSES WON'T WORK IN PARALLEL
assert_eq!(intersection_metadata[0], (2, 0.5));
assert_eq!(intersection_metadata[1], (7, 0.5));
assert_eq!(intersection_metadata[2], (16, 0.5));
Expand Down Expand Up @@ -261,6 +260,8 @@ fn regular_intersections() {

#[test]
fn corner_intersection() {
use num_traits::Float;

let mut cmap = CMapBuilder::from(
GridDescriptor::default()
.len_per_cell([1.0; 3])
Expand All @@ -280,7 +281,15 @@ fn corner_intersection() {
let (segments, intersection_metadata) =
generate_intersection_data(&cmap, &geometry, [2, 2], [1.0, 1.0], Vertex2::default());

assert_eq!(intersection_metadata.len(), 2);
// because we intersec a corner, some entries were preallocated but not needed.
// entries were initialized with (0, Nan), so they're easy to filter
assert_eq!(
intersection_metadata
.iter()
.filter(|(_, t)| !t.is_nan())
.count(),
2
);
assert_eq!(intersection_metadata[0], (2, 0.5));
assert_eq!(intersection_metadata[1], (7, 0.5));

Expand Down

0 comments on commit b3e6c20

Please sign in to comment.