Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in PredictorEvaluator for ZeroCost proxies #186

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions naslib/defaults/predictor_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from sklearn import metrics
import math

from naslib.predictors.zerocost import ZeroCost
from naslib.search_spaces.core.query_metrics import Metric
from naslib.utils import generate_kfold, cross_validation

from naslib import utils

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -47,6 +50,9 @@ def __init__(self, predictor, config=None):
self.num_arches_to_mutate = 5
self.max_mutation_rate = 3

# For ZeroCost proxies
self.dataloader = None

def adapt_search_space(
self, search_space, load_labeled, scope=None, dataset_api=None
):
Expand All @@ -70,6 +76,9 @@ def adapt_search_space(
"This search space is not yet implemented in PredictorEvaluator."
)

if isinstance(self.predictor, ZeroCost):
self.dataloader, _, _, _, _ = utils.get_train_val_loaders(self.config)

def get_full_arch_info(self, arch):
"""
Given an arch, return the accuracy, train_time,
Expand Down Expand Up @@ -139,10 +148,8 @@ def load_dataset(self, load_labeled=False, data_size=10, arch_hash_map={}):
arch.load_labeled_architecture(dataset_api=self.dataset_api)

arch_hash = arch.get_hash()
if False: # removing this for consistency, for now
continue
else:
arch_hash_map[arch_hash] = True

arch_hash_map[arch_hash] = True

accuracy, train_time, info_dict = self.get_full_arch_info(arch)
xdata.append(arch)
Expand Down Expand Up @@ -295,7 +302,11 @@ def single_evaluate(self, train_data, test_data, fidelity):
hyperparams = self.predictor.get_hyperparams()

fit_time_end = time.time()
test_pred = self.predictor.query(xtest, test_info)
if isinstance(self.predictor, ZeroCost):
[g.parse() for g in xtest] # parse the graphs because they will be used
test_pred = self.predictor.query_batch(xtest, self.dataloader)
else:
test_pred = self.predictor.query(xtest, test_info)
query_time_end = time.time()

# If the predictor is an ensemble, take the mean
Expand Down
14 changes: 12 additions & 2 deletions naslib/predictors/zerocost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
based on https://github.com/mohsaied/zero-cost-nas
"""
import torch
import numpy as np
import logging
import math

Expand All @@ -24,12 +25,21 @@ def __init__(self, method_type="jacov"):
self.num_imgs_or_batches = 1
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def query(self, graph, dataloader=None, info=None):
def query_batch(self, graphs, dataloader):
scores = []

for graph in graphs:
score = self.query(graph, dataloader)
scores.append(score)

return np.array(scores)

def query(self, graph, dataloader):
loss_fn = graph.get_loss_fn()

n_classes = graph.num_classes
score = predictive.find_measures(
net_orig=graph,
net_orig=graph.to(self.device),
dataloader=dataloader,
dataload_info=(self.dataload, self.num_imgs_or_batches, n_classes),
device=self.device,
Expand Down
Loading