From de4037b2c2116654b649acf15bf9674376cc61c2 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Mon, 14 Oct 2024 22:51:31 +0200 Subject: [PATCH 1/3] docs(tracers): add types to simple tracers (#1245) --- openfisca_core/populations/_errors.py | 7 +- openfisca_core/populations/types.py | 6 +- openfisca_core/tracers/__init__.py | 28 ++-- openfisca_core/tracers/computation_log.py | 66 ++++---- openfisca_core/tracers/flat_trace.py | 50 +++--- openfisca_core/tracers/full_tracer.py | 189 +++++++++++----------- openfisca_core/tracers/simple_tracer.py | 66 +++++--- openfisca_core/tracers/trace_node.py | 117 +++++++++++--- openfisca_core/types.py | 3 + 9 files changed, 319 insertions(+), 213 deletions(-) diff --git a/openfisca_core/populations/_errors.py b/openfisca_core/populations/_errors.py index 77e378663..0aad0d11d 100644 --- a/openfisca_core/populations/_errors.py +++ b/openfisca_core/populations/_errors.py @@ -57,4 +57,9 @@ def __init__( super().__init__(msg) -__all__ = ["InvalidArraySizeError", "PeriodValidityError"] +__all__ = [ + "IncompatibleOptionsError", + "InvalidArraySizeError", + "InvalidOptionError", + "PeriodValidityError", +] diff --git a/openfisca_core/populations/types.py b/openfisca_core/populations/types.py index e6b14c5a8..07f34d2f5 100644 --- a/openfisca_core/populations/types.py +++ b/openfisca_core/populations/types.py @@ -2,7 +2,7 @@ from collections.abc import Iterable, MutableMapping, Sequence from typing import NamedTuple, Union -from typing_extensions import NewType, TypeAlias, TypedDict +from typing_extensions import TypeAlias, TypedDict from openfisca_core.types import ( Array, @@ -14,6 +14,7 @@ Holder, MemoryUsage, Period, + PeriodInt, PeriodStr, Role, Simulation, @@ -52,9 +53,6 @@ # Periods -#: New type for a period integer. -PeriodInt = NewType("PeriodInt", int) - #: Type alias for a period-like object. PeriodLike: TypeAlias = Union[Period, PeriodStr, PeriodInt] diff --git a/openfisca_core/tracers/__init__.py b/openfisca_core/tracers/__init__.py index e59d0122a..76e36b55c 100644 --- a/openfisca_core/tracers/__init__.py +++ b/openfisca_core/tracers/__init__.py @@ -21,12 +21,22 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from .computation_log import ComputationLog # noqa: F401 -from .flat_trace import FlatTrace # noqa: F401 -from .full_tracer import FullTracer # noqa: F401 -from .performance_log import PerformanceLog # noqa: F401 -from .simple_tracer import SimpleTracer # noqa: F401 -from .trace_node import TraceNode # noqa: F401 -from .tracing_parameter_node_at_instant import ( # noqa: F401 - TracingParameterNodeAtInstant, -) +from . import types +from .computation_log import ComputationLog +from .flat_trace import FlatTrace +from .full_tracer import FullTracer +from .performance_log import PerformanceLog +from .simple_tracer import SimpleTracer +from .trace_node import TraceNode +from .tracing_parameter_node_at_instant import TracingParameterNodeAtInstant + +__all__ = [ + "ComputationLog", + "FlatTrace", + "FullTracer", + "PerformanceLog", + "SimpleTracer", + "TraceNode", + "TracingParameterNodeAtInstant", + "types", +] diff --git a/openfisca_core/tracers/computation_log.py b/openfisca_core/tracers/computation_log.py index 6310eb884..9fcc09e25 100644 --- a/openfisca_core/tracers/computation_log.py +++ b/openfisca_core/tracers/computation_log.py @@ -1,39 +1,24 @@ from __future__ import annotations -import typing -from typing import Union +import sys import numpy from openfisca_core.indexed_enums import EnumArray -if typing.TYPE_CHECKING: - from numpy.typing import ArrayLike - - from openfisca_core import tracers - - Array = Union[EnumArray, ArrayLike] +from . import types as t class ComputationLog: - _full_tracer: tracers.FullTracer + _full_tracer: t.FullTracer - def __init__(self, full_tracer: tracers.FullTracer) -> None: + def __init__(self, full_tracer: t.FullTracer) -> None: self._full_tracer = full_tracer - def display( - self, - value: Array | None, - ) -> str: - if isinstance(value, EnumArray): - value = value.decode_to_str() - - return numpy.array2string(value, max_line_width=float("inf")) - def lines( self, aggregate: bool = False, - max_depth: int | None = None, + max_depth: int = sys.maxsize, ) -> list[str]: depth = 1 @@ -44,7 +29,7 @@ def lines( return self._flatten(lines_by_tree) - def print_log(self, aggregate=False, max_depth=None) -> None: + def print_log(self, aggregate: bool = False, max_depth: int = sys.maxsize) -> None: """Print the computation log of a simulation. If ``aggregate`` is ``False`` (default), print the value of each @@ -60,20 +45,20 @@ def print_log(self, aggregate=False, max_depth=None) -> None: If ``max_depth`` is set, for example to ``3``, only print computed vectors up to a depth of ``max_depth``. """ - for _line in self.lines(aggregate, max_depth): + for _ in self.lines(aggregate, max_depth): pass def _get_node_log( self, - node: tracers.TraceNode, + node: t.TraceNode, depth: int, aggregate: bool, - max_depth: int | None, + max_depth: int = sys.maxsize, ) -> list[str]: - if max_depth is not None and depth > max_depth: + if depth > max_depth: return [] - node_log = [self._print_line(depth, node, aggregate, max_depth)] + node_log = [self._print_line(depth, node, aggregate)] children_logs = [ self._get_node_log(child, depth + 1, aggregate, max_depth) @@ -82,13 +67,7 @@ def _get_node_log( return node_log + self._flatten(children_logs) - def _print_line( - self, - depth: int, - node: tracers.TraceNode, - aggregate: bool, - max_depth: int | None, - ) -> str: + def _print_line(self, depth: int, node: t.TraceNode, aggregate: bool) -> str: indent = " " * depth value = node.value @@ -97,9 +76,11 @@ def _print_line( elif aggregate: try: - formatted_value = str( + formatted_value = str( # pyright: ignore[reportCallIssue] { - "avg": numpy.mean(value), + "avg": numpy.mean( + value + ), # pyright: ignore[reportArgumentType,reportCallIssue] "max": numpy.max(value), "min": numpy.min(value), }, @@ -113,8 +94,15 @@ def _print_line( return f"{indent}{node.name}<{node.period}> >> {formatted_value}" - def _flatten( - self, - lists: list[list[str]], - ) -> list[str]: + @staticmethod + def display(value: t.VarArray, max_depth: int = sys.maxsize) -> str: + if isinstance(value, EnumArray): + value = value.decode_to_str() + return numpy.array2string(value, max_line_width=max_depth) + + @staticmethod + def _flatten(lists: list[list[str]]) -> list[str]: return [item for list_ in lists for item in list_] + + +__all__ = ["ComputationLog"] diff --git a/openfisca_core/tracers/flat_trace.py b/openfisca_core/tracers/flat_trace.py index 2090d537b..aea9288e3 100644 --- a/openfisca_core/tracers/flat_trace.py +++ b/openfisca_core/tracers/flat_trace.py @@ -1,34 +1,20 @@ from __future__ import annotations -import typing -from typing import Union - import numpy from openfisca_core.indexed_enums import EnumArray -if typing.TYPE_CHECKING: - from numpy.typing import ArrayLike - - from openfisca_core import tracers - - Array = Union[EnumArray, ArrayLike] - Trace = dict[str, dict] +from . import types as t class FlatTrace: - _full_tracer: tracers.FullTracer + _full_tracer: t.FullTracer - def __init__(self, full_tracer: tracers.FullTracer) -> None: + def __init__(self, full_tracer: t.FullTracer) -> None: self._full_tracer = full_tracer - def key(self, node: tracers.TraceNode) -> str: - name = node.name - period = node.period - return f"{name}<{period}>" - - def get_trace(self) -> dict: - trace = {} + def get_trace(self) -> t.FlatNodeMap: + trace: t.FlatNodeMap = {} for node in self._full_tracer.browse_trace(): # We don't want cache read to overwrite data about the initial @@ -45,7 +31,7 @@ def get_trace(self) -> dict: return trace - def get_serialized_trace(self) -> dict: + def get_serialized_trace(self) -> t.SerializedNodeMap: return { key: {**flat_trace, "value": self.serialize(flat_trace["value"])} for key, flat_trace in self.get_trace().items() @@ -53,26 +39,29 @@ def get_serialized_trace(self) -> dict: def serialize( self, - value: Array | None, - ) -> Array | None | list: + value: None | t.VarArray | t.ArrayLike[object], + ) -> None | t.ArrayLike[object]: + if value is None: + return None + if isinstance(value, EnumArray): - value = value.decode_to_str() + return value.decode_to_str().tolist() if isinstance(value, numpy.ndarray) and numpy.issubdtype( value.dtype, numpy.dtype(bytes), ): - value = value.astype(numpy.dtype(str)) + return value.astype(numpy.dtype(str)).tolist() if isinstance(value, numpy.ndarray): - value = value.tolist() + return value.tolist() return value def _get_flat_trace( self, - node: tracers.TraceNode, - ) -> Trace: + node: t.TraceNode, + ) -> t.FlatNodeMap: key = self.key(node) return { @@ -87,3 +76,10 @@ def _get_flat_trace( "formula_time": node.formula_time(), }, } + + @staticmethod + def key(node: t.TraceNode) -> t.NodeKey: + """Return the key of a node.""" + name = node.name + period = node.period + return t.NodeKey(f"{name}<{period}>") diff --git a/openfisca_core/tracers/full_tracer.py b/openfisca_core/tracers/full_tracer.py index 9fa94d5ab..56109cc9f 100644 --- a/openfisca_core/tracers/full_tracer.py +++ b/openfisca_core/tracers/full_tracer.py @@ -1,46 +1,118 @@ from __future__ import annotations -import typing -from typing import Union +from collections.abc import Iterator +import sys import time -from openfisca_core import tracers - -if typing.TYPE_CHECKING: - from collections.abc import Iterator - from numpy.typing import ArrayLike - - from openfisca_core.periods import Period - - Stack = list[dict[str, Union[str, Period]]] +from . import types as t +from .computation_log import ComputationLog +from .flat_trace import FlatTrace +from .performance_log import PerformanceLog +from .simple_tracer import SimpleTracer +from .trace_node import TraceNode class FullTracer: - _simple_tracer: tracers.SimpleTracer - _trees: list - _current_node: tracers.TraceNode | None + _simple_tracer: t.SimpleTracer + _trees: list[t.TraceNode] + _current_node: None | t.TraceNode def __init__(self) -> None: - self._simple_tracer = tracers.SimpleTracer() + self._simple_tracer = SimpleTracer() self._trees = [] self._current_node = None + @property + def stack(self) -> t.SimpleStack: + """Return the stack of traces.""" + return self._simple_tracer.stack + + @property + def trees(self) -> list[t.TraceNode]: + """Return the tree of traces.""" + return self._trees + + @property + def computation_log(self) -> t.ComputationLog: + """Return the computation log.""" + return ComputationLog(self) + + @property + def performance_log(self) -> t.PerformanceLog: + """Return the performance log.""" + return PerformanceLog(self) + + @property + def flat_trace(self) -> t.FlatTrace: + """Return the flat trace.""" + return FlatTrace(self) + def record_calculation_start( self, - variable: str, - period: Period | int, + variable: t.VariableName, + period: t.PeriodInt | t.Period, ) -> None: self._simple_tracer.record_calculation_start(variable, period) self._enter_calculation(variable, period) self._record_start_time() + def record_parameter_access( + self, + parameter: str, + period: t.Period, + value: t.VarArray, + ) -> None: + if self._current_node is not None: + self._current_node.parameters.append( + TraceNode(name=parameter, period=period, value=value), + ) + + def record_calculation_result(self, value: t.VarArray) -> None: + if self._current_node is not None: + self._current_node.value = value + + def record_calculation_end(self) -> None: + self._simple_tracer.record_calculation_end() + self._record_end_time() + self._exit_calculation() + + def print_computation_log( + self, aggregate: bool = False, max_depth: int = sys.maxsize + ) -> None: + self.computation_log.print_log(aggregate, max_depth) + + def generate_performance_graph(self, dir_path: str) -> None: + self.performance_log.generate_graph(dir_path) + + def generate_performance_tables(self, dir_path: str) -> None: + self.performance_log.generate_performance_tables(dir_path) + + def get_nb_requests(self, variable: str) -> int: + return sum(self._get_nb_requests(tree, variable) for tree in self.trees) + + def get_flat_trace(self) -> dict: + return self.flat_trace.get_trace() + + def get_serialized_flat_trace(self) -> dict: + return self.flat_trace.get_serialized_trace() + + def browse_trace(self) -> Iterator[t.TraceNode]: + def _browse_node(node: t.TraceNode) -> Iterator[t.TraceNode]: + yield node + + for child in node.children: + yield from _browse_node(child) + + for node in self._trees: + yield from _browse_node(node) + def _enter_calculation( self, - variable: str, - period: Period, + variable: t.VariableName, + period: t.PeriodInt | t.Period, ) -> None: - new_node = tracers.TraceNode( + new_node = TraceNode( name=variable, period=period, parent=self._current_node, @@ -54,17 +126,6 @@ def _enter_calculation( self._current_node = new_node - def record_parameter_access( - self, - parameter: str, - period: Period, - value: ArrayLike, - ) -> None: - if self._current_node is not None: - self._current_node.parameters.append( - tracers.TraceNode(name=parameter, period=period, value=value), - ) - def _record_start_time( self, time_in_s: float | None = None, @@ -75,18 +136,9 @@ def _record_start_time( if self._current_node is not None: self._current_node.start = time_in_s - def record_calculation_result(self, value: ArrayLike) -> None: - if self._current_node is not None: - self._current_node.value = value - - def record_calculation_end(self) -> None: - self._simple_tracer.record_calculation_end() - self._record_end_time() - self._exit_calculation() - def _record_end_time( self, - time_in_s: float | None = None, + time_in_s: None | t.Time = None, ) -> None: if time_in_s is None: time_in_s = self._get_time_in_sec() @@ -98,39 +150,7 @@ def _exit_calculation(self) -> None: if self._current_node is not None: self._current_node = self._current_node.parent - @property - def stack(self) -> Stack: - return self._simple_tracer.stack - - @property - def trees(self) -> list[tracers.TraceNode]: - return self._trees - - @property - def computation_log(self) -> tracers.ComputationLog: - return tracers.ComputationLog(self) - - @property - def performance_log(self) -> tracers.PerformanceLog: - return tracers.PerformanceLog(self) - - @property - def flat_trace(self) -> tracers.FlatTrace: - return tracers.FlatTrace(self) - - def _get_time_in_sec(self) -> float: - return time.time_ns() / (10**9) - - def print_computation_log(self, aggregate=False, max_depth=None) -> None: - self.computation_log.print_log(aggregate, max_depth) - - def generate_performance_graph(self, dir_path: str) -> None: - self.performance_log.generate_graph(dir_path) - - def generate_performance_tables(self, dir_path: str) -> None: - self.performance_log.generate_performance_tables(dir_path) - - def _get_nb_requests(self, tree: tracers.TraceNode, variable: str) -> int: + def _get_nb_requests(self, tree: t.TraceNode, variable: str) -> int: tree_call = tree.name == variable children_calls = sum( self._get_nb_requests(child, variable) for child in tree.children @@ -138,21 +158,6 @@ def _get_nb_requests(self, tree: tracers.TraceNode, variable: str) -> int: return tree_call + children_calls - def get_nb_requests(self, variable: str) -> int: - return sum(self._get_nb_requests(tree, variable) for tree in self.trees) - - def get_flat_trace(self) -> dict: - return self.flat_trace.get_trace() - - def get_serialized_flat_trace(self) -> dict: - return self.flat_trace.get_serialized_trace() - - def browse_trace(self) -> Iterator[tracers.TraceNode]: - def _browse_node(node): - yield node - - for child in node.children: - yield from _browse_node(child) - - for node in self._trees: - yield from _browse_node(node) + @staticmethod + def _get_time_in_sec() -> t.Time: + return time.time_ns() / (10**9) diff --git a/openfisca_core/tracers/simple_tracer.py b/openfisca_core/tracers/simple_tracer.py index 84328730e..174dd3119 100644 --- a/openfisca_core/tracers/simple_tracer.py +++ b/openfisca_core/tracers/simple_tracer.py @@ -1,34 +1,64 @@ from __future__ import annotations -import typing -from typing import Union - -if typing.TYPE_CHECKING: - from numpy.typing import ArrayLike - - from openfisca_core.periods import Period - - Stack = list[dict[str, Union[str, Period]]] +from . import types as t class SimpleTracer: - _stack: Stack + """A simple tracer that records a stack of traces.""" + + #: The stack of traces. + _stack: t.SimpleStack def __init__(self) -> None: self._stack = [] - def record_calculation_start(self, variable: str, period: Period | int) -> None: + @property + def stack(self) -> t.SimpleStack: + """Return the stack of traces.""" + return self._stack + + def record_calculation_start( + self, variable: t.VariableName, period: t.PeriodInt | t.Period + ) -> None: + """Record the start of a calculation. + + Args: + variable: The variable being calculated. + period: The period for which the variable is being calculated. + + Examples: + >>> from openfisca_core import tracers + + >>> tracer = tracers.SimpleTracer() + >>> tracer.record_calculation_start("variable", 2020) + >>> tracer.stack + [{'name': 'variable', 'period': 2020}] + + """ self.stack.append({"name": variable, "period": period}) - def record_calculation_result(self, value: ArrayLike) -> None: - pass # ignore calculation result + def record_calculation_result(self, value: t.ArrayLike[object]) -> None: + """Ignore calculation result.""" - def record_parameter_access(self, parameter: str, period, value) -> None: - pass + def record_parameter_access( + self, parameter: str, period: t.Period, value: t.ArrayLike[object] + ) -> None: + """Ignore parameter access.""" def record_calculation_end(self) -> None: + """Record the end of a calculation. + + Examples: + >>> from openfisca_core import tracers + + >>> tracer = tracers.SimpleTracer() + >>> tracer.record_calculation_start("variable", 2020) + >>> tracer.record_calculation_end() + >>> tracer.stack + [] + + """ self.stack.pop() - @property - def stack(self) -> Stack: - return self._stack + +__all__ = ["SimpleTracer"] diff --git a/openfisca_core/tracers/trace_node.py b/openfisca_core/tracers/trace_node.py index ff55a5714..de81825e8 100644 --- a/openfisca_core/tracers/trace_node.py +++ b/openfisca_core/tracers/trace_node.py @@ -1,31 +1,61 @@ from __future__ import annotations -import typing - import dataclasses -if typing.TYPE_CHECKING: - import numpy - - from openfisca_core.indexed_enums import EnumArray - from openfisca_core.periods import Period - - Array = typing.Union[EnumArray, numpy.typing.ArrayLike] - Time = typing.Union[float, int] +from . import types as t @dataclasses.dataclass class TraceNode: + """A node in the tracing tree.""" + + #: The name of the node. name: str - period: Period - parent: TraceNode | None = None - children: list[TraceNode] = dataclasses.field(default_factory=list) - parameters: list[TraceNode] = dataclasses.field(default_factory=list) - value: Array | None = None - start: float = 0 - end: float = 0 - - def calculation_time(self, round_: bool = True) -> Time: + + #: The period of the node. + period: t.PeriodInt | t.Period + + #: The parent of the node. + parent: None | t.TraceNode = None + + #: The children of the node. + children: list[t.TraceNode] = dataclasses.field(default_factory=list) + + #: The parameters of the node. + parameters: list[t.TraceNode] = dataclasses.field(default_factory=list) + + #: The value of the node. + value: None | t.VarArray = None + + #: The start time of the node. + start: t.Time = 0.0 + + #: The end time of the node. + end: t.Time = 0.0 + + def calculation_time(self, round_: bool = True) -> t.Time: + """Calculate the time spent in the node. + + Args: + round_: Whether to round the result. + + Returns: + float: The time spent in the node. + + Examples: + >>> from openfisca_core import tracers + + >>> node = tracers.TraceNode("variable", 2020) + >>> node.start = 1.123122313 + >>> node.end = 1.12312313123 + + >>> node.calculation_time() + 8.182e-07 + + >>> node.calculation_time(round_=False) + 8.182299999770493e-07 + + """ result = self.end - self.start if round_: @@ -33,7 +63,29 @@ def calculation_time(self, round_: bool = True) -> Time: return result - def formula_time(self) -> float: + def formula_time(self) -> t.Time: + """Calculate the time spent on the formula. + + Returns: + float: The time spent on the formula. + + Examples: + >>> from openfisca_core import tracers + + >>> node = tracers.TraceNode("variable", 2020) + >>> node.start = 1.123122313 * 11 + >>> node.end = 1.12312313123 * 11 + >>> child = tracers.TraceNode("variable", 2020) + >>> child.start = 1.123122313 + >>> child.end = 1.12312313123 + + >>> for i in range(10): + ... node.children = [child, *node.children] + + >>> node.formula_time() + 8.182e-07 + + """ children_calculation_time = sum( child.calculation_time(round_=False) for child in self.children ) @@ -42,9 +94,28 @@ def formula_time(self) -> float: return self.round(result) - def append_child(self, node: TraceNode) -> None: + def append_child(self, node: t.TraceNode) -> None: + """Append a child to the node.""" self.children.append(node) @staticmethod - def round(time: Time) -> float: - return float(f"{time:.4g}") # Keep only 4 significant figures + def round(time: t.Time) -> t.Time: + """Keep only 4 significant figures. + + Args: + time: The time to round. + + Returns: + float: The rounded time. + + Examples: + >>> from openfisca_core import tracers + + >>> tracers.TraceNode.round(0.000123456789) + 0.0001235 + + """ + return float(f"{time:.4g}") + + +__all__ = ["TraceNode"] diff --git a/openfisca_core/types.py b/openfisca_core/types.py index 85adf7b51..9fb94d1f5 100644 --- a/openfisca_core/types.py +++ b/openfisca_core/types.py @@ -157,6 +157,9 @@ class ParameterNodeAtInstant(Protocol): ... #: For example "2000-01". InstantStr = NewType("InstantStr", str) +#: For example 2020. +PeriodInt = NewType("PeriodInt", int) + #: For example "1:2000-01-01:day". PeriodStr = NewType("PeriodStr", str) From 4b8f00ed5e2b95a2be6e911333f17994c6c1628a Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Mon, 14 Oct 2024 23:27:43 +0200 Subject: [PATCH 2/3] docs(tracers): add types to full tracer (#1245) --- openfisca_core/tracers/flat_trace.py | 45 ++++++----- openfisca_core/tracers/full_tracer.py | 7 +- openfisca_core/tracers/types.py | 108 ++++++++++++++++++++++++++ openfisca_core/types.py | 25 +++++- 4 files changed, 161 insertions(+), 24 deletions(-) create mode 100644 openfisca_core/tracers/types.py diff --git a/openfisca_core/tracers/flat_trace.py b/openfisca_core/tracers/flat_trace.py index aea9288e3..412ac8b02 100644 --- a/openfisca_core/tracers/flat_trace.py +++ b/openfisca_core/tracers/flat_trace.py @@ -37,27 +37,6 @@ def get_serialized_trace(self) -> t.SerializedNodeMap: for key, flat_trace in self.get_trace().items() } - def serialize( - self, - value: None | t.VarArray | t.ArrayLike[object], - ) -> None | t.ArrayLike[object]: - if value is None: - return None - - if isinstance(value, EnumArray): - return value.decode_to_str().tolist() - - if isinstance(value, numpy.ndarray) and numpy.issubdtype( - value.dtype, - numpy.dtype(bytes), - ): - return value.astype(numpy.dtype(str)).tolist() - - if isinstance(value, numpy.ndarray): - return value.tolist() - - return value - def _get_flat_trace( self, node: t.TraceNode, @@ -83,3 +62,27 @@ def key(node: t.TraceNode) -> t.NodeKey: name = node.name period = node.period return t.NodeKey(f"{name}<{period}>") + + @staticmethod + def serialize( + value: None | t.VarArray | t.ArrayLike[object], + ) -> None | t.ArrayLike[object]: + if value is None: + return None + + if isinstance(value, EnumArray): + return value.decode_to_str().tolist() + + if isinstance(value, numpy.ndarray) and numpy.issubdtype( + value.dtype, + numpy.dtype(bytes), + ): + return value.astype(numpy.dtype(str)).tolist() + + if isinstance(value, numpy.ndarray): + return value.tolist() + + return value + + +__all__ = ["FlatTrace"] diff --git a/openfisca_core/tracers/full_tracer.py b/openfisca_core/tracers/full_tracer.py index 56109cc9f..f6f793e19 100644 --- a/openfisca_core/tracers/full_tracer.py +++ b/openfisca_core/tracers/full_tracer.py @@ -91,10 +91,10 @@ def generate_performance_tables(self, dir_path: str) -> None: def get_nb_requests(self, variable: str) -> int: return sum(self._get_nb_requests(tree, variable) for tree in self.trees) - def get_flat_trace(self) -> dict: + def get_flat_trace(self) -> t.FlatNodeMap: return self.flat_trace.get_trace() - def get_serialized_flat_trace(self) -> dict: + def get_serialized_flat_trace(self) -> t.SerializedNodeMap: return self.flat_trace.get_serialized_trace() def browse_trace(self) -> Iterator[t.TraceNode]: @@ -161,3 +161,6 @@ def _get_nb_requests(self, tree: t.TraceNode, variable: str) -> int: @staticmethod def _get_time_in_sec() -> t.Time: return time.time_ns() / (10**9) + + +__all__ = ["FullTracer"] diff --git a/openfisca_core/tracers/types.py b/openfisca_core/tracers/types.py new file mode 100644 index 000000000..f26c85424 --- /dev/null +++ b/openfisca_core/tracers/types.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import NewType, Protocol +from typing_extensions import TypeAlias, TypedDict + +from openfisca_core.types import ( + Array, + ArrayLike, + ParameterNode, + ParameterNodeChild, + Period, + PeriodInt, + VariableName, +) + +from numpy import generic as VarDType + +#: A type of a generic array. +VarArray: TypeAlias = Array[VarDType] + +#: A type representing a unit time. +Time: TypeAlias = float + +#: A type representing a mapping of flat traces. +FlatNodeMap: TypeAlias = dict["NodeKey", "FlatTraceMap"] + +#: A type representing a mapping of serialized traces. +SerializedNodeMap: TypeAlias = dict["NodeKey", "SerializedTraceMap"] + +#: A stack of simple traces. +SimpleStack: TypeAlias = list["SimpleTraceMap"] + +#: Key of a trace. +NodeKey = NewType("NodeKey", str) + + +class FlatTraceMap(TypedDict, total=True): + dependencies: list[NodeKey] + parameters: dict[NodeKey, None | ArrayLike[object]] + value: None | VarArray + calculation_time: Time + formula_time: Time + + +class SerializedTraceMap(TypedDict, total=True): + dependencies: list[NodeKey] + parameters: dict[NodeKey, None | ArrayLike[object]] + value: None | ArrayLike[object] + calculation_time: Time + formula_time: Time + + +class SimpleTraceMap(TypedDict, total=True): + name: VariableName + period: int | Period + + +class ComputationLog(Protocol): + def print_log(self, aggregate: bool = ..., max_depth: int = ..., /) -> None: ... + + +class FlatTrace(Protocol): + def get_trace(self, /) -> FlatNodeMap: ... + def get_serialized_trace(self, /) -> SerializedNodeMap: ... + + +class FullTracer(Protocol): + @property + def trees(self, /) -> list[TraceNode]: ... + def browse_trace(self, /) -> Iterator[TraceNode]: ... + + +class PerformanceLog(Protocol): + def generate_graph(self, dir_path: str, /) -> None: ... + def generate_performance_tables(self, dir_path: str, /) -> None: ... + + +class SimpleTracer(Protocol): + @property + def stack(self, /) -> SimpleStack: ... + def record_calculation_start( + self, variable: VariableName, period: PeriodInt | Period, / + ) -> None: ... + def record_calculation_end(self, /) -> None: ... + + +class TraceNode(Protocol): + children: list[TraceNode] + end: Time + name: str + parameters: list[TraceNode] + parent: None | TraceNode + period: PeriodInt | Period + start: Time + value: None | VarArray + + def calculation_time(self, *, round_: bool = ...) -> Time: ... + def formula_time(self, /) -> Time: ... + def append_child(self, node: TraceNode, /) -> None: ... + + +__all__ = [ + "ArrayLike", + "ParameterNode", + "ParameterNodeChild", + "PeriodInt", +] diff --git a/openfisca_core/types.py b/openfisca_core/types.py index 9fb94d1f5..9c8105741 100644 --- a/openfisca_core/types.py +++ b/openfisca_core/types.py @@ -148,8 +148,31 @@ class MemoryUsage(TypedDict, total=False): # Parameters +#: A type representing a node of parameters. +ParameterNode: TypeAlias = Union[ + "ParameterNodeAtInstant", "VectorialParameterNodeAtInstant" +] -class ParameterNodeAtInstant(Protocol): ... +#: A type representing a ??? +ParameterNodeChild: TypeAlias = Union[ParameterNode, ArrayLike[object]] + + +class ParameterNodeAtInstant(Protocol): + _instant_str: InstantStr + + def __contains__(self, __item: object, /) -> bool: ... + def __getitem__( + self, __index: str | Array[DTypeGeneric], / + ) -> ParameterNodeChild: ... + + +class VectorialParameterNodeAtInstant(Protocol): + _instant_str: InstantStr + + def __contains__(self, item: object, /) -> bool: ... + def __getitem__( + self, __index: str | Array[DTypeGeneric], / + ) -> ParameterNodeChild: ... # Periods From d859a62b8ee0807cca3d9a2e743b8666188f8e70 Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Mon, 14 Oct 2024 23:36:39 +0200 Subject: [PATCH 3/3] build: version bump --- CHANGELOG.md | 6 ++++++ setup.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bed82be51..43d8d9d6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +### 43.2.2 [#1280](https://github.com/openfisca/openfisca-core/pull/1280) + +#### Documentation + +- Add types to common tracers (`SimpleTracer`, `FlatTracer`, etc.) + ### 43.2.1 [#1283](https://github.com/openfisca/openfisca-core/pull/1283) #### Technical changes diff --git a/setup.py b/setup.py index 43a4c49d1..40c16dbff 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ setup( name="OpenFisca-Core", - version="43.2.1", + version="43.2.2", author="OpenFisca Team", author_email="contact@openfisca.org", classifiers=[