From 0f7466c34e774e547d21c579b58b60168c4ee6bc Mon Sep 17 00:00:00 2001 From: Kenny Kerr Date: Mon, 23 Sep 2024 18:01:22 -0500 Subject: [PATCH] Add `Deref` implementation for `HSTRING` (#3291) --- crates/libs/result/src/bstr.rs | 32 ++++----- crates/libs/result/src/error.rs | 2 +- crates/libs/strings/src/bstr.rs | 64 ++++++++---------- crates/libs/strings/src/hstring.rs | 65 +++++++------------ crates/tests/misc/literals/tests/win.rs | 4 +- crates/tests/misc/string_param/tests/pwstr.rs | 18 +---- crates/tests/misc/strings/tests/bstr.rs | 35 +++++++++- crates/tests/misc/strings/tests/hstring.rs | 31 ++++++++- .../misc/strings/tests/hstring_builder.rs | 5 +- .../tests/misc/win32_arrays/tests/xmllite.rs | 11 ++-- 10 files changed, 134 insertions(+), 133 deletions(-) diff --git a/crates/libs/result/src/bstr.rs b/crates/libs/result/src/bstr.rs index 700e4d7cb8..86b8c2caaf 100644 --- a/crates/libs/result/src/bstr.rs +++ b/crates/libs/result/src/bstr.rs @@ -1,36 +1,26 @@ use super::*; +use core::ops::Deref; #[repr(transparent)] pub struct BasicString(*const u16); -impl BasicString { - pub fn is_empty(&self) -> bool { - self.len() == 0 - } +impl Deref for BasicString { + type Target = [u16]; - pub fn len(&self) -> usize { - if self.0.is_null() { + fn deref(&self) -> &[u16] { + let len = if self.0.is_null() { 0 } else { unsafe { SysStringLen(self.0) as usize } - } - } - - pub fn as_wide(&self) -> &[u16] { - let len = self.len(); - if len != 0 { - unsafe { core::slice::from_raw_parts(self.as_ptr(), len) } - } else { - &[] - } - } + }; - pub fn as_ptr(&self) -> *const u16 { - if !self.is_empty() { - self.0 + if len > 0 { + unsafe { core::slice::from_raw_parts(self.0, len) } } else { + // This ensures that if `as_ptr` is called on the slice that the resulting pointer + // will still refer to a null-terminated string. const EMPTY: [u16; 1] = [0]; - EMPTY.as_ptr() + &EMPTY[..0] } } } diff --git a/crates/libs/result/src/error.rs b/crates/libs/result/src/error.rs index 9c448bc6e8..768b3bd32d 100644 --- a/crates/libs/result/src/error.rs +++ b/crates/libs/result/src/error.rs @@ -343,7 +343,7 @@ mod error_info { } } - Some(String::from_utf16_lossy(wide_trim_end(message.as_wide()))) + Some(String::from_utf16_lossy(wide_trim_end(&message))) } pub(crate) fn as_ptr(&self) -> *mut core::ffi::c_void { diff --git a/crates/libs/strings/src/bstr.rs b/crates/libs/strings/src/bstr.rs index 79fd493b95..53cf5da87e 100644 --- a/crates/libs/strings/src/bstr.rs +++ b/crates/libs/strings/src/bstr.rs @@ -1,4 +1,5 @@ use super::*; +use core::ops::Deref; /// A BSTR string ([BSTR](https://learn.microsoft.com/en-us/previous-versions/windows/desktop/automat/string-manipulation-functions)) /// is a length-prefixed wide string. @@ -13,35 +14,6 @@ impl BSTR { Self(core::ptr::null_mut()) } - /// Returns `true` if the string is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns the length of the string. - pub fn len(&self) -> usize { - if self.0.is_null() { - 0 - } else { - unsafe { bindings::SysStringLen(self.0) as usize } - } - } - - /// Get the string as 16-bit wide characters (wchars). - pub fn as_wide(&self) -> &[u16] { - unsafe { core::slice::from_raw_parts(self.as_ptr(), self.len()) } - } - - /// Returns a raw pointer to the `BSTR` buffer. - pub fn as_ptr(&self) -> *const u16 { - if !self.is_empty() { - self.0 - } else { - const EMPTY: [u16; 1] = [0]; - EMPTY.as_ptr() - } - } - /// Create a `BSTR` from a slice of 16 bit characters (wchars). pub fn from_wide(value: &[u16]) -> Self { if value.is_empty() { @@ -75,9 +47,30 @@ impl BSTR { } } +impl Deref for BSTR { + type Target = [u16]; + + fn deref(&self) -> &[u16] { + let len = if self.0.is_null() { + 0 + } else { + unsafe { bindings::SysStringLen(self.0) as usize } + }; + + if len > 0 { + unsafe { core::slice::from_raw_parts(self.0, len) } + } else { + // This ensures that if `as_ptr` is called on the slice that the resulting pointer + // will still refer to a null-terminated string. + const EMPTY: [u16; 1] = [0]; + &EMPTY[..0] + } + } +} + impl Clone for BSTR { fn clone(&self) -> Self { - Self::from_wide(self.as_wide()) + Self::from_wide(self) } } @@ -104,7 +97,7 @@ impl<'a> TryFrom<&'a BSTR> for String { type Error = alloc::string::FromUtf16Error; fn try_from(value: &BSTR) -> core::result::Result { - String::from_utf16(value.as_wide()) + String::from_utf16(value) } } @@ -127,7 +120,7 @@ impl core::fmt::Display for BSTR { core::write!( f, "{}", - Decode(|| core::char::decode_utf16(self.as_wide().iter().cloned())) + Decode(|| core::char::decode_utf16(self.iter().cloned())) ) } } @@ -140,7 +133,7 @@ impl core::fmt::Debug for BSTR { impl PartialEq for BSTR { fn eq(&self, other: &Self) -> bool { - self.as_wide() == other.as_wide() + self.deref() == other.deref() } } @@ -160,10 +153,7 @@ impl PartialEq for String { impl + ?Sized> PartialEq for BSTR { fn eq(&self, other: &T) -> bool { - self.as_wide() - .iter() - .copied() - .eq(other.as_ref().encode_utf16()) + self.iter().copied().eq(other.as_ref().encode_utf16()) } } diff --git a/crates/libs/strings/src/hstring.rs b/crates/libs/strings/src/hstring.rs index 11a3f5db49..a72a8dd11a 100644 --- a/crates/libs/strings/src/hstring.rs +++ b/crates/libs/strings/src/hstring.rs @@ -1,4 +1,5 @@ use super::*; +use core::ops::Deref; /// An ([HSTRING](https://docs.microsoft.com/en-us/windows/win32/winrt/hstring)) /// is a reference-counted and immutable UTF-16 string type. @@ -13,36 +14,6 @@ impl HSTRING { Self(core::ptr::null_mut()) } - /// Returns `true` if the string is empty. - pub fn is_empty(&self) -> bool { - // An empty HSTRING is represented by a null pointer. - self.0.is_null() - } - - /// Returns the length of the string. The length is measured in `u16`s (UTF-16 code units), not including the terminating null character. - pub fn len(&self) -> usize { - if let Some(header) = self.as_header() { - header.len as usize - } else { - 0 - } - } - - /// Get the string as 16-bit wide characters (wchars). - pub fn as_wide(&self) -> &[u16] { - unsafe { core::slice::from_raw_parts(self.as_ptr(), self.len()) } - } - - /// Returns a raw pointer to the `HSTRING` buffer. - pub fn as_ptr(&self) -> *const u16 { - if let Some(header) = self.as_header() { - header.data - } else { - const EMPTY: [u16; 1] = [0]; - EMPTY.as_ptr() - } - } - /// Create a `HSTRING` from a slice of 16 bit characters (wchars). pub fn from_wide(value: &[u16]) -> Self { unsafe { Self::from_wide_iter(value.iter().copied(), value.len()) } @@ -50,13 +21,13 @@ impl HSTRING { /// Get the contents of this `HSTRING` as a String lossily. pub fn to_string_lossy(&self) -> String { - String::from_utf16_lossy(self.as_wide()) + String::from_utf16_lossy(self) } /// Get the contents of this `HSTRING` as a OsString. #[cfg(feature = "std")] pub fn to_os_string(&self) -> std::ffi::OsString { - std::os::windows::ffi::OsStringExt::from_wide(self.as_wide()) + std::os::windows::ffi::OsStringExt::from_wide(self) } /// # Safety @@ -87,6 +58,21 @@ impl HSTRING { } } +impl Deref for HSTRING { + type Target = [u16]; + + fn deref(&self) -> &[u16] { + if let Some(header) = self.as_header() { + unsafe { core::slice::from_raw_parts(header.data, header.len as usize) } + } else { + // This ensures that if `as_ptr` is called on the slice that the resulting pointer + // will still refer to a null-terminated string. + const EMPTY: [u16; 1] = [0]; + &EMPTY[..0] + } + } +} + impl Default for HSTRING { fn default() -> Self { Self::new() @@ -125,7 +111,7 @@ impl core::fmt::Display for HSTRING { write!( f, "{}", - Decode(|| core::char::decode_utf16(self.as_wide().iter().cloned())) + Decode(|| core::char::decode_utf16(self.iter().cloned())) ) } } @@ -191,13 +177,13 @@ impl Eq for HSTRING {} impl Ord for HSTRING { fn cmp(&self, other: &Self) -> core::cmp::Ordering { - self.as_wide().cmp(other.as_wide()) + self.deref().cmp(other) } } impl core::hash::Hash for HSTRING { fn hash(&self, hasher: &mut H) { - self.as_wide().hash(hasher) + self.deref().hash(hasher) } } @@ -209,7 +195,7 @@ impl PartialOrd for HSTRING { impl PartialEq for HSTRING { fn eq(&self, other: &Self) -> bool { - *self.as_wide() == *other.as_wide() + self.deref() == other.deref() } } @@ -233,7 +219,7 @@ impl PartialEq<&String> for HSTRING { impl PartialEq for HSTRING { fn eq(&self, other: &str) -> bool { - self.as_wide().iter().copied().eq(other.encode_utf16()) + self.iter().copied().eq(other.encode_utf16()) } } @@ -309,8 +295,7 @@ impl PartialEq<&std::ffi::OsString> for HSTRING { #[cfg(feature = "std")] impl PartialEq for HSTRING { fn eq(&self, other: &std::ffi::OsStr) -> bool { - self.as_wide() - .iter() + self.iter() .copied() .eq(std::os::windows::ffi::OsStrExt::encode_wide(other)) } @@ -376,7 +361,7 @@ impl<'a> TryFrom<&'a HSTRING> for String { type Error = alloc::string::FromUtf16Error; fn try_from(hstring: &HSTRING) -> core::result::Result { - String::from_utf16(hstring.as_wide()) + String::from_utf16(hstring) } } diff --git a/crates/tests/misc/literals/tests/win.rs b/crates/tests/misc/literals/tests/win.rs index 8cf9281bb5..50dc6ed382 100644 --- a/crates/tests/misc/literals/tests/win.rs +++ b/crates/tests/misc/literals/tests/win.rs @@ -47,8 +47,6 @@ fn test() { fn into() { let a = h!(""); assert!(a.is_empty()); - assert!(!a.as_ptr().is_null()); - assert!(a.as_wide().is_empty()); let b = PCWSTR(a.as_ptr()); // Even though an empty HSTRING is internally represented by a null pointer, the PCWSTR // will still be a non-null pointer to a null terminated empty string. @@ -80,7 +78,7 @@ fn assert_hstring(left: &HSTRING, right: &[u16]) { unsafe { wcslen(PCWSTR::from_raw(left.as_ptr())) }, right.len() - 1 ); - let left = unsafe { std::slice::from_raw_parts(left.as_wide().as_ptr(), right.len()) }; + let left = unsafe { std::slice::from_raw_parts(left.as_ptr(), right.len()) }; assert_eq!(left, right); } diff --git a/crates/tests/misc/string_param/tests/pwstr.rs b/crates/tests/misc/string_param/tests/pwstr.rs index 2510c9a889..750965b10e 100644 --- a/crates/tests/misc/string_param/tests/pwstr.rs +++ b/crates/tests/misc/string_param/tests/pwstr.rs @@ -1,20 +1,4 @@ -use windows::{core::*, Win32::Foundation::*, Win32::UI::Shell::*}; - -#[test] -fn error() { - unsafe { - SetLastError(ERROR_BUSY_DRIVE); - - let utf8 = "test\0".as_bytes(); - let utf16 = HSTRING::from("test\0"); - let utf16 = utf16.as_wide(); - let len = 5; - assert_eq!(utf8.len(), len); - assert_eq!(utf16.len(), len); - - assert_eq!(GetLastError(), ERROR_BUSY_DRIVE); - } -} +use windows::{core::*, Win32::UI::Shell::*}; #[test] fn convert() { diff --git a/crates/tests/misc/strings/tests/bstr.rs b/crates/tests/misc/strings/tests/bstr.rs index 5df54f5a51..73c8bc466f 100644 --- a/crates/tests/misc/strings/tests/bstr.rs +++ b/crates/tests/misc/strings/tests/bstr.rs @@ -28,14 +28,14 @@ fn clone() { assert!(a.is_empty()); assert!(a.len() == 0); assert_eq!(a.len(), 0); - assert_eq!(a.as_wide().len(), 0); + assert_eq!(a.len(), 0); let wide = &[0x68, 0x65, 0x6c, 0x6c, 0x6f]; let a = BSTR::from_wide(wide); assert!(!a.is_empty()); assert!(a.len() == 5); - assert_eq!(a.as_wide().len(), 5); - assert_eq!(a.as_wide(), wide); + assert_eq!(a.len(), 5); + assert_eq!(*a, *wide); assert_eq!(a, "hello"); let a: BSTR = "".into(); @@ -60,3 +60,32 @@ fn interop() -> Result<()> { Ok(()) } } + +#[test] +fn deref_as_slice() { + let deref = BSTR::from("0123456789"); + assert!(!deref.is_empty()); + assert_eq!(deref.len(), 10); + assert_eq!(BSTR::from_wide(&deref[..=3]), "0123"); + assert!(deref.ends_with(&deref[7..=9])); + assert_eq!(deref.get(5), Some(b'5' as u16).as_ref()); + let ptr = PCWSTR(deref.as_ptr()); + assert_eq!(deref.cmp(&deref), std::cmp::Ordering::Equal); + + unsafe { + assert_eq!(*ptr.as_wide(), *deref); + } + + let empty = BSTR::new(); + assert!(empty.is_empty()); + assert_eq!(empty.len(), 0); + assert_eq!(*empty, []); + + unsafe { + assert_eq!(wcslen(empty.as_ptr()), 0); + } +} + +extern "C" { + pub fn wcslen(s: *const u16) -> usize; +} diff --git a/crates/tests/misc/strings/tests/hstring.rs b/crates/tests/misc/strings/tests/hstring.rs index 86cf6e0815..30d3dbc8aa 100644 --- a/crates/tests/misc/strings/tests/hstring.rs +++ b/crates/tests/misc/strings/tests/hstring.rs @@ -232,7 +232,7 @@ fn hstring_compat() -> Result<()> { let result = WindowsConcatString(&hey, &world)?; assert_eq!(result, "HeyWorld"); - let result = WindowsCreateString(Some(&hey.as_wide()))?; + let result = WindowsCreateString(Some(&hey))?; assert_eq!(result, "Hey"); let result = WindowsDuplicateString(&hey)?; @@ -292,6 +292,35 @@ fn hstring_compat() -> Result<()> { } } +#[test] +fn deref_as_slice() { + let deref = HSTRING::from("0123456789"); + assert!(!deref.is_empty()); + assert_eq!(deref.len(), 10); + assert_eq!(HSTRING::from_wide(&deref[..=3]), "0123"); + assert!(deref.ends_with(&deref[7..=9])); + assert_eq!(deref.get(5), Some(b'5' as u16).as_ref()); + let ptr = PCWSTR(deref.as_ptr()); + assert_eq!(deref.cmp(&deref), std::cmp::Ordering::Equal); + + unsafe { + assert_eq!(*ptr.as_wide(), *deref); + } + + let empty = HSTRING::new(); + assert!(empty.is_empty()); + assert_eq!(empty.len(), 0); + assert_eq!(*empty, []); + + unsafe { + assert_eq!(wcslen(empty.as_ptr()), 0); + } +} + +extern "C" { + pub fn wcslen(s: *const u16) -> usize; +} + mod sys { windows_targets::link!("api-ms-win-core-winrt-string-l1-1-0.dll" "system" fn WindowsCreateStringReference(sourcestring: PCWSTR, length: u32, hstringheader: *mut HSTRING_HEADER, string: *mut HSTRING) -> HRESULT); windows_targets::link!("api-ms-win-core-winrt-string-l1-1-0.dll" "system" fn WindowsDeleteString(string: HSTRING) -> HRESULT); diff --git a/crates/tests/misc/strings/tests/hstring_builder.rs b/crates/tests/misc/strings/tests/hstring_builder.rs index 3f1ca2aa0a..e0fc4f580e 100644 --- a/crates/tests/misc/strings/tests/hstring_builder.rs +++ b/crates/tests/misc/strings/tests/hstring_builder.rs @@ -4,7 +4,6 @@ use windows_strings::*; fn hstring() { let s = HSTRING::from("hello"); assert_eq!(s.len(), 5); - assert_eq!(s.as_wide().len(), 5); } #[test] @@ -36,7 +35,7 @@ fn hstring_builder() { b.copy_from_slice(&HELLO00); let h: HSTRING = b.into(); assert_eq!(h.len(), 7); - assert_eq!(h.as_wide(), HELLO00); + assert_eq!(*h, HELLO00); // But trim_end can avoid that. let mut b = HStringBuilder::new(7); @@ -44,7 +43,7 @@ fn hstring_builder() { b.trim_end(); let h: HSTRING = b.into(); assert_eq!(h.len(), 5); - assert_eq!(h.as_wide(), HELLO); + assert_eq!(*h, HELLO); // HStringBuilder will initialize memory to zero. let b = HStringBuilder::new(5); diff --git a/crates/tests/misc/win32_arrays/tests/xmllite.rs b/crates/tests/misc/win32_arrays/tests/xmllite.rs index d1949a842f..0f640d7e62 100644 --- a/crates/tests/misc/win32_arrays/tests/xmllite.rs +++ b/crates/tests/misc/win32_arrays/tests/xmllite.rs @@ -117,13 +117,10 @@ fn lite() -> Result<()> { let writer = writer.unwrap(); writer.SetOutput(&stream)?; - writer.WriteStartElement(HSTRING::from("html").as_wide())?; - writer.WriteAttributeString(HSTRING::from("no-value").as_wide(), None)?; - writer.WriteAttributeString( - HSTRING::from("with-value").as_wide(), - Some(HSTRING::from("value").as_wide()), - )?; - writer.WriteEndElement(HSTRING::from("html").as_wide())?; + writer.WriteStartElement(&HSTRING::from("html"))?; + writer.WriteAttributeString(&HSTRING::from("no-value"), None)?; + writer.WriteAttributeString(&HSTRING::from("with-value"), Some(&HSTRING::from("value")))?; + writer.WriteEndElement(&HSTRING::from("html"))?; writer.Flush()?; let mut pos = 0;