Skip to content

Commit

Permalink
add evofed with parallel clients in multiple GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
mahi97 committed Aug 24, 2023
1 parent 7e23aa0 commit 4d2d391
Showing 1 changed file with 175 additions and 0 deletions.
175 changes: 175 additions & 0 deletions evofed_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import chex
import jax
import jax.numpy as jnp # JAX NumPy
import numpy as np
import tensorflow_datasets as tfds # TFDS for MNIST
import wandb
from evosax import NetworkMapper
from backprop import sl
from args import get_args
from utils import helpers, evo
from evosax import NetworkMapper, ParameterReshaper, FitnessShaper
from flax.core import FrozenDict

import os

os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

# cosine distance
def cosine(x, y):
return jnp.sum(x * y) / (jnp.sqrt(jnp.sum(x ** 2)) * jnp.sqrt(jnp.sum(x ** 2)))


def cosine2(x, y):
return jnp.sum(x * y) / (jnp.sqrt(jnp.sum(x ** 2)) * jnp.sqrt(jnp.sum(y ** 2)))


# l2 distance
def l2(x, y):
return -1 * jnp.sqrt(jnp.sum((x - y) ** 2))

def l1(x, y):
return -1 * jnp.sum(jnp.abs(x - y))

def pnorm(x, y, p):
x = jnp.abs(x - y)
return -1 * jnp.sum(x ** p) ** (1 / p)


def max_dist(x, y):
return -1 * 0.02 * jnp.max(jnp.abs(x - y)) + 0.98 * l2(x, y)

# def l2_std(x, y):
# return l2(x, y) +
def sparsify(array, percentage):
original = array
array = jnp.abs(array.flatten())
array = jnp.sort(array)
threshold = array[int(len(array) * percentage)]
array = jnp.where(jnp.abs(original) < threshold, 0, original)
return array

def quantize(array, min_val, max_val, n_bits):
# max_val = array.max()
# min_val = array.min()
step = (max_val - min_val) / (2 ** n_bits - 1)
array = ((array - min_val) / step).round()
return array


# dequantization array
def dequantize(array, min_val, max_val, n_bits):
step = (max_val - min_val) / (2 ** n_bits - 1)
array = array * step + min_val
return array

def pfun(x):
return x

def vfun(x, y):
return y


num_devices = jax.local_device_count()


class TaskManager:
def __init__(self, rng: chex.PRNGKey, args):
wandb.run.name = '{}-{}-{} b{} c{} s{} p{} r{} q{} -- {}' \
.format(args.dataset, args.algo,
args.dist,
args.batch_size, args.n_clients,
args.seed,
args.percentage,
args.rank_factor,
args.quantize_bits,
wandb.run.id)
wandb.run.save()
# self.train_ds, self.test_ds = sl.get_datasets_non_iid(args.dataset, args.n_clients) \
# if args.dist == 'NON-IID' else sl.get_datasets_iid(args.dataset, args.n_clients)
self.train_ds, self.test_ds = sl.get_fed_datasets_pmap(args.dataset, args.n_clients, 2, args.dist == 'IID')
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
self.learning_rate = wandb.config.lr
self.momentum = wandb.config.momentum
network = NetworkMapper[wandb.config.network_name](**wandb.config.network_config)

self.state = sl.create_train_state(init_rng, network, self.learning_rate, self.momentum)
self.param_reshaper = ParameterReshaper(self.state.params, n_devices=1)
self.test_param_reshaper = ParameterReshaper(self.state.params, n_devices=1)
self.strategy, self.es_params = evo.get_strategy_and_params(args.pop_size, self.param_reshaper.total_params,
args)
self.fit_shaper = FitnessShaper(centered_rank=args.centered_rank, z_score=args.z_score,
w_decay=args.w_decay, maximize=args.maximize, rank_factor=args.rank_factor)
server = self.strategy.initialize(init_rng, self.es_params)
self.server = server.replace(mean=self.test_param_reshaper.network_to_flat(self.state.params))
del init_rng # Must not be used anymore.

self.param_count = sum(x.size for x in jax.tree_leaves(self.state.params))
self.num_epochs = wandb.config.n_rounds
self.batch_size = wandb.config.batch_size
self.n_clients = args.n_clients

# min_cut = 10000
# self.X = jnp.array([train['image'][:min_cut] for train in self.train_ds])
# self.y = jnp.array([train['label'][:min_cut] for train in self.train_ds])
self.X = self.train_ds['image']
self.y = self.train_ds['label']
self.args = args
self.n_bits = args.quantize_bits

def run(self, rng: chex.PRNGKey):
self.X = jax.pmap(pfun)(self.X.reshape(num_devices, self.args.n_clients // num_devices, *self.X.shape[1:]))
self.y = jax.pmap(pfun)(self.y.reshape(num_devices, self.args.n_clients // num_devices, *self.y.shape[1:]))
for epoch in range(0, self.num_epochs + 1):

rng, input_rng, rng_ask = jax.random.split(rng, 3)
clients = jax.pmap(vfun, in_axes=(0, None))(jnp.arange(num_devices),
jax.vmap(vfun, in_axes=(0, None))(jnp.arange(self.args.n_clients // num_devices), self.state))

# clients, _, _ = jax.vmap(sl.train_epoch, in_axes=(None, 0, 0, None, None))(self.state,
# self.X,
# self.y,
# self.batch_size, input_rng)
clients, loss, acc = sl.train_epoch_pmap(clients, self.X, self.y, self.batch_size, input_rng)
target_server = jax.pmap(jax.vmap(self.param_reshaper.network_to_flat))(clients.params)
x, self.server = self.strategy.ask(rng_ask, self.server, self.es_params)
fitness = jax.pmap(jax.vmap(jax.vmap(l2, in_axes=(0, None)), in_axes=(None, 0)), in_axes=(None, 0))(x, target_server)
fitness = jax.pmap(jax.vmap(self.fit_shaper.apply, in_axes=(None, 0)), in_axes=(None, 0))(x, fitness)
# fitness = jax.vmap(sparsify, in_axes=(0, None))(fitness, self.args.percentage)
# fitness = jax.vmap(quantize, in_axes=(0, None, None, None))(fitness, -0.5, 0.5, self.n_bits)
fitness = jax.pmap(lambda x: x.mean(0))(fitness).mean(0)
# fitness = dequantize(fitness, -0.5, 0.5, self.n_bits)
# fitness = sparsify(fitness, self.args.percentage)
self.server = self.strategy.tell(x, fitness, self.server, self.es_params)
self.state = self.state.replace(params=FrozenDict(self.test_param_reshaper.reshape_single_net(self.server.mean)))

rng, eval_rng = jax.random.split(rng)
test_loss, test_accuracy = sl.eval_model(self.state.params, self.test_ds, eval_rng)
wandb.log({
'Round': epoch,
'Test Loss': test_loss,
'Global Accuracy': test_accuracy,
# 'Communication': epoch * 2 * self.args.pop_size,
# 'Communication': epoch * 2 * self.args.pop_size * (1 - self.args.percentage) * (1 + np.log2(self.args.pop_size)),
# 'Communication': epoch * 4 * self.args.pop_size * (1 - self.args.percentage) * (np.log2(self.args.pop_size * np.sqrt((1 - self.args.percentage) * 1/self.args.rank_factor))),
'Communication': epoch * 2 * self.args.pop_size * (1 - self.args.percentage) * ((self.n_bits + np.log2(self.args.pop_size))/ 32),
})


def run():
print(jax.devices())
args = get_args()
config = helpers.load_config(args.config)
wandb.init(project='evofed-publish', config=args)
wandb.config.update(config)
args = wandb.config
rng = jax.random.PRNGKey(args.seed)
rng, rng_init, rng_run = jax.random.split(rng, 3)
manager = TaskManager(rng_init, args)
manager.run(rng_run)


if __name__ == '__main__':
run()
# wandb.agent('y1lh8ou0', function=run, project='evofed', count=10)

0 comments on commit 4d2d391

Please sign in to comment.