Skip to content

Commit

Permalink
Replace np.ndarray with npt.NDArray in type annotations to supoprt py…
Browse files Browse the repository at this point in the history
…thon versions before 3.9
  • Loading branch information
salaast committed Jun 7, 2023
1 parent ffba5c4 commit 722111b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 45 deletions.
73 changes: 37 additions & 36 deletions compiler_opt/es/combined_blackbox_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@
import numpy as np
from sklearn import linear_model
from typing import List, Dict, Callable, Tuple, Any, Optional
import numpy.typing as npt

import gradient_ascent_optimization_algorithms


def filter_top_directions(
perturbations: np.ndarray, function_values: np.ndarray, est_type: str,
num_top_directions: int) -> Tuple[np.ndarray, np.ndarray]:
perturbations: npt.NDArray, function_values: npt.NDArray, est_type: str,
num_top_directions: int) -> Tuple[npt.NDArray, npt.NDArray]:
"""Select the subset of top-performing perturbations.
TODO(b/139662389): In the future, we may want (either here or inside the
Expand Down Expand Up @@ -94,8 +95,8 @@ class BlackboxOptimizer(metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def run_step(self, perturbations: np.ndarray, function_values: np.ndarray,
current_input: np.ndarray, current_value: float) -> np.ndarray:
def run_step(self, perturbations: npt.NDArray, function_values: npt.NDArray,
current_input: npt.NDArray, current_value: float) -> npt.NDArray:
"""Conducts a single step of blackbox optimization procedure.
Conducts a single step of blackbox optimization procedure, given values of
Expand Down Expand Up @@ -141,7 +142,7 @@ def get_state(self) -> List:
raise NotImplementedError("Abstract method")

@abc.abstractmethod
def update_state(self, evaluation_stats: List | np.ndarray) -> None:
def update_state(self, evaluation_stats: List | npt.NDArray) -> None:
"""Updates the state for blackbox function runs.
Updates the state of the optimizer for blackbox function runs.
Expand All @@ -154,7 +155,7 @@ def update_state(self, evaluation_stats: List | np.ndarray) -> None:
raise NotImplementedError("Abstract method")

@abc.abstractmethod
def set_state(self, state: List | np.ndarray) -> None:
def set_state(self, state: List | npt.NDArray) -> None:
"""Sets up the internal state of the optimizer.
Sets up the internal state of the optimizer.
Expand Down Expand Up @@ -204,8 +205,8 @@ def __init__(
self.ga_optimizer = ga_optimizer
super().__init__()

def run_step(self, perturbations: np.ndarray, function_values: np.ndarray,
current_input: np.ndarray, current_value: float) -> np.ndarray:
def run_step(self, perturbations: npt.NDArray, function_values: npt.NDArray,
current_input: npt.NDArray, current_value: float) -> npt.NDArray:
dim = len(current_input)
if self.normalize_fvalues:
values = function_values.tolist()
Expand Down Expand Up @@ -255,7 +256,7 @@ def get_state(self) -> List:
else:
return ga_state

def update_state(self, evaluation_stats: List | np.ndarray) -> None:
def update_state(self, evaluation_stats: List | npt.NDArray) -> None:
if self.hyperparameters_update_method == "state_normalization":
self.nb_steps += evaluation_stats[0]
evaluation_stats = evaluation_stats[1:]
Expand All @@ -279,7 +280,7 @@ def update_state(self, evaluation_stats: List | np.ndarray) -> None:
for a, b in zip(mean_squares_state_vector, self.mean_state_vector)
]

def set_state(self, state: List | np.ndarray) -> None:
def set_state(self, state: List | npt.NDArray) -> None:
if self.hyperparameters_update_method == "state_normalization":
self.nb_steps = state[0]
state = state[1:]
Expand Down Expand Up @@ -345,8 +346,8 @@ def set_state(self, state: List | np.ndarray) -> None:
"""


def normalize_function_values(function_values: np.ndarray,
current_value: float) -> Tuple[np.ndarray, List]:
def normalize_function_values(function_values: npt.NDArray,
current_value: float) -> Tuple[npt.NDArray, List]:
values = function_values.tolist()
values.append(current_value)
mean = sum(values) / float(len(values))
Expand All @@ -357,10 +358,10 @@ def normalize_function_values(function_values: np.ndarray,

def mc_gradient(precision_parameter: float,
est_type: str,
perturbations: np.ndarray,
function_values: np.ndarray,
perturbations: npt.NDArray,
function_values: npt.NDArray,
current_value: float,
energy: float = 0) -> np.ndarray:
energy: float = 0) -> npt.NDArray:
"""Calculates Monte Carlo gradient.
There are several ways of estimating the gradient. This is specified by the
Expand Down Expand Up @@ -401,9 +402,9 @@ def mc_gradient(precision_parameter: float,


def sklearn_regression_gradient(clf: linear_model, est_type: str,
perturbations: np.ndarray,
function_values: np.ndarray,
current_value: float) -> np.ndarray:
perturbations: npt.NDArray,
function_values: npt.NDArray,
current_value: float) -> npt.NDArray:
"""Calculates gradient by function difference regression.
Args:
Expand Down Expand Up @@ -469,7 +470,7 @@ class QuadraticModel(object):
f(x) = 1/2x^TAx + b^Tx + c
"""

def __init__(self, Av: Callable, b: np.ndarray, c: float = 0):
def __init__(self, Av: Callable, b: npt.NDArray, c: float = 0):
"""Initialize quadratic function.
Args:
Expand All @@ -482,7 +483,7 @@ def __init__(self, Av: Callable, b: np.ndarray, c: float = 0):
self.b = b
self.c = c

def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
"""Evaluate the quadratic function.
Args:
Expand All @@ -492,7 +493,7 @@ def f(self, x: np.ndarray) -> float:
"""
return 0.5 * np.dot(x, self.quad_v(x)) + np.dot(x, self.b) + self.c

def grad(self, x: np.ndarray) -> np.ndarray:
def grad(self, x: npt.NDArray) -> npt.NDArray:
"""Evaluate the gradient of the quadratic, Ax + b.
Args:
Expand All @@ -517,7 +518,7 @@ class ProjectedGradientOptimizer(object):

def __init__(self, function_object: QuadraticModel,
projection_operator: Callable, pgd_params: Dict[str, Any],
x_init: np.ndarray[float]):
x_init: npt.NDArray[np.float32]):
self.f = function_object
self.proj = projection_operator
if pgd_params is not None:
Expand Down Expand Up @@ -560,7 +561,7 @@ def run_step(self) -> None:
self.x = x_next
self.k += 1

def get_solution(self) -> np.ndarray:
def get_solution(self) -> npt.NDArray:
return self.x

def get_x_diff_norm(self) -> float:
Expand Down Expand Up @@ -608,7 +609,7 @@ class TrustRegionSubproblemOptimizer(object):
def __init__(self,
model_function: QuadraticModel,
trust_region_params: Dict[str, Any],
x_init: Optional[np.ndarray[float]] = None):
x_init: Optional[npt.NDArray[np.float32]] = None):
self.mf = model_function
self.params = trust_region_params
self.center = x_init
Expand Down Expand Up @@ -643,7 +644,7 @@ def solve_trust_region_subproblem(self) -> None:

self.x = pgd_solver.get_solution()

def get_solution(self) -> np.ndarray:
def get_solution(self) -> npt.NDArray:
return self.x


Expand Down Expand Up @@ -779,7 +780,7 @@ def __init__(self, precision_parameter: float, est_type: str,
self.clf = linear_model.Lasso(alpha=self.params['grad_reg_alpha'])
self.is_returned_step = False

def trust_region_test(self, current_input: np.ndarray,
def trust_region_test(self, current_input: npt.NDArray,
current_value: float) -> bool:
"""Test the next step to determine how to update the trust region.
Expand Down Expand Up @@ -848,8 +849,8 @@ def trust_region_test(self, current_input: np.ndarray,
print('Unchanged: ' + str(self.radius) + log_message)
return True

def update_hessian_part(self, perturbations: np.ndarray,
function_values: np.ndarray, current_value: float,
def update_hessian_part(self, perturbations: npt.NDArray,
function_values: npt.NDArray, current_value: float,
is_update: bool) -> None:
"""Updates the internal state which stores Hessian information.
Expand Down Expand Up @@ -918,7 +919,7 @@ def create_hessv_function(self) -> Callable:
"""
if self.params['dense_hessian']:

def hessv_func(x: np.ndarray) -> np.ndarray:
def hessv_func(x: npt.NDArray) -> npt.NDArray:
"""Calculates Hessian-vector product from dense Hessian.
Args:
Expand All @@ -934,7 +935,7 @@ def hessv_func(x: np.ndarray) -> np.ndarray:
return hessv
else:

def hessv_func(x: np.ndarray) -> np.ndarray:
def hessv_func(x: npt.NDArray) -> npt.NDArray:
"""Calculates Hessian-vector product from perturbation/value pairs.
Args:
Expand All @@ -961,8 +962,8 @@ def hessv_func(x: np.ndarray) -> np.ndarray:

return hessv_func

def update_quadratic_model(self, perturbations: np.ndarray,
function_values: np.ndarray, current_value: float,
def update_quadratic_model(self, perturbations: npt.NDArray,
function_values: npt.NDArray, current_value: float,
is_update: bool) -> QuadraticModel:
"""Updates the internal state of the optimizer with new perturbations.
Expand Down Expand Up @@ -1009,8 +1010,8 @@ def update_quadratic_model(self, perturbations: np.ndarray,
is_update)
return QuadraticModel(self.create_hessv_function(), self.saved_gradient)

def run_step(self, perturbations: np.ndarray, function_values: np.ndarray,
current_input: np.ndarray, current_value: float) -> np.ndarray:
def run_step(self, perturbations: npt.NDArray, function_values: npt.NDArray,
current_input: npt.NDArray, current_value: float) -> npt.NDArray:
"""Run a single step of trust region optimizer.
Args:
Expand Down Expand Up @@ -1088,7 +1089,7 @@ def get_state(self) -> List[float]:
else:
return []

def update_state(self, evaluation_stats: List | np.ndarray) -> None:
def update_state(self, evaluation_stats: List | npt.NDArray) -> None:
if self.hyperparameters_update_method == 'state_normalization':
self.nb_steps += evaluation_stats[0]
evaluation_stats = evaluation_stats[1:]
Expand All @@ -1112,7 +1113,7 @@ def update_state(self, evaluation_stats: List | np.ndarray) -> None:
for a, b in zip(mean_squares_state_vector, self.mean_state_vector)
]

def set_state(self, state: List | np.ndarray) -> None:
def set_state(self, state: List | npt.NDArray) -> None:
if self.hyperparameters_update_method == 'state_normalization':
self.nb_steps = state[0]
state = state[1:]
Expand Down
19 changes: 10 additions & 9 deletions compiler_opt/es/gradient_ascent_optimization_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import abc
import numpy as np
from typing import List
import numpy.typing as npt


# TODO(kchoro): Borrow JAXs optimizer library here. Integrated into Blackbox-v2.
Expand All @@ -41,8 +42,8 @@ class GAOptimizer(metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def run_step(self, current_input: np.ndarray,
gradient: np.ndarray[np.float32]) -> np.ndarray:
def run_step(self, current_input: npt.NDArray,
gradient: npt.NDArray[np.float32]) -> npt.NDArray:
"""Conducts a single step of gradient ascent optimization.
Conduct a single step of gradient ascent optimization procedure, given the
Expand Down Expand Up @@ -71,7 +72,7 @@ def get_state(self) -> List[np.float32]:
raise NotImplementedError("Abstract method")

@abc.abstractmethod
def set_state(self, state: np.ndarray[np.float32]) -> None:
def set_state(self, state: npt.NDArray[np.float32]) -> None:
"""Sets up the internal state of the optimizer.
Sets up the internal state of the optimizer.
Expand All @@ -98,8 +99,8 @@ def __init__(self, step_size: float, momentum: float):
self.moving_average = np.asarray([], dtype=np.float32)
super().__init__()

def run_step(self, current_input: np.ndarray,
gradient: np.ndarray[np.float32]) -> np.ndarray:
def run_step(self, current_input: npt.NDArray,
gradient: npt.NDArray[np.float32]) -> npt.NDArray:
if self.moving_average.size == 0:
# Initialize the moving average
self.moving_average = np.zeros(len(current_input), dtype=np.float32)
Expand All @@ -119,7 +120,7 @@ def run_step(self, current_input: np.ndarray,
def get_state(self) -> List[np.float32]:
return self.moving_average.tolist()

def set_state(self, state: np.ndarray[np.float32]) -> None:
def set_state(self, state: npt.NDArray[np.float32]) -> None:
self.moving_average = np.asarray(state, dtype=np.float32)


Expand All @@ -141,8 +142,8 @@ def __init__(self,
self.t = 0
super().__init__()

def run_step(self, current_input: np.ndarray,
gradient: np.ndarray[np.float32]) -> np.ndarray:
def run_step(self, current_input: npt.NDArray,
gradient: npt.NDArray[np.float32]) -> npt.NDArray:
if self.first_moment_moving_average.size == 0:
# Initialize the moving averages
self.first_moment_moving_average = np.zeros(
Expand Down Expand Up @@ -177,7 +178,7 @@ def get_state(self) -> List[float]:
return (self.first_moment_moving_average.tolist() +
self.second_moment_moving_average.tolist() + [self.t])

def set_state(self, state: np.ndarray[np.float32]) -> None:
def set_state(self, state: npt.NDArray[np.float32]) -> None:
total_len = len(state)
if total_len % 2 != 1:
raise ValueError("The dimension of the state should be odd")
Expand Down

0 comments on commit 722111b

Please sign in to comment.