Skip to content

Commit

Permalink
Create random_forest.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KOSASIH authored Jul 28, 2024
1 parent a6b2337 commit ba9ca3b
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions ai/models/random_forest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

class RandomForestModel:
def __init__(self, n_estimators=100, max_depth=None):
self.model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)

def train(self, X_train, y_train):
self.model.fit(X_train, y_train)

def predict(self, X_test):
return self.model.predict(X_test)

def evaluate(self, X_test, y_test):
y_pred = self.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
return accuracy

0 comments on commit ba9ca3b

Please sign in to comment.