Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix async return handling #92

Open
wants to merge 3 commits into
base: pdg
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 93 additions & 41 deletions crates/flowistry/src/pdg/construct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ use flowistry_pdg::{CallString, GlobalLocation, RichLocation};
use itertools::Itertools;
use log::{debug, trace};
use petgraph::graph::DiGraph;
use rustc_abi::FieldIdx;
use rustc_abi::{FieldIdx, VariantIdx};
use rustc_borrowck::consumers::{
places_conflict, BodyWithBorrowckFacts, PlaceConflictBias,
};
use rustc_hash::{FxHashMap, FxHashSet};
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_middle::{
mir::{
visit::Visitor, AggregateKind, BasicBlock, Body, Location, Operand, Place, PlaceElem,
Rvalue, Statement, StatementKind, Terminator, TerminatorEdges, TerminatorKind,
RETURN_PLACE,
visit::Visitor, AggregateKind, BasicBlock, Body, HasLocalDecls, Location, Operand,
Place, PlaceElem, Rvalue, Statement, StatementKind, Terminator, TerminatorEdges,
TerminatorKind, RETURN_PLACE,
},
ty::{GenericArg, GenericArgsRef, List, ParamEnv, TyCtxt, TyKind},
};
Expand Down Expand Up @@ -201,6 +201,25 @@ struct CallingContext<'tcx> {
call_stack: Vec<DefId>,
}

/// Stores ids that are needed to construct projections around async functions.
struct AsyncInfo {
poll_ready_variant_idx: VariantIdx,
poll_ready_field_idx: FieldIdx,
}

impl AsyncInfo {
fn make(tcx: TyCtxt) -> Option<Rc<Self>> {
let lang_items = tcx.lang_items();
let poll_def = tcx.adt_def(lang_items.poll()?);
let ready_vid = lang_items.poll_ready_variant()?;
assert_eq!(poll_def.variant_with_id(ready_vid).fields.len(), 1);
Some(Rc::new(Self {
poll_ready_variant_idx: poll_def.variant_index_with_id(ready_vid),
poll_ready_field_idx: 0_u32.into(),
}))
}
}

pub struct GraphConstructor<'tcx> {
tcx: TyCtxt<'tcx>,
params: PdgParams<'tcx>,
Expand All @@ -212,6 +231,7 @@ pub struct GraphConstructor<'tcx> {
body_assignments: utils::BodyAssignments,
calling_context: Option<CallingContext<'tcx>>,
start_loc: FxHashSet<RichLocation>,
async_info: Rc<AsyncInfo>,
}

macro_rules! trylet {
Expand All @@ -226,11 +246,20 @@ macro_rules! trylet {
impl<'tcx> GraphConstructor<'tcx> {
/// Creates a [`GraphConstructor`] at the root of the PDG.
pub fn root(params: PdgParams<'tcx>) -> Self {
GraphConstructor::new(params, None)
let tcx = params.tcx;
GraphConstructor::new(
params,
None,
AsyncInfo::make(tcx).expect("async functions are not defined"),
)
}

/// Creates [`GraphConstructor`] for a function resolved as `fn_resolution` in a given `calling_context`.
fn new(params: PdgParams<'tcx>, calling_context: Option<CallingContext<'tcx>>) -> Self {
fn new(
params: PdgParams<'tcx>,
calling_context: Option<CallingContext<'tcx>>,
async_info: Rc<AsyncInfo>,
) -> Self {
let tcx = params.tcx;
let def_id = params.root.def_id().expect_local();
let body_with_facts = borrowck_facts::get_body_with_borrowck_facts(tcx, def_id);
Expand Down Expand Up @@ -261,6 +290,7 @@ impl<'tcx> GraphConstructor<'tcx> {
def_id,
calling_context,
body_assignments,
async_info,
}
}

Expand Down Expand Up @@ -662,41 +692,57 @@ impl<'tcx> GraphConstructor<'tcx> {
let parent_body = &self.body;
let translate_to_parent = |child: Place<'tcx>| -> Option<Place<'tcx>> {
trace!(" Translating child place: {child:?}");
let (parent_place, child_projection) = if child.local == RETURN_PLACE {
(destination, &child.projection[..])
} else {
match call_kind {
// Map arguments to the argument array
CallKind::Direct => (
args[child.local.as_usize() - 1].place()?,
let (parent_place, child_projection) = match call_kind {
// Async return must be handled special, because it gets wrapped in `Poll::Ready`
CallKind::AsyncPoll if child.local == RETURN_PLACE => {
let async_info = self.async_info.as_ref();
let in_poll = destination.project_deeper(
&[PlaceElem::Downcast(None, async_info.poll_ready_variant_idx)],
tcx,
);
let field_idx = async_info.poll_ready_field_idx;
let child_inner_return_type = in_poll
.ty(parent_body.local_decls(), tcx)
.field_ty(tcx, field_idx);
(
in_poll.project_deeper(
&[PlaceElem::Field(field_idx, child_inner_return_type)],
tcx,
),
&child.projection[..],
),
// Map arguments to projections of the future, the poll's first argument
CallKind::AsyncPoll => {
if child.local.as_usize() == 1 {
let PlaceElem::Field(idx, _) = child.projection[0] else {
panic!("Unexpected non-projection of async context")
};
(args[idx.as_usize()].place()?, &child.projection[1 ..])
} else {
return None;
}
)
}
_ if child.local == RETURN_PLACE => (destination, &child.projection[..]),
// Map arguments to the argument array
CallKind::Direct => (
args[child.local.as_usize() - 1].place()?,
&child.projection[..],
),
// Map arguments to projections of the future, the poll's first argument
CallKind::AsyncPoll => {
if child.local.as_usize() == 1 {
let PlaceElem::Field(idx, _) = child.projection[0] else {
panic!("Unexpected non-projection of async context")
};
(args[idx.as_usize()].place()?, &child.projection[1 ..])
} else {
return None;
}
// Map closure captures to the first argument.
// Map formal parameters to the second argument.
CallKind::Indirect => {
if child.local.as_usize() == 1 {
(args[0].place()?, &child.projection[..])
} else {
let tuple_arg = args[1].place()?;
let _projection = child.projection.to_vec();
let field = FieldIdx::from_usize(child.local.as_usize() - 2);
let field_ty = tuple_arg.ty(parent_body.as_ref(), tcx).field_ty(tcx, field);
(
tuple_arg.project_deeper(&[PlaceElem::Field(field, field_ty)], tcx),
&child.projection[..],
)
}
}
// Map closure captures to the first argument.
// Map formal parameters to the second argument.
CallKind::Indirect => {
if child.local.as_usize() == 1 {
(args[0].place()?, &child.projection[..])
} else {
let tuple_arg = args[1].place()?;
let _projection = child.projection.to_vec();
let field = FieldIdx::from_usize(child.local.as_usize() - 2);
let field_ty = tuple_arg.ty(parent_body.as_ref(), tcx).field_ty(tcx, field);
(
tuple_arg.project_deeper(&[PlaceElem::Field(field, field_ty)], tcx),
&child.projection[..],
)
}
}
};
Expand Down Expand Up @@ -730,7 +776,8 @@ impl<'tcx> GraphConstructor<'tcx> {
param_env,
call_stack,
};
let child_constructor = GraphConstructor::new(params, Some(calling_context));
let child_constructor =
GraphConstructor::new(params, Some(calling_context), self.async_info.clone());

if let Some(callback) = &self.params.call_change_callback {
let info = CallInfo {
Expand Down Expand Up @@ -941,7 +988,12 @@ impl<'tcx> GraphConstructor<'tcx> {
call_string,
call_stack,
};
return GraphConstructor::new(params, Some(calling_context)).construct_partial();
return GraphConstructor::new(
params,
Some(calling_context),
self.async_info.clone(),
)
.construct_partial();
}

let mut analysis = DfAnalysis(self)
Expand Down
Loading