Skip to content

Commit

Permalink
fixing initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaud de mattia committed Feb 5, 2024
1 parent a54b07b commit 74a0c43
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 107 deletions.
9 changes: 5 additions & 4 deletions desilike/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def callback(calculator):
self.calculators.append(calculator)
for require in calculator.runtime_info.requires:
require.runtime_info.initialize()
require.runtime_info._initialized_for_required_by.append(id(calculator))
require.runtime_info._initialized_for_pipeline.append(id(self))
if require in self.calculators:
del self.calculators[self.calculators.index(require)] # we want first dependencies at the end
callback(require)
Expand Down Expand Up @@ -663,7 +663,7 @@ def __init__(self, calculator, init=None):
if not isinstance(init, InitConfig):
self.init = InitConfig(init)
self._initialized = False
self._initialized_for_required_by = []
self._initialized_for_pipeline = []
self._tocalculate = True
self.calculated = False
self.name = self.calculator.__class__.__name__
Expand Down Expand Up @@ -722,7 +722,7 @@ def pipeline(self):
self._pipeline = BasePipeline(self.calculator)
else:
for calculator in self._pipeline.calculators[:-1]:
if not calculator.runtime_info._initialized_for_required_by:
if not calculator.runtime_info.initialized or id(self._pipeline) not in calculator.runtime_info._initialized_for_pipeline:
self._pipeline = BasePipeline(self.calculator)
break
return self._pipeline
Expand Down Expand Up @@ -750,6 +750,7 @@ def initialized(self):
"""Has this calculator been initialized?"""
if self.init.updated:
self._initialized = False
self._initialized_for_pipeline.clear()
return self._initialized

@initialized.setter
Expand Down Expand Up @@ -786,7 +787,7 @@ def initialize(self):
self.params = self.init.params
self.init.params = bak
self.initialized = True
self._initialized_for_required_by = []
self._initialized_for_pipeline = []
self._initialization = False
if getattr(self, '_requires', None) is None:
self._requires = []
Expand Down
1 change: 1 addition & 0 deletions desilike/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def get_wrapper(func):

return get_wrapper(args[0])


def use_jax(array):
"""Whether to use jax.numpy depending on whether array is jax's object."""
return isinstance(array, tuple(array_types))
Expand Down
Loading

0 comments on commit 74a0c43

Please sign in to comment.