Skip to content

Commit

Permalink
updated linear regression
Browse files Browse the repository at this point in the history
  • Loading branch information
ghimiresdp committed Sep 27, 2023
1 parent d613a4b commit e4bcee9
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions challenge_mid/src/c2_linear_regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,13 @@ pub mod losses {
struct LinearRegressionModel {
w0: f64,
w: Vec<f64>,
e: f64,
}

impl LinearRegressionModel {
fn new() -> Self {
Self {
w0: 0.1,
w: vec![], // coefficients
e: 0.0,
}
}
fn _predict(&self, x: &Vec<f64>) -> f64 {
Expand All @@ -72,9 +70,8 @@ impl LinearRegressionModel {
}
self.w = vec![0.0; x.get(0).unwrap().len()];
for epoch in 0..epochs {
// use gradient descent method to optimize the algorithm
let mut gradients = vec![0.0; self.w.len()];
println!("Epoch: {epoch}");

for idx in 0..x.len() {
let prediction = self._predict(&x[idx]);
let error = prediction - y[idx];
Expand All @@ -91,10 +88,7 @@ impl LinearRegressionModel {
.zip(gradients.clone())
.map(|(v, g)| v - (learning_rate / gradients.len() as f64) * g)
.collect();

let predictions: Vec<f64> = x.iter().map(|row| self._predict(row)).collect();
let error = losses::mean_squared_error(y.clone(), predictions.clone());
println!("actual: {y:?}\n predicted: {predictions:?}\n loss: {error}\n\n");
println!("Epoch: {epoch}\t loss: {}", self.test(x.clone(), y.clone()));
}
}

Expand All @@ -107,6 +101,12 @@ impl LinearRegressionModel {
}
self._predict(x)
}

fn test(&mut self, x: Vec<Vec<f64>>, y: Vec<f64>) -> f64 {
let predictions: Vec<f64> = x.iter().map(|row| self._predict(row)).collect();
let error = losses::mean_squared_error(y.clone(), predictions.clone());
error
}
}

fn main() {
Expand All @@ -126,3 +126,18 @@ fn main() {
let out = model.predict(vec![1.0, 2.0].as_ref());
println!("Actual: 5.0, Prediction: {out}");
}

#[cfg(test)]
mod tests {
use crate::LinearRegressionModel;

#[test]
fn test_correct_prediction() {
let x = vec![vec![1.0, 1.0], vec![2.0, 2.0], vec![3.0, 3.0]];
let y = vec![2.0, 4.0, 6.0];
let mut model = LinearRegressionModel::new();
model.fit(x, y, 0.001, 1000);
let loss = (model.predict(vec![4.0, 4.0].as_ref()) - 8.0).abs();
assert_eq!(loss < 0.2, true);
}
}

0 comments on commit e4bcee9

Please sign in to comment.