diff --git a/compiler_opt/es/combined_blackbox_optimizers.py b/compiler_opt/es/combined_blackbox_optimizers.py index ec3133d2..7e9d1b59 100644 --- a/compiler_opt/es/combined_blackbox_optimizers.py +++ b/compiler_opt/es/combined_blackbox_optimizers.py @@ -36,13 +36,14 @@ import numpy as np from sklearn import linear_model from typing import List, Dict, Callable, Tuple, Any, Optional +import numpy.typing as npt import gradient_ascent_optimization_algorithms def filter_top_directions( - perturbations: np.ndarray, function_values: np.ndarray, est_type: str, - num_top_directions: int) -> Tuple[np.ndarray, np.ndarray]: + perturbations: npt.NDArray, function_values: npt.NDArray, est_type: str, + num_top_directions: int) -> Tuple[npt.NDArray, npt.NDArray]: """Select the subset of top-performing perturbations. TODO(b/139662389): In the future, we may want (either here or inside the @@ -94,8 +95,8 @@ class BlackboxOptimizer(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def run_step(self, perturbations: np.ndarray, function_values: np.ndarray, - current_input: np.ndarray, current_value: float) -> np.ndarray: + def run_step(self, perturbations: npt.NDArray, function_values: npt.NDArray, + current_input: npt.NDArray, current_value: float) -> npt.NDArray: """Conducts a single step of blackbox optimization procedure. Conducts a single step of blackbox optimization procedure, given values of @@ -141,7 +142,7 @@ def get_state(self) -> List: raise NotImplementedError("Abstract method") @abc.abstractmethod - def update_state(self, evaluation_stats: List | np.ndarray) -> None: + def update_state(self, evaluation_stats: List | npt.NDArray) -> None: """Updates the state for blackbox function runs. Updates the state of the optimizer for blackbox function runs. @@ -154,7 +155,7 @@ def update_state(self, evaluation_stats: List | np.ndarray) -> None: raise NotImplementedError("Abstract method") @abc.abstractmethod - def set_state(self, state: List | np.ndarray) -> None: + def set_state(self, state: List | npt.NDArray) -> None: """Sets up the internal state of the optimizer. Sets up the internal state of the optimizer. @@ -204,8 +205,8 @@ def __init__( self.ga_optimizer = ga_optimizer super().__init__() - def run_step(self, perturbations: np.ndarray, function_values: np.ndarray, - current_input: np.ndarray, current_value: float) -> np.ndarray: + def run_step(self, perturbations: npt.NDArray, function_values: npt.NDArray, + current_input: npt.NDArray, current_value: float) -> npt.NDArray: dim = len(current_input) if self.normalize_fvalues: values = function_values.tolist() @@ -255,7 +256,7 @@ def get_state(self) -> List: else: return ga_state - def update_state(self, evaluation_stats: List | np.ndarray) -> None: + def update_state(self, evaluation_stats: List | npt.NDArray) -> None: if self.hyperparameters_update_method == "state_normalization": self.nb_steps += evaluation_stats[0] evaluation_stats = evaluation_stats[1:] @@ -279,7 +280,7 @@ def update_state(self, evaluation_stats: List | np.ndarray) -> None: for a, b in zip(mean_squares_state_vector, self.mean_state_vector) ] - def set_state(self, state: List | np.ndarray) -> None: + def set_state(self, state: List | npt.NDArray) -> None: if self.hyperparameters_update_method == "state_normalization": self.nb_steps = state[0] state = state[1:] @@ -345,8 +346,8 @@ def set_state(self, state: List | np.ndarray) -> None: """ -def normalize_function_values(function_values: np.ndarray, - current_value: float) -> Tuple[np.ndarray, List]: +def normalize_function_values(function_values: npt.NDArray, + current_value: float) -> Tuple[npt.NDArray, List]: values = function_values.tolist() values.append(current_value) mean = sum(values) / float(len(values)) @@ -357,10 +358,10 @@ def normalize_function_values(function_values: np.ndarray, def mc_gradient(precision_parameter: float, est_type: str, - perturbations: np.ndarray, - function_values: np.ndarray, + perturbations: npt.NDArray, + function_values: npt.NDArray, current_value: float, - energy: float = 0) -> np.ndarray: + energy: float = 0) -> npt.NDArray: """Calculates Monte Carlo gradient. There are several ways of estimating the gradient. This is specified by the @@ -401,9 +402,9 @@ def mc_gradient(precision_parameter: float, def sklearn_regression_gradient(clf: linear_model, est_type: str, - perturbations: np.ndarray, - function_values: np.ndarray, - current_value: float) -> np.ndarray: + perturbations: npt.NDArray, + function_values: npt.NDArray, + current_value: float) -> npt.NDArray: """Calculates gradient by function difference regression. Args: @@ -469,7 +470,7 @@ class QuadraticModel(object): f(x) = 1/2x^TAx + b^Tx + c """ - def __init__(self, Av: Callable, b: np.ndarray, c: float = 0): + def __init__(self, Av: Callable, b: npt.NDArray, c: float = 0): """Initialize quadratic function. Args: @@ -482,7 +483,7 @@ def __init__(self, Av: Callable, b: np.ndarray, c: float = 0): self.b = b self.c = c - def f(self, x: np.ndarray) -> float: + def f(self, x: npt.NDArray) -> float: """Evaluate the quadratic function. Args: @@ -492,7 +493,7 @@ def f(self, x: np.ndarray) -> float: """ return 0.5 * np.dot(x, self.quad_v(x)) + np.dot(x, self.b) + self.c - def grad(self, x: np.ndarray) -> np.ndarray: + def grad(self, x: npt.NDArray) -> npt.NDArray: """Evaluate the gradient of the quadratic, Ax + b. Args: @@ -517,7 +518,7 @@ class ProjectedGradientOptimizer(object): def __init__(self, function_object: QuadraticModel, projection_operator: Callable, pgd_params: Dict[str, Any], - x_init: np.ndarray[float]): + x_init: npt.NDArray[np.float32]): self.f = function_object self.proj = projection_operator if pgd_params is not None: @@ -560,7 +561,7 @@ def run_step(self) -> None: self.x = x_next self.k += 1 - def get_solution(self) -> np.ndarray: + def get_solution(self) -> npt.NDArray: return self.x def get_x_diff_norm(self) -> float: @@ -608,7 +609,7 @@ class TrustRegionSubproblemOptimizer(object): def __init__(self, model_function: QuadraticModel, trust_region_params: Dict[str, Any], - x_init: Optional[np.ndarray[float]] = None): + x_init: Optional[npt.NDArray[np.float32]] = None): self.mf = model_function self.params = trust_region_params self.center = x_init @@ -643,7 +644,7 @@ def solve_trust_region_subproblem(self) -> None: self.x = pgd_solver.get_solution() - def get_solution(self) -> np.ndarray: + def get_solution(self) -> npt.NDArray: return self.x @@ -779,7 +780,7 @@ def __init__(self, precision_parameter: float, est_type: str, self.clf = linear_model.Lasso(alpha=self.params['grad_reg_alpha']) self.is_returned_step = False - def trust_region_test(self, current_input: np.ndarray, + def trust_region_test(self, current_input: npt.NDArray, current_value: float) -> bool: """Test the next step to determine how to update the trust region. @@ -848,8 +849,8 @@ def trust_region_test(self, current_input: np.ndarray, print('Unchanged: ' + str(self.radius) + log_message) return True - def update_hessian_part(self, perturbations: np.ndarray, - function_values: np.ndarray, current_value: float, + def update_hessian_part(self, perturbations: npt.NDArray, + function_values: npt.NDArray, current_value: float, is_update: bool) -> None: """Updates the internal state which stores Hessian information. @@ -918,7 +919,7 @@ def create_hessv_function(self) -> Callable: """ if self.params['dense_hessian']: - def hessv_func(x: np.ndarray) -> np.ndarray: + def hessv_func(x: npt.NDArray) -> npt.NDArray: """Calculates Hessian-vector product from dense Hessian. Args: @@ -934,7 +935,7 @@ def hessv_func(x: np.ndarray) -> np.ndarray: return hessv else: - def hessv_func(x: np.ndarray) -> np.ndarray: + def hessv_func(x: npt.NDArray) -> npt.NDArray: """Calculates Hessian-vector product from perturbation/value pairs. Args: @@ -961,8 +962,8 @@ def hessv_func(x: np.ndarray) -> np.ndarray: return hessv_func - def update_quadratic_model(self, perturbations: np.ndarray, - function_values: np.ndarray, current_value: float, + def update_quadratic_model(self, perturbations: npt.NDArray, + function_values: npt.NDArray, current_value: float, is_update: bool) -> QuadraticModel: """Updates the internal state of the optimizer with new perturbations. @@ -1009,8 +1010,8 @@ def update_quadratic_model(self, perturbations: np.ndarray, is_update) return QuadraticModel(self.create_hessv_function(), self.saved_gradient) - def run_step(self, perturbations: np.ndarray, function_values: np.ndarray, - current_input: np.ndarray, current_value: float) -> np.ndarray: + def run_step(self, perturbations: npt.NDArray, function_values: npt.NDArray, + current_input: npt.NDArray, current_value: float) -> npt.NDArray: """Run a single step of trust region optimizer. Args: @@ -1088,7 +1089,7 @@ def get_state(self) -> List[float]: else: return [] - def update_state(self, evaluation_stats: List | np.ndarray) -> None: + def update_state(self, evaluation_stats: List | npt.NDArray) -> None: if self.hyperparameters_update_method == 'state_normalization': self.nb_steps += evaluation_stats[0] evaluation_stats = evaluation_stats[1:] @@ -1112,7 +1113,7 @@ def update_state(self, evaluation_stats: List | np.ndarray) -> None: for a, b in zip(mean_squares_state_vector, self.mean_state_vector) ] - def set_state(self, state: List | np.ndarray) -> None: + def set_state(self, state: List | npt.NDArray) -> None: if self.hyperparameters_update_method == 'state_normalization': self.nb_steps = state[0] state = state[1:] diff --git a/compiler_opt/es/gradient_ascent_optimization_algorithms.py b/compiler_opt/es/gradient_ascent_optimization_algorithms.py index afbcef06..a9a0b632 100644 --- a/compiler_opt/es/gradient_ascent_optimization_algorithms.py +++ b/compiler_opt/es/gradient_ascent_optimization_algorithms.py @@ -30,6 +30,7 @@ import abc import numpy as np from typing import List +import numpy.typing as npt # TODO(kchoro): Borrow JAXs optimizer library here. Integrated into Blackbox-v2. @@ -41,8 +42,8 @@ class GAOptimizer(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def run_step(self, current_input: np.ndarray, - gradient: np.ndarray[np.float32]) -> np.ndarray: + def run_step(self, current_input: npt.NDArray, + gradient: npt.NDArray[np.float32]) -> npt.NDArray: """Conducts a single step of gradient ascent optimization. Conduct a single step of gradient ascent optimization procedure, given the @@ -71,7 +72,7 @@ def get_state(self) -> List[np.float32]: raise NotImplementedError("Abstract method") @abc.abstractmethod - def set_state(self, state: np.ndarray[np.float32]) -> None: + def set_state(self, state: npt.NDArray[np.float32]) -> None: """Sets up the internal state of the optimizer. Sets up the internal state of the optimizer. @@ -98,8 +99,8 @@ def __init__(self, step_size: float, momentum: float): self.moving_average = np.asarray([], dtype=np.float32) super().__init__() - def run_step(self, current_input: np.ndarray, - gradient: np.ndarray[np.float32]) -> np.ndarray: + def run_step(self, current_input: npt.NDArray, + gradient: npt.NDArray[np.float32]) -> npt.NDArray: if self.moving_average.size == 0: # Initialize the moving average self.moving_average = np.zeros(len(current_input), dtype=np.float32) @@ -119,7 +120,7 @@ def run_step(self, current_input: np.ndarray, def get_state(self) -> List[np.float32]: return self.moving_average.tolist() - def set_state(self, state: np.ndarray[np.float32]) -> None: + def set_state(self, state: npt.NDArray[np.float32]) -> None: self.moving_average = np.asarray(state, dtype=np.float32) @@ -141,8 +142,8 @@ def __init__(self, self.t = 0 super().__init__() - def run_step(self, current_input: np.ndarray, - gradient: np.ndarray[np.float32]) -> np.ndarray: + def run_step(self, current_input: npt.NDArray, + gradient: npt.NDArray[np.float32]) -> npt.NDArray: if self.first_moment_moving_average.size == 0: # Initialize the moving averages self.first_moment_moving_average = np.zeros( @@ -177,7 +178,7 @@ def get_state(self) -> List[float]: return (self.first_moment_moving_average.tolist() + self.second_moment_moving_average.tolist() + [self.t]) - def set_state(self, state: np.ndarray[np.float32]) -> None: + def set_state(self, state: npt.NDArray[np.float32]) -> None: total_len = len(state) if total_len % 2 != 1: raise ValueError("The dimension of the state should be odd")