From 5d4dd1f3a5c33c373440d28e1cb8dccad910e0df Mon Sep 17 00:00:00 2001 From: Jae-Won Chung Date: Sun, 26 May 2024 22:09:33 -0400 Subject: [PATCH] A macro to implement APIs more easily --- .github/workflows/zeusd_fmt_lint_test.yaml | 2 +- zeusd/Cargo.lock | 1 + zeusd/Cargo.toml | 5 +- zeusd/src/devices/gpu.rs | 41 ++- zeusd/src/routes/gpu.rs | 285 ++++++++------------- zeusd/src/routes/mod.rs | 13 +- zeusd/tests/gpu.rs | 19 +- zeusd/tests/helpers/mod.rs | 83 +++++- 8 files changed, 233 insertions(+), 216 deletions(-) diff --git a/.github/workflows/zeusd_fmt_lint_test.yaml b/.github/workflows/zeusd_fmt_lint_test.yaml index 192701f3..b7e295af 100644 --- a/.github/workflows/zeusd_fmt_lint_test.yaml +++ b/.github/workflows/zeusd_fmt_lint_test.yaml @@ -32,5 +32,5 @@ jobs: run: cargo clippy --all -- -D warnings working-directory: zeusd - name: Run tests - run: cargo test --features testing + run: cargo test working-directory: zeusd diff --git a/zeusd/Cargo.lock b/zeusd/Cargo.lock index 6e962f19..208c590b 100644 --- a/zeusd/Cargo.lock +++ b/zeusd/Cargo.lock @@ -2032,6 +2032,7 @@ dependencies = [ "nix", "nvml-wrapper", "once_cell", + "paste", "reqwest", "serde", "thiserror", diff --git a/zeusd/Cargo.toml b/zeusd/Cargo.toml index 4b3d3d95..931fb3fa 100644 --- a/zeusd/Cargo.toml +++ b/zeusd/Cargo.toml @@ -5,6 +5,7 @@ authors = ["Jae-Won Chung "] edition = "2021" [lib] +name = "zeusd" path = "src/lib.rs" doctest = false @@ -12,9 +13,6 @@ doctest = false path = "src/main.rs" name = "zeusd" -[features] -testing = [] - [dependencies] actix-web = "4" nvml-wrapper = "0.10" @@ -28,6 +26,7 @@ tracing-subscriber = { version = "0.3", features = ["registry", "env-filter"] } tracing-log = "0.2.0" tracing-actix-web = "0.7.10" nix = { version = "0.29", default-features = false, features = ["user"] } +paste = "1" [dev-dependencies] once_cell = "1.7.2" diff --git a/zeusd/src/devices/gpu.rs b/zeusd/src/devices/gpu.rs index 8c0a6be3..9576963c 100644 --- a/zeusd/src/devices/gpu.rs +++ b/zeusd/src/devices/gpu.rs @@ -7,7 +7,11 @@ use tracing::Span; use crate::error::ZeusdError; -pub trait GpuManager: Send + 'static { +/// A trait for structs that manage one GPU. +/// +/// This trait can be used to abstract over different GPU management libraries. +/// Currently, this was done to facilitate testing. +pub trait GpuManager { fn device_count() -> Result where Self: Sized; @@ -18,11 +22,13 @@ pub trait GpuManager: Send + 'static { min_clock_mhz: u32, max_clock_mhz: u32, ) -> Result<(), ZeusdError>; + fn reset_gpu_locked_clocks(&mut self) -> Result<(), ZeusdError>; fn set_mem_locked_clocks( &mut self, min_clock_mhz: u32, max_clock_mhz: u32, ) -> Result<(), ZeusdError>; + fn reset_mem_locked_clocks(&mut self) -> Result<(), ZeusdError>; } pub struct NvmlGpu<'n> { @@ -34,6 +40,7 @@ impl NvmlGpu<'static> { pub fn init(index: u32) -> Result { // `Device` needs to hold a reference to `Nvml`, meaning that `Nvml` must outlive `Device`. // We can achieve this by leaking a `Box` containing `Nvml` and holding a reference to it. + // `Nvml` will actually live until the server terminates inside the GPU management task. let _nvml = Box::leak(Box::new(Nvml::init()?)); let device = _nvml.device_by_index(index)?; Ok(Self { _nvml, device }) @@ -69,6 +76,11 @@ impl GpuManager for NvmlGpu<'static> { Ok(self.device.set_gpu_locked_clocks(setting)?) } + #[inline] + fn reset_gpu_locked_clocks(&mut self) -> Result<(), ZeusdError> { + Ok(self.device.reset_gpu_locked_clocks()?) + } + #[inline] fn set_mem_locked_clocks( &mut self, @@ -79,6 +91,11 @@ impl GpuManager for NvmlGpu<'static> { .device .set_mem_locked_clocks(min_clock_mhz, max_clock_mhz)?) } + + #[inline] + fn reset_mem_locked_clocks(&mut self) -> Result<(), ZeusdError> { + Ok(self.device.reset_mem_locked_clocks()?) + } } /// A request to execute a GPU command. @@ -183,11 +200,15 @@ pub enum GpuCommand { min_clock_mhz: u32, max_clock_mhz: u32, }, + /// Reset the GPU's locked clocks. + ResetGpuLockedClocks, /// Set the GPU's memory locked clock range in MHz. SetMemLockedClocks { min_clock_mhz: u32, max_clock_mhz: u32, }, + /// Reset the GPU's memory locked clocks. + ResetMemLockedClocks, } impl GpuCommand { @@ -242,6 +263,15 @@ impl GpuCommand { } result } + Self::ResetGpuLockedClocks => { + let result = device.reset_gpu_locked_clocks(); + if result.is_ok() { + tracing::info!("GPU locked clocks reset"); + } else { + tracing::warn!("Cannot reset GPU locked clocks"); + } + result + } Self::SetMemLockedClocks { min_clock_mhz, max_clock_mhz, @@ -262,6 +292,15 @@ impl GpuCommand { } result } + Self::ResetMemLockedClocks => { + let result = device.reset_mem_locked_clocks(); + if result.is_ok() { + tracing::info!("Memory locked clocks reset"); + } else { + tracing::warn!("Cannot reset memory locked clocks"); + } + result + } } } } diff --git a/zeusd/src/routes/gpu.rs b/zeusd/src/routes/gpu.rs index 301317be..bd8bb212 100644 --- a/zeusd/src/routes/gpu.rs +++ b/zeusd/src/routes/gpu.rs @@ -1,211 +1,136 @@ //! Routes for interacting with GPUs use actix_web::{web, HttpResponse}; +use paste::paste; +use serde::{Deserialize, Serialize}; use crate::devices::gpu::{GpuCommand, GpuManagementTasks}; use crate::error::ZeusdError; -#[derive(serde::Deserialize, Debug)] -#[cfg_attr(feature = "testing", derive(serde::Serialize))] -pub struct SetPersistentModeRequest { - pub enabled: bool, - pub block: bool, -} +macro_rules! impl_handler_for_gpu_command { + ($action:ident, $api:ident, $path:expr, $($field:ident),*) => { + // Implement conversion to the GpuCommand variant. + paste! { + impl From<[<$action:camel $api:camel>]> for GpuCommand { + // Prefixing with underscore to avoid lint errors when $field is empty. + fn from(_request: [<$action:camel $api:camel>]) -> Self { + GpuCommand::[<$action:camel $api:camel>] { + $($field: _request.$field),* + } + } + } -impl From for GpuCommand { - fn from(request: SetPersistentModeRequest) -> Self { - GpuCommand::SetPersistentMode { - enabled: request.enabled, + // Generate the request handler. + #[actix_web::post($path)] + #[tracing::instrument( + skip(gpu, request, device_tasks), + fields( + gpu_id = %gpu, + block = %request.block, + $($field = %request.$field),* + ) + )] + pub async fn [<$action:snake _ $api:snake _handler>]( + gpu: web::Path, + request: web::Json<[<$action:camel $api:camel>]>, + device_tasks: web::Data, + ) -> Result { + let gpu = gpu.into_inner(); + let request = request.into_inner(); + + tracing::info!( + "Received reqeust to GPU {} ({:?})", + gpu, + request, + ); + + if request.block { + device_tasks + .send_command_blocking(gpu, request.into()) + .await?; + } else { + device_tasks.send_command_nonblocking(gpu, request.into())?; + } + + Ok(HttpResponse::Ok().finish()) + } } - } + }; } -#[actix_web::post("/{gpu_id}/persistent_mode")] -#[tracing::instrument( - skip(gpu, request, device_tasks), - fields( - gpu_id = %gpu, - enabled = %request.enabled, - block = %request.block - ) -)] -pub async fn set_persistent_mode_handler( - gpu: web::Path, - request: web::Json, - device_tasks: web::Data, -) -> Result { - let gpu = gpu.into_inner(); - let request = request.into_inner(); - - tracing::info!( - "Received reqeust to set GPU {}'s persistent mode to {} W", - gpu, - if request.enabled { - "enabled" - } else { - "disabled" - }, - ); - - if request.block { - device_tasks - .send_command_blocking(gpu, request.into()) - .await?; - } else { - device_tasks.send_command_nonblocking(gpu, request.into())?; - } - - Ok(HttpResponse::Ok().finish()) +#[derive(Serialize, Deserialize, Debug)] +pub struct SetPersistentMode { + pub enabled: bool, + pub block: bool, } -#[derive(serde::Deserialize, Debug)] -struct SetPowerLimitRequest { - power_limit_uw: u32, +impl_handler_for_gpu_command!(set, persistent_mode, "/{gpu_id}/set_persistent_mode", enabled); + +#[derive(Serialize, Deserialize, Debug)] +pub struct SetPowerLimit { + power_limit_mw: u32, block: bool, } -impl From for GpuCommand { - fn from(request: SetPowerLimitRequest) -> Self { - GpuCommand::SetPowerLimit { - power_limit_mw: request.power_limit_uw, - } - } -} +impl_handler_for_gpu_command!(set, power_limit, "/{gpu_id}/set_power_limit", power_limit_mw); -#[actix_web::post("/{gpu_id}/power_limit")] -#[tracing::instrument( - skip(gpu, request, device_tasks), - fields( - gpu_id = %gpu, - power_limit = %request.power_limit_uw, - block = %request.block - ) -)] -pub async fn set_power_limit_handler( - gpu: web::Path, - request: web::Json, - device_tasks: web::Data, -) -> Result { - let gpu = gpu.into_inner(); - let request = request.into_inner(); - - tracing::info!( - "Received reqeust to set GPU {}'s power limit to {} W", - gpu, - request.power_limit_uw / 1000, - ); - - if request.block { - device_tasks - .send_command_blocking(gpu, request.into()) - .await?; - } else { - device_tasks.send_command_nonblocking(gpu, request.into())?; - } - - Ok(HttpResponse::Ok().finish()) -} - -#[derive(serde::Deserialize, Debug)] -struct SetGpuLockedClocksRequest { +#[derive(Serialize, Deserialize, Debug)] +pub struct SetGpuLockedClocks { min_clock_mhz: u32, max_clock_mhz: u32, block: bool, } -impl From for GpuCommand { - fn from(request: SetGpuLockedClocksRequest) -> Self { - GpuCommand::SetGpuLockedClocks { - min_clock_mhz: request.min_clock_mhz, - max_clock_mhz: request.max_clock_mhz, - } - } -} - -#[actix_web::post("/{gpu_id}/gpu_locked_clocks")] -#[tracing::instrument( - skip(gpu, request, device_tasks), - fields( - gpu_id = %gpu, - min_clock_mhz = %request.min_clock_mhz, - max_clock_mhz = %request.max_clock_mhz, - block = %request.block - ) -)] -pub async fn set_gpu_locked_clocks_handler( - gpu: web::Path, - request: web::Json, - device_tasks: web::Data, -) -> Result { - let gpu = gpu.into_inner(); - let request = request.into_inner(); - - tracing::info!( - "Received reqeust to set GPU {}'s gpu locked clocks to [{}, {}] MHz", - gpu, - request.min_clock_mhz, - request.max_clock_mhz, - ); - - if request.block { - device_tasks - .send_command_blocking(gpu, request.into()) - .await?; - } else { - device_tasks.send_command_nonblocking(gpu, request.into())?; - } - - Ok(HttpResponse::Ok().finish()) -} +impl_handler_for_gpu_command!( + set, + gpu_locked_clocks, + "/{gpu_id}/set_gpu_locked_clocks", + min_clock_mhz, + max_clock_mhz +); -#[derive(serde::Deserialize, Debug)] -struct SetMemLockedClocksRequest { +#[derive(Serialize, Deserialize, Debug)] +pub struct SetMemLockedClocks { min_clock_mhz: u32, max_clock_mhz: u32, block: bool, } -impl From for GpuCommand { - fn from(request: SetMemLockedClocksRequest) -> Self { - GpuCommand::SetMemLockedClocks { - min_clock_mhz: request.min_clock_mhz, - max_clock_mhz: request.max_clock_mhz, - } - } +#[derive(Serialize, Deserialize, Debug)] +pub struct ResetGpuLockedClocks { + block: bool, +} + +impl_handler_for_gpu_command!( + reset, + gpu_locked_clocks, + "/{gpu_id}/reset_gpu_locked_clocks", +); + + +impl_handler_for_gpu_command!( + set, + mem_locked_clocks, + "/{gpu_id}/set_mem_locked_clocks", + min_clock_mhz, + max_clock_mhz +); +#[derive(Serialize, Deserialize, Debug)] +pub struct ResetMemLockedClocks { + block: bool, } -#[actix_web::post("/{gpu_id}/mem_locked_clocks")] -#[tracing::instrument( - skip(gpu, request, device_tasks), - fields( - gpu_id = %gpu, - min_clock_mhz = %request.min_clock_mhz, - max_clock_mhz = %request.max_clock_mhz, - block = %request.block - ) -)] -pub async fn set_mem_locked_clocks_handler( - gpu: web::Path, - request: web::Json, - device_tasks: web::Data, -) -> Result { - let gpu = gpu.into_inner(); - let request = request.into_inner(); - - tracing::info!( - "Received reqeust to set GPU {}'s memory locked clocks to [{}, {}] MHz", - gpu, - request.min_clock_mhz, - request.max_clock_mhz, - ); - - if request.block { - device_tasks - .send_command_blocking(gpu, request.into()) - .await?; - } else { - device_tasks.send_command_nonblocking(gpu, request.into())?; - } - - Ok(HttpResponse::Ok().finish()) +impl_handler_for_gpu_command!( + reset, + mem_locked_clocks, + "/{gpu_id}/reset_mem_locked_clocks", +); + +pub fn gpu_routes(cfg: &mut web::ServiceConfig) { + cfg.service(set_persistent_mode_handler) + .service(set_power_limit_handler) + .service(set_gpu_locked_clocks_handler) + .service(reset_gpu_locked_clocks_handler) + .service(set_mem_locked_clocks_handler) + .service(reset_mem_locked_clocks_handler); } diff --git a/zeusd/src/routes/mod.rs b/zeusd/src/routes/mod.rs index 1184ee18..67783626 100644 --- a/zeusd/src/routes/mod.rs +++ b/zeusd/src/routes/mod.rs @@ -1,14 +1,3 @@ pub mod gpu; -use actix_web::web; -use gpu::{ - set_gpu_locked_clocks_handler, set_mem_locked_clocks_handler, set_persistent_mode_handler, - set_power_limit_handler, -}; - -pub fn gpu_routes(cfg: &mut web::ServiceConfig) { - cfg.service(set_persistent_mode_handler) - .service(set_power_limit_handler) - .service(set_gpu_locked_clocks_handler) - .service(set_mem_locked_clocks_handler); -} +pub use gpu::gpu_routes; diff --git a/zeusd/tests/gpu.rs b/zeusd/tests/gpu.rs index daa59282..d19a3e8e 100644 --- a/zeusd/tests/gpu.rs +++ b/zeusd/tests/gpu.rs @@ -1,16 +1,25 @@ mod helpers; +use zeusd::routes::gpu::SetPersistentMode; + use crate::helpers::TestApp; #[tokio::test] async fn test_set_persistent_mode() { let mut app = TestApp::start().await; - let resp = app.set_persistent_mode(0, true, true).await; + let resp = app + .send( + 0, + SetPersistentMode { + enabled: true, + block: true, + }, + ) + .await; assert_eq!(resp.status(), 200); - assert_eq!( - app.observers[0].persistent_mode_rx.recv().await.unwrap(), - true - ); + let history = app.persistent_mode_history_for_gpu(0); + assert_eq!(history.len(), 1); + assert_eq!(history[0], true); } diff --git a/zeusd/tests/helpers/mod.rs b/zeusd/tests/helpers/mod.rs index 4adb6878..abbb38d0 100644 --- a/zeusd/tests/helpers/mod.rs +++ b/zeusd/tests/helpers/mod.rs @@ -1,9 +1,14 @@ +//! Helpers for running integration tests. +//! +//! It has to be under `tests/helpers/mod.rs` instead of `tests/helpers.rs` +//! to avoid it from being treated as another test module. + +use paste::paste; use once_cell::sync::Lazy; use std::net::TcpListener; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use zeusd::devices::gpu::{GpuManagementTasks, GpuManager}; use zeusd::error::ZeusdError; -use zeusd::routes::gpu::SetPersistentModeRequest; use zeusd::startup::{init_tracing, start_server_tcp}; static NUM_GPUS: u32 = 4; @@ -81,6 +86,11 @@ impl GpuManager for TestGpu { Ok(()) } + fn reset_gpu_locked_clocks(&mut self) -> Result<(), ZeusdError> { + self.gpu_locked_clocks_tx.send((0, 0)).unwrap(); + Ok(()) + } + fn set_mem_locked_clocks( &mut self, min_clock_mhz: u32, @@ -91,6 +101,11 @@ impl GpuManager for TestGpu { .unwrap(); Ok(()) } + + fn reset_mem_locked_clocks(&mut self) -> Result<(), ZeusdError> { + self.mem_locked_clocks_tx.send((0, 0)).unwrap(); + Ok(()) + } } pub fn start_test_tasks() -> anyhow::Result<(GpuManagementTasks, Vec)> { @@ -107,9 +122,38 @@ pub fn start_test_tasks() -> anyhow::Result<(GpuManagementTasks, Vec String; +} + +macro_rules! impl_zeusd_request { + ($api:ident) => { + paste! { + impl ZeusdRequest for zeusd::routes::gpu::[<$api:camel>] { + fn build_url(app: &TestApp, gpu_id: u32) -> String { + format!( + "http://127.0.0.1:{}/gpu/{}/{}", + app.port, gpu_id, stringify!([<$api:snake>]), + ) + } + } + } + } +} + +impl_zeusd_request!(SetPersistentMode); +impl_zeusd_request!(SetPowerLimit); +impl_zeusd_request!(SetGpuLockedClocks); +impl_zeusd_request!(ResetGpuLockedClocks); +impl_zeusd_request!(SetMemLockedClocks); +impl_zeusd_request!(ResetMemLockedClocks); + +/// A test application that starts a server over TCP and provides helper methods +/// for sending requests and fetching what happened to the fake GPUs. pub struct TestApp { port: u16, - pub observers: Vec, + observers: Vec, } impl TestApp { @@ -130,24 +174,35 @@ impl TestApp { } } - pub async fn set_persistent_mode( - &mut self, - gpu_id: u32, - enabled: bool, - block: bool, - ) -> reqwest::Response { + pub async fn send(&mut self, gpu_id: u32, payload: T) -> reqwest::Response { let client = reqwest::Client::new(); - let url = format!( - "http://127.0.0.1:{}/gpu/{}/persistent_mode", - self.port, gpu_id - ); - let payload = SetPersistentModeRequest { enabled, block }; + let url = T::build_url(self, gpu_id); client - .post(&url) + .post(url) .json(&payload) .send() .await .expect("Failed to send request") } + + pub fn persistent_mode_history_for_gpu(&mut self, gpu_id: usize) -> Vec { + let rx = &mut self.observers[gpu_id].persistent_mode_rx; + std::iter::from_fn(|| rx.try_recv().ok()).collect() + } + + pub fn power_limit_history_for_gpu(&mut self, gpu_id: usize) -> Vec { + let rx = &mut self.observers[gpu_id].power_limit_rx; + std::iter::from_fn(|| rx.try_recv().ok()).collect() + } + + pub fn gpu_locked_clocks_history_for_gpu(&mut self, gpu_id: usize) -> Vec<(u32, u32)> { + let rx = &mut self.observers[gpu_id].gpu_locked_clocks_rx; + std::iter::from_fn(|| rx.try_recv().ok()).collect() + } + + pub fn mem_locked_clocks_history_for_gpu(&mut self, gpu_id: usize) -> Vec<(u32, u32)> { + let rx = &mut self.observers[gpu_id].mem_locked_clocks_rx; + std::iter::from_fn(|| rx.try_recv().ok()).collect() + } }