From 1142cf105b6fce1cb4dff2b5a45de6cbb7503865 Mon Sep 17 00:00:00 2001 From: rustdesk Date: Fri, 1 Dec 2023 11:32:07 +0800 Subject: [PATCH] Fix #324 to remove unsafe --- src/relay_server.rs | 132 ++++++++++++++++++--------------------- src/rendezvous_server.rs | 28 ++++----- 2 files changed, 75 insertions(+), 85 deletions(-) diff --git a/src/relay_server.rs b/src/relay_server.rs index 4522c894..e042c235 100644 --- a/src/relay_server.rs +++ b/src/relay_server.rs @@ -25,6 +25,7 @@ use std::{ io::prelude::*, io::Error, net::SocketAddr, + sync::atomic::{AtomicUsize, Ordering}, }; type Usage = (usize, usize, usize, usize); @@ -36,11 +37,11 @@ lazy_static::lazy_static! { static ref BLOCKLIST: RwLock> = Default::default(); } -static mut DOWNGRADE_THRESHOLD: f64 = 0.66; -static mut DOWNGRADE_START_CHECK: usize = 1_800_000; // in ms -static mut LIMIT_SPEED: usize = 4 * 1024 * 1024; // in bit/s -static mut TOTAL_BANDWIDTH: usize = 1024 * 1024 * 1024; // in bit/s -static mut SINGLE_BANDWIDTH: usize = 16 * 1024 * 1024; // in bit/s +static DOWNGRADE_THRESHOLD_100: AtomicUsize = AtomicUsize::new(66); // 0.66 +static DOWNGRADE_START_CHECK: AtomicUsize = AtomicUsize::new(1_800_000); // in ms +static LIMIT_SPEED: AtomicUsize = AtomicUsize::new(4 * 1024 * 1024); // in bit/s +static TOTAL_BANDWIDTH: AtomicUsize = AtomicUsize::new(1024 * 1024 * 1024); // in bit/s +static SINGLE_BANDWIDTH: AtomicUsize = AtomicUsize::new(16 * 1024 * 1024); // in bit/s const BLACKLIST_FILE: &str = "blacklist.txt"; const BLOCKLIST_FILE: &str = "blocklist.txt"; @@ -99,57 +100,53 @@ fn check_params() { .map(|x| x.parse::().unwrap_or(0.)) .unwrap_or(0.); if tmp > 0. { - unsafe { - DOWNGRADE_THRESHOLD = tmp; - } + DOWNGRADE_THRESHOLD_100.store((tmp * 100.) as _, Ordering::SeqCst); } - unsafe { log::info!("DOWNGRADE_THRESHOLD: {}", DOWNGRADE_THRESHOLD) }; + log::info!( + "DOWNGRADE_THRESHOLD: {}", + DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100. + ); let tmp = std::env::var("DOWNGRADE_START_CHECK") .map(|x| x.parse::().unwrap_or(0)) .unwrap_or(0); if tmp > 0 { - unsafe { - DOWNGRADE_START_CHECK = tmp * 1000; - } + DOWNGRADE_START_CHECK.store(tmp * 1000, Ordering::SeqCst); } - unsafe { log::info!("DOWNGRADE_START_CHECK: {}s", DOWNGRADE_START_CHECK / 1000) }; + log::info!( + "DOWNGRADE_START_CHECK: {}s", + DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000 + ); let tmp = std::env::var("LIMIT_SPEED") .map(|x| x.parse::().unwrap_or(0.)) .unwrap_or(0.); if tmp > 0. { - unsafe { - LIMIT_SPEED = (tmp * 1024. * 1024.) as usize; - } + LIMIT_SPEED.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst); } - unsafe { log::info!("LIMIT_SPEED: {}Mb/s", LIMIT_SPEED as f64 / 1024. / 1024.) }; + log::info!( + "LIMIT_SPEED: {}Mb/s", + LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024. + ); let tmp = std::env::var("TOTAL_BANDWIDTH") .map(|x| x.parse::().unwrap_or(0.)) .unwrap_or(0.); if tmp > 0. { - unsafe { - TOTAL_BANDWIDTH = (tmp * 1024. * 1024.) as usize; - } + TOTAL_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst); } - unsafe { - log::info!( - "TOTAL_BANDWIDTH: {}Mb/s", - TOTAL_BANDWIDTH as f64 / 1024. / 1024. - ) - }; + + log::info!( + "TOTAL_BANDWIDTH: {}Mb/s", + TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024. + ); let tmp = std::env::var("SINGLE_BANDWIDTH") .map(|x| x.parse::().unwrap_or(0.)) .unwrap_or(0.); if tmp > 0. { - unsafe { - SINGLE_BANDWIDTH = (tmp * 1024. * 1024.) as usize; - } + SINGLE_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst); } - unsafe { - log::info!( - "SINGLE_BANDWIDTH: {}Mb/s", - SINGLE_BANDWIDTH as f64 / 1024. / 1024. - ) - }; + log::info!( + "SINGLE_BANDWIDTH: {}Mb/s", + SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024. + ) } async fn check_cmd(cmd: &str, limiter: Limiter) -> String { @@ -233,76 +230,68 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0. { - unsafe { - DOWNGRADE_THRESHOLD = v; - } + DOWNGRADE_THRESHOLD_100.store((v * 100.) as _, Ordering::SeqCst); } } } else { - unsafe { - res = format!("{DOWNGRADE_THRESHOLD}\n"); - } + res = format!( + "{}\n", + DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100. + ); } } Some("downgrade-start-check" | "t") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0 { - unsafe { - DOWNGRADE_START_CHECK = v * 1000; - } + DOWNGRADE_START_CHECK.store(v * 1000, Ordering::SeqCst); } } } else { - unsafe { - res = format!("{}s\n", DOWNGRADE_START_CHECK / 1000); - } + res = format!("{}s\n", DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000); } } Some("limit-speed" | "ls") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0. { - unsafe { - LIMIT_SPEED = (v * 1024. * 1024.) as _; - } + LIMIT_SPEED.store((v * 1024. * 1024.) as _, Ordering::SeqCst); } } } else { - unsafe { - res = format!("{}Mb/s\n", LIMIT_SPEED as f64 / 1024. / 1024.); - } + res = format!( + "{}Mb/s\n", + LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024. + ); } } Some("total-bandwidth" | "tb") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0. { - unsafe { - TOTAL_BANDWIDTH = (v * 1024. * 1024.) as _; - limiter.set_speed_limit(TOTAL_BANDWIDTH as _); - } + TOTAL_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst); + limiter.set_speed_limit(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _); } } } else { - unsafe { - res = format!("{}Mb/s\n", TOTAL_BANDWIDTH as f64 / 1024. / 1024.); - } + res = format!( + "{}Mb/s\n", + TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024. + ); } } Some("single-bandwidth" | "sb") => { if let Some(v) = fds.next() { if let Ok(v) = v.parse::() { if v > 0. { - unsafe { - SINGLE_BANDWIDTH = (v * 1024. * 1024.) as _; - } + SINGLE_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst); } } } else { - unsafe { - res = format!("{}Mb/s\n", SINGLE_BANDWIDTH as f64 / 1024. / 1024.); - } + res = format!( + "{}Mb/s\n", + SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024. + ); } } Some("usage" | "u") => { @@ -336,7 +325,7 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String { async fn io_loop(listener: TcpListener, listener2: TcpListener, key: &str) { check_params(); - let limiter = ::new(unsafe { TOTAL_BANDWIDTH as _ }); + let limiter = ::new(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _); loop { tokio::select! { res = listener.accept() => { @@ -475,10 +464,11 @@ async fn relay( let mut highest_s = 0; let mut downgrade: bool = false; let mut blacked: bool = false; - let limiter = ::new(unsafe { SINGLE_BANDWIDTH as _ }); - let blacklist_limiter = ::new(unsafe { LIMIT_SPEED as _ }); + let sb = SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64; + let limiter = ::new(sb); + let blacklist_limiter = ::new(LIMIT_SPEED.load(Ordering::SeqCst) as _); let downgrade_threshold = - (unsafe { SINGLE_BANDWIDTH as f64 * DOWNGRADE_THRESHOLD } / 1000.) as usize; // in bit/ms + (sb * DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100. / 1000.) as usize; // in bit/ms let mut timer = interval(Duration::from_secs(3)); let mut last_recv_time = std::time::Instant::now(); loop { @@ -546,7 +536,7 @@ async fn relay( (elapsed as _, total as _, highest_s as _, speed as _), ); total_s = 0; - if elapsed > unsafe { DOWNGRADE_START_CHECK } + if elapsed > DOWNGRADE_START_CHECK.load(Ordering::SeqCst) && !downgrade && total > elapsed * downgrade_threshold { diff --git a/src/rendezvous_server.rs b/src/rendezvous_server.rs index c5045768..0c0d1075 100644 --- a/src/rendezvous_server.rs +++ b/src/rendezvous_server.rs @@ -35,6 +35,7 @@ use sodiumoxide::crypto::sign; use std::{ collections::HashMap, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + sync::atomic::{AtomicBool, AtomicUsize, Ordering}, sync::Arc, time::Instant, }; @@ -55,10 +56,10 @@ enum Sink { } type Sender = mpsc::UnboundedSender; type Receiver = mpsc::UnboundedReceiver; -static mut ROTATION_RELAY_SERVER: usize = 0; +static ROTATION_RELAY_SERVER: AtomicUsize = AtomicUsize::new(0); type RelayServers = Vec; static CHECK_RELAY_TIMEOUT: u64 = 3_000; -static mut ALWAYS_USE_RELAY: bool = false; +static ALWAYS_USE_RELAY: AtomicBool = AtomicBool::new(false); #[derive(Clone)] struct Inner { @@ -147,13 +148,11 @@ impl RendezvousServer { .to_uppercase() == "Y" { - unsafe { - ALWAYS_USE_RELAY = true; - } + ALWAYS_USE_RELAY.store(true, Ordering::SeqCst); } log::info!( "ALWAYS_USE_RELAY={}", - if unsafe { ALWAYS_USE_RELAY } { + if ALWAYS_USE_RELAY.load(Ordering::SeqCst) { "Y" } else { "N" @@ -711,7 +710,7 @@ impl RendezvousServer { let peer_is_lan = self.is_lan(peer_addr); let is_lan = self.is_lan(addr); let mut relay_server = self.get_relay_server(addr.ip(), peer_addr.ip()); - if unsafe { ALWAYS_USE_RELAY } || (peer_is_lan ^ is_lan) { + if ALWAYS_USE_RELAY.load(Ordering::SeqCst) || (peer_is_lan ^ is_lan) { if peer_is_lan { // https://github.com/rustdesk/rustdesk-server/issues/24 relay_server = self.inner.local_ip.clone() @@ -905,10 +904,7 @@ impl RendezvousServer { } else if self.relay_servers.len() == 1 { return self.relay_servers[0].clone(); } - let i = unsafe { - ROTATION_RELAY_SERVER += 1; - ROTATION_RELAY_SERVER % self.relay_servers.len() - }; + let i = ROTATION_RELAY_SERVER.fetch_add(1, Ordering::SeqCst) % self.relay_servers.len(); self.relay_servers[i].clone() } @@ -1027,13 +1023,17 @@ impl RendezvousServer { Some("always-use-relay" | "aur") => { if let Some(rs) = fds.next() { if rs.to_uppercase() == "Y" { - unsafe { ALWAYS_USE_RELAY = true }; + ALWAYS_USE_RELAY.store(true, Ordering::SeqCst); } else { - unsafe { ALWAYS_USE_RELAY = false }; + ALWAYS_USE_RELAY.store(false, Ordering::SeqCst); } self.tx.send(Data::RelayServers0(rs.to_owned())).ok(); } else { - let _ = writeln!(res, "ALWAYS_USE_RELAY: {:?}", unsafe { ALWAYS_USE_RELAY }); + let _ = writeln!( + res, + "ALWAYS_USE_RELAY: {:?}", + ALWAYS_USE_RELAY.load(Ordering::SeqCst) + ); } } Some("test-geo" | "tg") => {