Skip to content

Commit

Permalink
Automatically generate MainThreadMarker argument in methods
Browse files Browse the repository at this point in the history
  • Loading branch information
madsmtm committed Sep 3, 2023
1 parent ee5cd91 commit 69cce30
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 52 deletions.
89 changes: 86 additions & 3 deletions crates/header-translator/src/cache.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::BTreeMap;
use std::collections::{BTreeMap, BTreeSet};
use std::mem;

use crate::config::Config;
Expand All @@ -7,16 +7,38 @@ use crate::id::ItemIdentifier;
use crate::method::Method;
use crate::output::Output;
use crate::stmt::Stmt;
use crate::Mutability;

/// A helper struct for doing global analysis on the output.
#[derive(Debug, PartialEq, Clone)]
pub struct Cache<'a> {
config: &'a Config,
mainthreadonly_classes: BTreeSet<ItemIdentifier>,
}

impl<'a> Cache<'a> {
pub fn new(_output: &Output, config: &'a Config) -> Self {
Self { config }
pub fn new(output: &Output, config: &'a Config) -> Self {
let mut mainthreadonly_classes = BTreeSet::new();

for library in output.libraries.values() {
for file in library.files.values() {
for stmt in file.stmts.iter() {
if let Stmt::ClassDecl {
id,
mutability: Mutability::MainThreadOnly,
..
} = stmt
{
mainthreadonly_classes.insert(id.clone());
}
}
}
}

Self {
config,
mainthreadonly_classes,
}
}

pub fn update(&self, output: &mut Output) {
Expand Down Expand Up @@ -68,6 +90,67 @@ impl<'a> Cache<'a> {
}
}

// Add `mainthreadonly` to relevant methods
for stmt in file.stmts.iter_mut() {
match stmt {
Stmt::Methods {
cls: id, methods, ..
}
| Stmt::ProtocolDecl { id, methods, .. } => {
for method in methods.iter_mut() {
let mut result_type_contains_mainthreadonly: bool = false;
method.result_type.visit_required_types(&mut |id| {
if self.mainthreadonly_classes.contains(id) {
result_type_contains_mainthreadonly = true;
}
});

match (method.is_class, self.mainthreadonly_classes.contains(id)) {
// MainThreadOnly class with static method
(true, true) => {
// Assume the method needs main thread
result_type_contains_mainthreadonly = true;
}
// Class with static method
(true, false) => {
// Continue with the normal check
}
// MainThreadOnly class with non-static method
(false, true) => {
// Method is already required to run on main
// thread, so no need to add MainThreadMarker
continue;
}
// Class with non-static method
(false, false) => {
// Continue with the normal check
}
}

if result_type_contains_mainthreadonly {
let mut any_argument_contains_mainthreadonly: bool = false;
for (_, argument) in method.arguments.iter() {
// Important: We only visit the top-level types, to not
// include e.g. `Option<&NSView>` or `&NSArray<NSView>`.
argument.visit_toplevel_types(&mut |id| {
if self.mainthreadonly_classes.contains(id) {
any_argument_contains_mainthreadonly = true;
}
});
}

// Apply main thread only, unless a (required)
// argument was main thread only.
if !any_argument_contains_mainthreadonly {
method.mainthreadonly = true;
}
}
}
}
_ => {}
}
}

// Fix up a few typedef + enum declarations
let mut iter = mem::take(&mut file.stmts).into_iter().peekable();
while let Some(stmt) = iter.next() {
Expand Down
8 changes: 8 additions & 0 deletions crates/header-translator/src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ impl ItemIdentifier {
self.library == "Foundation" && self.name == "NSComparator"
}

pub fn main_thread_marker() -> Self {
Self {
name: "NSThread".to_string(),
library: "Foundation".to_string(),
file_name: Some("NSThread".to_string()),
}
}

pub fn feature(&self) -> Option<impl fmt::Display + '_> {
struct ItemIdentifierFeature<'a>(&'a ItemIdentifier);

Expand Down
19 changes: 16 additions & 3 deletions crates/header-translator/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,14 @@ pub struct Method {
pub is_class: bool,
is_optional_protocol: bool,
memory_management: MemoryManagement,
arguments: Vec<(String, Ty)>,
pub(crate) arguments: Vec<(String, Ty)>,
pub result_type: Ty,
safe: bool,
mutating: bool,
is_protocol: bool,
// Thread-safe, even on main-thread only (@MainActor/@UIActor) classes
non_isolated: bool,
mainthreadonly: bool,
pub(crate) mainthreadonly: bool,
}

impl Method {
Expand Down Expand Up @@ -349,6 +349,10 @@ impl Method {
}

self.result_type.visit_required_types(&mut f);

if self.mainthreadonly {
f(&ItemIdentifier::main_thread_marker())
}
}
}

Expand Down Expand Up @@ -636,6 +640,11 @@ impl fmt::Display for Method {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let _span = debug_span!("method", self.fn_name).entered();

// TODO: Use this somehow?
// if self.non_isolated {
// writeln!(f, "// non_isolated")?;
// }

//
// Attributes
//
Expand Down Expand Up @@ -689,7 +698,11 @@ impl fmt::Display for Method {
// Arguments
for (param, arg_ty) in &self.arguments {
let param = handle_reserved(&crate::to_snake_case(param));
write!(f, "{param}: {arg_ty},")?;
write!(f, "{param}: {arg_ty}, ")?;
}
// FIXME: Skipping main thread only on protocols for now
if self.mainthreadonly && !self.is_protocol {
write!(f, "mtm: MainThreadMarker")?;
}
write!(f, ")")?;

Expand Down
33 changes: 33 additions & 0 deletions crates/header-translator/src/rust_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,12 @@ impl IdType {
}
}
}

fn visit_toplevel_types(&self, f: &mut impl FnMut(&ItemIdentifier)) {
if let Some(id) = self._id() {
f(id);
}
}
}

impl fmt::Display for IdType {
Expand Down Expand Up @@ -1036,6 +1042,25 @@ impl Inner {
_ => {}
}
}

pub fn visit_toplevel_types(&self, f: &mut impl FnMut(&ItemIdentifier)) {
match self {
Self::Id { ty, .. } => {
ty.visit_toplevel_types(f);
}
Self::Pointer {
// Only visit non-null types
nullability: Nullability::NonNull,
is_const: _,
pointee,
} => {
pointee.visit_toplevel_types(f);
}
// TODO
Self::TypeDef { id } => f(id),
_ => {}
}
}
}

/// This is sound to output in (almost, c_void is not a valid return type) any
Expand Down Expand Up @@ -1492,6 +1517,14 @@ impl Ty {

self.ty.visit_required_types(f);
}

pub fn visit_toplevel_types(&self, f: &mut impl FnMut(&ItemIdentifier)) {
if let TyKind::MethodReturn { with_error: true } = &self.kind {
f(&ItemIdentifier::nserror());
}

self.ty.visit_toplevel_types(f);
}
}

impl Ty {
Expand Down
7 changes: 5 additions & 2 deletions crates/icrate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ unstable-example-basic_usage = [
unstable-example-delegate = [
"apple",
"Foundation",
"Foundation_NSString",
"Foundation_NSNotification",
"Foundation_NSString",
"Foundation_NSThread",
"AppKit",
"AppKit_NSApplication",
]
Expand Down Expand Up @@ -165,6 +166,7 @@ unstable-example-browser = [
"Foundation",
"Foundation_NSNotification",
"Foundation_NSString",
"Foundation_NSThread",
"Foundation_NSURL",
"Foundation_NSURLRequest",
"WebKit",
Expand All @@ -177,10 +179,11 @@ unstable-example-metal = [
"AppKit_NSWindow",
"Foundation",
"Foundation_NSCoder",
"Foundation_NSDate",
"Foundation_NSError",
"Foundation_NSNotification",
"Foundation_NSString",
"Foundation_NSDate",
"Foundation_NSThread",
"Metal",
"Metal_MTLCompileOptions",
"Metal_MTLRenderPassDescriptor",
Expand Down
Loading

0 comments on commit 69cce30

Please sign in to comment.