diff --git a/src/dataframe.rs b/src/dataframe.rs index 69c867e..43b8b78 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -210,6 +210,17 @@ impl DataFrame { Ok(data.value(0)) } + pub async fn foreach(self, mut f: F) -> Result<(), SparkError> + where + F: FnMut(&RecordBatch) -> (), + { + let rows = self.collect().await?; + for i in 0..rows.num_rows() { + let row = rows.slice(i, 1); + f(&row); + } + Ok(()) + } /// Creates a local temporary view with this DataFrame. #[allow(non_snake_case)] diff --git a/src/functions/mod.rs b/src/functions/mod.rs index aaf2b92..b97d31c 100644 --- a/src/functions/mod.rs +++ b/src/functions/mod.rs @@ -528,6 +528,31 @@ mod tests { assert_eq!(&expected, &rows_func_asc); Ok(()) } + #[tokio::test] + async fn test_func_foreach() -> Result<(), SparkError> { + let spark = setup().await; + + let df_col_asc = spark + .clone() + .range(Some(1), 4, 1, Some(1)) + .sort([col("id").desc()]); + + let mut result = Vec::new(); + let capture_result = |row: &RecordBatch| { + result.push(row.clone()); + }; + + df_col_asc.foreach(capture_result).await?; + + assert_eq!(result.len(), 3); + + assert_eq!(result[0].column(0).as_any().downcast_ref::().unwrap().value(0), 3); + assert_eq!(result[1].column(0).as_any().downcast_ref::().unwrap().value(0), 2); + assert_eq!(result[2].column(0).as_any().downcast_ref::().unwrap().value(0), 1); + + Ok(()) + } + #[tokio::test] async fn test_func_desc() -> Result<(), SparkError> {