Skip to content

Commit

Permalink
Merge pull request #6 from perib/dev
Browse files Browse the repository at this point in the history
allow survival selector to be none
  • Loading branch information
nickotto authored Apr 20, 2023
2 parents d0129f0 + 00819dd commit 388ce58
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 64 deletions.
11 changes: 6 additions & 5 deletions tpot2/base_evolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,11 +526,12 @@ def step(self,):

def one_generation_step(self, ): #your EA Algorithm goes here

n_survivors = max(1,int(self.cur_population_size*self.survival_percentage)) #always keep at least one individual
#Get survivors from current population
weighted_scores = self.population.get_column(self.population.population, column_names=self.objective_names) * self.objective_function_weights
new_population_index = np.ravel(self.survival_selector(weighted_scores, k=n_survivors)) #TODO make it clear that we are concatenating scores...
self.population.set_population(np.array(self.population.population)[new_population_index])
if self.survival_selector is not None:
n_survivors = max(1,int(self.cur_population_size*self.survival_percentage)) #always keep at least one individual
#Get survivors from current population
weighted_scores = self.population.get_column(self.population.population, column_names=self.objective_names) * self.objective_function_weights
new_population_index = np.ravel(self.survival_selector(weighted_scores, k=n_survivors)) #TODO make it clear that we are concatenating scores...
self.population.set_population(np.array(self.population.population)[new_population_index])
weighted_scores = self.population.get_column(self.population.population, column_names=self.objective_names) * self.objective_function_weights

#number of crossover pairs and mutation only parent to generate
Expand Down
2 changes: 1 addition & 1 deletion tpot2/evolutionary_algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .simple_evolver import SimpleEvolver


57 changes: 0 additions & 57 deletions tpot2/evolutionary_algorithms/simple_evolver.py

This file was deleted.

9 changes: 8 additions & 1 deletion tpot2/tpot_estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self, scorers,
verbose = 0,
periodic_checkpoint_folder = None,
callback: tpot2.CallBackInterface = None,
processes = True,
):

'''
Expand Down Expand Up @@ -376,6 +377,10 @@ def __init__(self, scorers,
callback : tpot2.CallBackInterface, default=None
Callback object. Not implemented
processes : bool, default=True
If True, will use multiprocessing to parallelize the optimization process. If False, will use threading.
True seems to perform better. However, False is required for interactive debugging.
'''

# sklearn BaseEstimator must have a corresponding attribute for each parameter.
Expand Down Expand Up @@ -443,6 +448,8 @@ def __init__(self, scorers,

self.objective_function_names = objective_function_names

self.processes = processes

#Initialize other used params


Expand Down Expand Up @@ -508,7 +515,7 @@ def fit(self, X, y):
silence_logs = 50
cluster = LocalCluster(n_workers=self.n_jobs, #if no client is passed in and no global client exists, create our own
threads_per_worker=1,
processes=True,
processes=self.processes,
silence_logs=silence_logs,
memory_limit=self.memory_limit)
_client = Client(cluster)
Expand Down
4 changes: 4 additions & 0 deletions tpot2/tpot_estimator/templates/tpottemplates.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__( self,
verbose=0,
periodic_checkpoint_folder=None,
callback: tpot2.CallBackInterface=None,
processes = True,
):
super(TPOTRegressor,self).__init__(
scorers=scorers,
Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__( self,
verbose=verbose,
periodic_checkpoint_folder=periodic_checkpoint_folder,
callback=callback,
processes=processes,
)


Expand Down Expand Up @@ -191,6 +193,7 @@ def __init__( self,
verbose=0,
periodic_checkpoint_folder=None,
callback: tpot2.CallBackInterface=None,
processes = True,
):
super(TPOTClassifier,self).__init__(
scorers=scorers,
Expand Down Expand Up @@ -252,6 +255,7 @@ def __init__( self,
verbose=verbose,
periodic_checkpoint_folder=periodic_checkpoint_folder,
callback=callback,
processes = processes,
)


Expand Down

0 comments on commit 388ce58

Please sign in to comment.