Skip to content

Commit

Permalink
Automatically generate MainThreadMarker argument in methods
Browse files Browse the repository at this point in the history
  • Loading branch information
madsmtm committed Sep 1, 2023
1 parent 33c3199 commit 3e6d97f
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 32 deletions.
79 changes: 76 additions & 3 deletions crates/header-translator/src/cache.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::BTreeMap;
use std::collections::{BTreeMap, BTreeSet};
use std::mem;

use crate::config::Config;
Expand All @@ -7,16 +7,38 @@ use crate::id::ItemIdentifier;
use crate::method::Method;
use crate::output::Output;
use crate::stmt::Stmt;
use crate::Mutability;

/// A helper struct for doing global analysis on the output.
#[derive(Debug, PartialEq, Clone)]
pub struct Cache<'a> {
config: &'a Config,
mainthreadonly_classes: BTreeSet<ItemIdentifier>,
}

impl<'a> Cache<'a> {
pub fn new(_output: &Output, config: &'a Config) -> Self {
Self { config }
pub fn new(output: &Output, config: &'a Config) -> Self {
let mut mainthreadonly_classes = BTreeSet::new();

for library in output.libraries.values() {
for file in library.files.values() {
for stmt in file.stmts.iter() {
if let Stmt::ClassDecl {
id,
mutability: Mutability::MainThreadOnly,
..
} = stmt
{
mainthreadonly_classes.insert(id.clone());
}
}
}
}

Self {
config,
mainthreadonly_classes,
}
}

pub fn update(&self, output: &mut Output) {
Expand Down Expand Up @@ -68,6 +90,57 @@ impl<'a> Cache<'a> {
}
}

// Add `mainthreadonly` to relevant methods
for stmt in file.stmts.iter_mut() {
match stmt {
Stmt::Methods {
cls: id, methods, ..
}
| Stmt::ProtocolDecl { id, methods, .. } => {
for method in methods.iter_mut() {
let mut result_type_contains_mainthreadonly: bool = false;
method.result_type.visit_required_types(&mut |id| {
if self.mainthreadonly_classes.contains(id) {
result_type_contains_mainthreadonly = true;
}
});

match (method.is_class, self.mainthreadonly_classes.contains(id)) {
// MainThreadOnly class with static method; assume the method is main thread only
(true, true) => {
result_type_contains_mainthreadonly = true;
}
// Class with non-static method; do the normal check
(true, false) => {}
// MainThreadOnly class with static method; skip the check
(false, true) => {
continue;
}
// Class with non-static method; do the normal check
(false, false) => {}
}

if !result_type_contains_mainthreadonly {
continue;
}

let mut any_argument_contains_mainthreadonly: bool = false;
for (_, argument) in method.arguments.iter() {
argument.visit_toplevel_types(&mut |id| {
if self.mainthreadonly_classes.contains(id) {
any_argument_contains_mainthreadonly = true;
}
});
}
if !any_argument_contains_mainthreadonly {
method.mainthreadonly = true;
}
}
}
_ => {}
}
}

// Fix up a few typedef + enum declarations
let mut iter = mem::take(&mut file.stmts).into_iter().peekable();
while let Some(stmt) = iter.next() {
Expand Down
15 changes: 12 additions & 3 deletions crates/header-translator/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,14 @@ pub struct Method {
pub is_class: bool,
is_optional_protocol: bool,
memory_management: MemoryManagement,
arguments: Vec<(String, Ty)>,
pub(crate) arguments: Vec<(String, Ty)>,
pub result_type: Ty,
safe: bool,
mutating: bool,
is_protocol: bool,
// Thread-safe, even on main-thread only (@MainActor/@UIActor) classes
non_isolated: bool,
mainthreadonly: bool,
pub(crate) mainthreadonly: bool,
}

impl Method {
Expand Down Expand Up @@ -636,6 +636,11 @@ impl fmt::Display for Method {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let _span = debug_span!("method", self.fn_name).entered();

// TODO: Use this somehow?
// if self.non_isolated {
// writeln!(f, "// non_isolated")?;
// }

//
// Attributes
//
Expand Down Expand Up @@ -689,7 +694,11 @@ impl fmt::Display for Method {
// Arguments
for (param, arg_ty) in &self.arguments {
let param = handle_reserved(&crate::to_snake_case(param));
write!(f, "{param}: {arg_ty},")?;
write!(f, "{param}: {arg_ty}, ")?;
}
// FIXME: Skipping main thread only on protocols for now
if self.mainthreadonly && !self.is_protocol {
write!(f, "mtm: MainThreadMarker")?;
}
write!(f, ")")?;

Expand Down
33 changes: 33 additions & 0 deletions crates/header-translator/src/rust_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,12 @@ impl IdType {
}
}
}

fn visit_toplevel_types(&self, f: &mut impl FnMut(&ItemIdentifier)) {
if let Some(id) = self._id() {
f(id);
}
}
}

impl fmt::Display for IdType {
Expand Down Expand Up @@ -1036,6 +1042,25 @@ impl Inner {
_ => {}
}
}

pub fn visit_toplevel_types(&self, f: &mut impl FnMut(&ItemIdentifier)) {
match self {
Self::Id { ty, .. } => {
ty.visit_toplevel_types(f);
}
Self::Pointer {
// Only visit non-null types
nullability: Nullability::NonNull,
is_const: _,
pointee,
} => {
pointee.visit_toplevel_types(f);
}
// TODO
Self::TypeDef { id } => f(id),
_ => {}
}
}
}

/// This is sound to output in (almost, c_void is not a valid return type) any
Expand Down Expand Up @@ -1492,6 +1517,14 @@ impl Ty {

self.ty.visit_required_types(f);
}

pub fn visit_toplevel_types(&self, f: &mut impl FnMut(&ItemIdentifier)) {
if let TyKind::MethodReturn { with_error: true } = &self.kind {
f(&ItemIdentifier::nserror());
}

self.ty.visit_toplevel_types(f);
}
}

impl Ty {
Expand Down
26 changes: 12 additions & 14 deletions crates/icrate/examples/browser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use icrate::{
NSWindowStyleMaskResizable, NSWindowStyleMaskTitled,
},
Foundation::{
ns_string, NSNotification, NSObject, NSObjectProtocol, NSPoint, NSRect, NSSize,
NSURLRequest, NSURL,
ns_string, MainThreadMarker, NSNotification, NSObject, NSObjectProtocol, NSPoint, NSRect,
NSSize, NSURLRequest, NSURL,
},
WebKit::{WKNavigation, WKNavigationDelegate, WKWebView},
};
Expand Down Expand Up @@ -68,9 +68,9 @@ declare_class!(
#[method(applicationDidFinishLaunching:)]
#[allow(non_snake_case)]
unsafe fn applicationDidFinishLaunching(&self, _notification: &NSNotification) {
let mtm = MainThreadMarker::from(self);
// create the app window
let window = {
let this = NSWindow::alloc();
let content_rect = NSRect::new(NSPoint::new(0., 0.), NSSize::new(1024., 768.));
let style = NSWindowStyleMaskClosable
| NSWindowStyleMaskResizable
Expand All @@ -79,7 +79,7 @@ declare_class!(
let flag = false;
unsafe {
NSWindow::initWithContentRect_styleMask_backing_defer(
this,
mtm.alloc(),
content_rect,
style,
backing_store_type,
Expand All @@ -90,16 +90,14 @@ declare_class!(

// create the web view
let web_view = {
let this = WKWebView::alloc();
let frame_rect = NSRect::ZERO;
unsafe { WKWebView::initWithFrame(this, frame_rect) }
unsafe { WKWebView::initWithFrame(mtm.alloc(), frame_rect) }
};

// create the nav bar view
let nav_bar = {
let frame_rect = NSRect::ZERO;
let this = NSStackView::alloc();
let this = unsafe { NSStackView::initWithFrame(this, frame_rect) };
let this = unsafe { NSStackView::initWithFrame(mtm.alloc(), frame_rect) };
unsafe {
this.setOrientation(NSUserInterfaceLayoutOrientationHorizontal);
this.setAlignment(NSLayoutAttributeHeight);
Expand All @@ -112,8 +110,7 @@ declare_class!(
// create the nav buttons view
let nav_buttons = {
let frame_rect = NSRect::ZERO;
let this = NSStackView::alloc();
let this = unsafe { NSStackView::initWithFrame(this, frame_rect) };
let this = unsafe { NSStackView::initWithFrame(mtm.alloc(), frame_rect) };
unsafe {
this.setOrientation(NSUserInterfaceLayoutOrientationHorizontal);
this.setAlignment(NSLayoutAttributeHeight);
Expand All @@ -130,7 +127,7 @@ declare_class!(
let target = Some::<&AnyObject>(&web_view);
let action = Some(sel!(goBack));
let this =
unsafe { NSButton::buttonWithTitle_target_action(title, target, action) };
unsafe { NSButton::buttonWithTitle_target_action(title, target, action, mtm) };
unsafe { this.setBezelStyle(NSBezelStyleShadowlessSquare) };
this
};
Expand All @@ -142,7 +139,7 @@ declare_class!(
let target = Some::<&AnyObject>(&web_view);
let action = Some(sel!(goForward));
let this =
unsafe { NSButton::buttonWithTitle_target_action(title, target, action) };
unsafe { NSButton::buttonWithTitle_target_action(title, target, action, mtm) };
unsafe { this.setBezelStyle(NSBezelStyleShadowlessSquare) };
this
};
Expand Down Expand Up @@ -217,7 +214,7 @@ declare_class!(
menu_app_item.setSubmenu(Some(&menu_app_menu));
menu.addItem(&menu_app_item);

let app = NSApplication::sharedApplication();
let app = NSApplication::sharedApplication(mtm);
app.setMainMenu(Some(&menu));
}

Expand Down Expand Up @@ -294,7 +291,8 @@ impl Delegate {
unsafe impl NSObjectProtocol for Delegate {}

fn main() {
let app = unsafe { NSApplication::sharedApplication() };
let mtm = MainThreadMarker::new().unwrap();
let app = unsafe { NSApplication::sharedApplication(mtm) };
unsafe { app.setActivationPolicy(NSApplicationActivationPolicyRegular) };

// initialize the delegate
Expand Down
12 changes: 7 additions & 5 deletions crates/icrate/examples/delegate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ptr::NonNull;

use icrate::AppKit::{NSApplication, NSApplicationActivationPolicyRegular, NSApplicationDelegate};
use icrate::Foundation::{
ns_string, NSCopying, NSNotification, NSObject, NSObjectProtocol, NSString,
ns_string, MainThreadMarker, NSCopying, NSNotification, NSObject, NSObjectProtocol, NSString,
};
use objc2::declare::{Ivar, IvarBool, IvarDrop, IvarEncode};
use objc2::rc::Id;
Expand Down Expand Up @@ -68,17 +68,19 @@ declare_class!(
unsafe impl NSObjectProtocol for AppDelegate {}

impl AppDelegate {
pub fn new(ivar: u8, another_ivar: bool) -> Id<Self> {
unsafe { msg_send_id![Self::alloc(), initWith: ivar, another: another_ivar] }
pub fn new(ivar: u8, another_ivar: bool, mtm: MainThreadMarker) -> Id<Self> {
unsafe { msg_send_id![mtm.alloc(), initWith: ivar, another: another_ivar] }
}
}

fn main() {
let app = unsafe { NSApplication::sharedApplication() };
let mtm: MainThreadMarker = MainThreadMarker::new().unwrap();

let app = unsafe { NSApplication::sharedApplication(mtm) };
unsafe { app.setActivationPolicy(NSApplicationActivationPolicyRegular) };

// initialize the delegate
let delegate = AppDelegate::new(42, true);
let delegate = AppDelegate::new(42, true, mtm);

println!("{delegate:?}");

Expand Down
6 changes: 4 additions & 2 deletions crates/icrate/examples/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use icrate::{
NSWindowStyleMaskTitled,
},
Foundation::{
ns_string, NSDate, NSNotification, NSObject, NSObjectProtocol, NSPoint, NSRect, NSSize,
ns_string, MainThreadMarker, NSDate, NSNotification, NSObject, NSObjectProtocol, NSPoint,
NSRect, NSSize,
},
Metal::{
MTLCommandBuffer, MTLCommandEncoder, MTLCommandQueue, MTLCreateSystemDefaultDevice,
Expand Down Expand Up @@ -358,8 +359,9 @@ impl Delegate {
}

fn main() {
let mtm = MainThreadMarker::new().unwrap();
// configure the app
let app = unsafe { NSApplication::sharedApplication() };
let app = unsafe { NSApplication::sharedApplication(mtm) };
unsafe { app.setActivationPolicy(NSApplicationActivationPolicyRegular) };

// initialize the delegate
Expand Down
2 changes: 1 addition & 1 deletion crates/icrate/src/generated
5 changes: 3 additions & 2 deletions crates/objc2/src/macros/__method_msg_send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,13 @@ macro_rules! __method_msg_send_id {

()
()
()
// Possible to hit via. the MainThreadMarker branch
($($already_parsed_retain_semantics:ident)?)
) => {
$crate::__msg_send_id_helper! {
@(send_message_id)
@($receiver)
@($($retain_semantics)?)
@($($retain_semantics)? $($already_parsed_retain_semantics)?)
@($sel)
@()
}
Expand Down
4 changes: 2 additions & 2 deletions crates/objc2/tests/macros_mainthreadmarker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct MainThreadMarker(bool);
extern_methods!(
unsafe impl Cls {
#[method_id(new)]
fn new() -> Id<Self>;
fn new(mtm: MainThreadMarker) -> Id<Self>;

#[method(myMethod:)]
fn method(mtm: MainThreadMarker, arg: i32, mtm2: MainThreadMarker) -> i32;
Expand All @@ -63,8 +63,8 @@ extern_methods!(

#[test]
fn call() {
let obj1 = Cls::new();
let mtm = MainThreadMarker(true);
let obj1 = Cls::new(mtm);

let res = Cls::method(mtm, 2, mtm);
assert_eq!(res, 3);
Expand Down

0 comments on commit 3e6d97f

Please sign in to comment.