From 946560695b1a32bae6dbb6660ee633f3e42d40f7 Mon Sep 17 00:00:00 2001 From: weijingchen Date: Mon, 14 Aug 2023 15:24:49 +0800 Subject: [PATCH 1/9] update fate nn, support FedIPR Signed-off-by: weijingchen --- python/federatedml/nn/homo/client.py | 4 ++++ .../nn/homo/trainer/fedavg_trainer.py | 24 +++++++------------ .../nn/homo/trainer/trainer_base.py | 3 ++- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/python/federatedml/nn/homo/client.py b/python/federatedml/nn/homo/client.py index f705d145d6..8397b8f523 100644 --- a/python/federatedml/nn/homo/client.py +++ b/python/federatedml/nn/homo/client.py @@ -298,6 +298,10 @@ def fit(self, train_input, validate_input=None): self.trainer_inst.set_model(model) self.trainer_inst.set_tracker(self.tracker) self.trainer_inst.set_model_exporter(self.exporter) + party_id_list = [self.component_properties.guest_partyid] + for i in self.component_properties.host_party_idlist: + party_id_list.append(i) + self.trainer_inst.set_party_id_list(party_id_list) # load dataset class dataset_inst = load_dataset( diff --git a/python/federatedml/nn/homo/trainer/fedavg_trainer.py b/python/federatedml/nn/homo/trainer/fedavg_trainer.py index 9b2597f17f..47c6e3f456 100644 --- a/python/federatedml/nn/homo/trainer/fedavg_trainer.py +++ b/python/federatedml/nn/homo/trainer/fedavg_trainer.py @@ -191,7 +191,7 @@ def _select_model(self): else: return self.model - def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss): + def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss_func): epoch_loss = 0.0 batch_idx = 0 @@ -210,19 +210,10 @@ def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss): batch_label = None for _batch_iter in to_iterate: _batch_iter = self._decode(_batch_iter) - if isinstance(_batch_iter, list): + if isinstance(_batch_iter, list) or isinstance(_batch_iter, tuple): batch_data, batch_label = _batch_iter else: batch_data = _batch_iter - """ - if self.task_type in [consts.CAUSAL_LM, consts.SEQ_2_SEQ_LM]: - batch_data = _batch_iter - else: - batch_data, batch_label = _batch_iter - - batch_data = self._decode(batch_data) - batch_label = self._decode(batch_label) - """ if self.cuda is not None or self._enable_deepspeed: device = self.cuda_main_device if self.cuda_main_device is not None else self.model.device @@ -237,17 +228,17 @@ def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss): pred = model(batch_data) - if not loss and hasattr(pred, "loss"): + if not loss_func and hasattr(pred, "loss"): batch_loss = pred.loss - elif loss is not None: + elif loss_func is not None: if batch_label is None: raise ValueError( "When loss is set, please provide label to calculate loss" ) if not isinstance(pred, torch.Tensor) and hasattr(pred, "logits"): pred = pred.logits - batch_loss = loss(pred, batch_label) + batch_loss = loss_func(pred, batch_label) else: raise ValueError( 'FedAVGTrainer requires a loss function, but got None, please specify loss function in the' @@ -293,7 +284,7 @@ def train( if optimizer is None: raise ValueError( - 'FedAVGTrainer requires an optimizer, but got None, please specify optimizer in the ' + 'An optimizer is required, but got None, please specify optimizer in the ' 'job configuration') if self.batch_size > len(train_set) or self.batch_size == -1: @@ -309,7 +300,7 @@ def train( need_stop = False evaluation_summary = {} - self._get_train_data_loader(train_set) + self.data_loader = self._get_train_data_loader(train_set) # training process for i in range(self.epochs): @@ -608,3 +599,4 @@ def _sync_loss(self, loss): else: dist.gather(loss, dst=0, async_op=False) # LOGGER.info(f"Loss on rank{dist.get_rank()}={loss}") + diff --git a/python/federatedml/nn/homo/trainer/trainer_base.py b/python/federatedml/nn/homo/trainer/trainer_base.py index 3af07dd355..424cbed94d 100644 --- a/python/federatedml/nn/homo/trainer/trainer_base.py +++ b/python/federatedml/nn/homo/trainer/trainer_base.py @@ -52,6 +52,7 @@ def __init__(self, **kwargs): self._model_checkpoint = None self._exporter = None self._evaluation_summary = {} + self._client_num = None # running status self._set_model_checkpoint_epoch = set() @@ -273,7 +274,7 @@ def save( if self._exporter: LOGGER.debug('save model to fate') - model_dict = self._exporter.export_model_dict(model=model, + model_dict = self._exporter.export_model_dict(model=modedel, optimizer=optimizer, model_define=self.nn_define, optimizer_define=self.opt_define, From 41b548fc48f47c6bbf5747b6b3275891037d5b43 Mon Sep 17 00:00:00 2001 From: weijingchen Date: Mon, 14 Aug 2023 16:57:15 +0800 Subject: [PATCH 2/9] Add FedIPR trainer & sign block support ipr fix typo Signed-off-by: weijingchen --- .../nn/homo/trainer/fedipr_trainer.py | 489 ++++++++++++++++++ .../nn/homo/trainer/trainer_base.py | 2 +- python/federatedml/nn/model_zoo/sign_block.py | 187 +++++++ 3 files changed, 677 insertions(+), 1 deletion(-) create mode 100644 python/federatedml/nn/homo/trainer/fedipr_trainer.py create mode 100644 python/federatedml/nn/model_zoo/sign_block.py diff --git a/python/federatedml/nn/homo/trainer/fedipr_trainer.py b/python/federatedml/nn/homo/trainer/fedipr_trainer.py new file mode 100644 index 0000000000..74f4dda49f --- /dev/null +++ b/python/federatedml/nn/homo/trainer/fedipr_trainer.py @@ -0,0 +1,489 @@ +import torch as t +import tqdm +import numpy as np +import torch +from typing import Literal +from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer +from federatedml.nn.backend.utils import distributed_util +from torch.utils.data import DataLoader, DistributedSampler +import torch.distributed as dist +from federatedml.nn.dataset.watermark import WaterMarkImageDataset, WaterMarkDataset +from federatedml.util import LOGGER +from federatedml.nn.model_zoo.sign_block import generate_signature, is_sign_block +from federatedml.nn.model_zoo.sign_block import SignatureBlock +from sklearn.metrics import accuracy_score +from federatedml.nn.dataset.base import Dataset +from federatedml.util import consts + + +def get_sign_blocks(model: torch.nn.Module): + + record_sign_block = {} + for name, m in model.named_modules(): + if is_sign_block(m): + record_sign_block[name] = m + + return record_sign_block + + +def get_keys(sign_block_dict: dict, num_bits: int): + + key_pairs = {} + param_len = [] + sum_allocated_bits = 0 + # Iterate through each layer and compute the flattened parameter lengths + for k, v in sign_block_dict.items(): + param_len.append(len(v.embeded_param.flatten())) + total_param_len = sum(param_len) + + alloc_bits = {} + + for i, (k, v) in enumerate(sign_block_dict.items()): + allocated_bits = int((param_len[i] / total_param_len) * num_bits) + alloc_bits[k] = allocated_bits + sum_allocated_bits += allocated_bits + + rest_bits = num_bits - sum_allocated_bits + if rest_bits > 0: + alloc_bits[k] += rest_bits + + for k, v in sign_block_dict.items(): + key_pairs[k] = generate_signature(v, alloc_bits[k]) + + return key_pairs + + +""" +Verify Tools +""" + +def to_cuda(var, device=0): + if hasattr(var, 'cuda'): + return var.cuda(device) + elif isinstance(var, tuple) or isinstance(var, list): + ret = tuple(to_cuda(i) for i in var) + return ret + elif isinstance(var, dict): + for k in var: + if hasattr(var[k], 'cuda'): + var[k] = var[k].cuda(device) + return var + else: + return var + + +def _verify_sign_blocks(sign_blocks, keys, cuda=False, device=None): + + signature_correct_count = 0 + total_bit = 0 + for name, block in sign_blocks.items(): + block: SignatureBlock = block + W, signature = keys[name] + if cuda: + W = to_cuda(W, device=device) + signature = to_cuda(signature, device=device) + extract_bits = block.extract_sign(W) + total_bit += len(extract_bits) + signature_correct_count += (extract_bits == signature).sum().detach().cpu().item() + + sign_acc = signature_correct_count / total_bit + return sign_acc + + +def _suggest_sign_bit(param_num, client_num): + max_signbit = param_num // client_num + max_signbit -= 1 # not to exceed + if max_signbit <= 0: + raise ValueError('not able to add feature based watermark, param_num is {}, client num is {}, computed max bit is {} <=0'.format(param_num, client_num, max_signbit)) + return max_signbit + + +def compute_sign_bit(model, client_num): + total_param_num = 0 + blocks = get_sign_blocks(model) + for k, v in blocks.items(): + total_param_num += v.embeded_param_num() + if total_param_num == 0: + return 0 + return _suggest_sign_bit(total_param_num, client_num) + + +def verify_feature_based_signature(model, keys): + + model = model.cpu() + sign_blocks = get_sign_blocks(model) + return _verify_sign_blocks(sign_blocks, keys, cuda=False) + + + +class FedIPRTrainer(FedAVGTrainer): + + def __init__(self, epochs=10, noraml_dataset_batch_size=32, watermark_dataset_batch_size=2, + early_stop=None, tol=0.0001, secure_aggregate=True, weighted_aggregation=True, + aggregate_every_n_epoch=None, cuda=None, pin_memory=True, shuffle=True, + data_loader_worker=0, validation_freqs=None, checkpoint_save_freqs=None, + task_type='auto', save_to_local_dir=False, collate_fn=None, collate_fn_params=None, + alpha=0.01, verify_freqs=1, backdoor_verify_method: Literal['accuracy', 'loss'] = 'accuracy' + ): + + super().__init__(epochs, noraml_dataset_batch_size, early_stop, tol, secure_aggregate, weighted_aggregation, + aggregate_every_n_epoch, cuda, pin_memory, shuffle, data_loader_worker, + validation_freqs, checkpoint_save_freqs, task_type, save_to_local_dir, collate_fn, collate_fn_params) + + self.normal_train_set = None + self.watermark_set = None + self.data_loader = None + self.normal_dataset_batch_size = noraml_dataset_batch_size + self.watermark_dataset_batch_size = watermark_dataset_batch_size + self.alpha = alpha + self.verify_freqs = verify_freqs + self.backdoor_verify_method = backdoor_verify_method + self._sign_keys = None + self._sign_blocks = None + self._client_num = None + self._sign_bits = None + + assert self.alpha > 0, 'alpha must be greater than 0' + assert self.verify_freqs > 0 and isinstance(self.verify_freqs, int), 'verify_freqs must be greater than 0' + assert self.backdoor_verify_method in ['accuracy', 'loss'], 'backdoor_verify_method must be accuracy or loss' + + def local_mode(self): + self.fed_mode = False + self._client_num = 1 + + def _handle_dataset(self, train_set, collate_fn): + + if not distributed_util.is_distributed() or distributed_util.get_num_workers() <= 1: + return DataLoader( + train_set, + batch_size=self.batch_size, + pin_memory=self.pin_memory, + shuffle=self.shuffle, + num_workers=self.data_loader_worker, + collate_fn=collate_fn + ) + else: + train_sampler = DistributedSampler( + train_set, + num_replicas=dist.get_world_size(), + rank=dist.get_rank() + ) + return DataLoader( + train_set, + batch_size=self.batch_size, + pin_memory=self.pin_memory, + num_workers=self.data_loader_worker, + collate_fn=collate_fn, + sampler=train_sampler + ) + + + def _get_train_data_loader(self, train_set): + + collate_fn = self._get_collate_fn(train_set) + + if isinstance(train_set, WaterMarkDataset): + LOGGER.info('detect watermark dataset, split watermark dataset and normal dataset') + normal_train_set = train_set.get_normal_dataset() + watermark_set = train_set.get_watermark_dataset() + if normal_train_set is None: + raise ValueError('normal dataset must not be None in FedIPR algo') + train_dataloder = self._handle_dataset(normal_train_set, collate_fn) + + if watermark_set is not None: + watermark_dataloader = self._handle_dataset(watermark_set, collate_fn) + else: + watermark_dataloader = None + self.normal_train_set = normal_train_set + self.watermark_set = watermark_set + dataloaders = {'train': train_dataloder, 'watermark': watermark_dataloader} + return dataloaders + else: + LOGGER.info('detect non-watermark dataset') + train_dataloder = self._handle_dataset(train_set, collate_fn) + dataloaders = {'train': train_dataloder, 'watermark': None} + return dataloaders + + def _get_device(self): + if self.cuda is not None or self._enable_deepspeed: + device = self.cuda_main_device if self.cuda_main_device is not None else self.model.device + return device + else: + return None + + def verify(self, sign_blocks: dict, keys: dict): + + return _verify_sign_blocks(sign_blocks, keys, self.cuda is not None, self._get_device()) + + def get_loss_from_pred(self, loss, pred, batch_label): + + if not loss and hasattr(pred, "loss"): + batch_loss = pred.loss + + elif loss is not None: + if batch_label is None: + raise ValueError( + "When loss is set, please provide label to calculate loss" + ) + if not isinstance(pred, torch.Tensor) and hasattr(pred, "logits"): + pred = pred.logits + batch_loss = loss(pred, batch_label) + else: + raise ValueError( + 'Trainer requires a loss function, but got None, please specify loss function in the' + ' job configuration') + + return batch_loss + + def _get_keys(self, sign_blocks): + + if self._sign_keys is None: + self._sign_keys = get_keys(sign_blocks, self._sign_bits) + return self._sign_keys + + def _get_sign_blocks(self): + if self._sign_blocks is None: + sign_blocks = get_sign_blocks(self.model) + self._sign_blocks = sign_blocks + + return self._sign_blocks + + def train(self, train_set: Dataset, validate_set: Dataset = None, optimizer = None, loss=None, extra_dict={}): + + if 'keys' in extra_dict: + self._sign_keys = extra_dict['keys'] + self._sign_bits = extra_dict['num_bits'] + else: + LOGGER.info('computing feature based sign bits') + if self._client_num is None and self.party_id_list is not None: + self._client_num = len(self.party_id_list) + self._sign_bits = compute_sign_bit(self.model, self._client_num) + + + LOGGER.info('client num {}, party id list {}'.format(self._client_num, self.party_id_list)) + LOGGER.info('will assign {} bits for feature based watermark'.format(self._sign_bits)) + return super().train(train_set, validate_set, optimizer, loss, extra_dict) + + def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss_func): + + epoch_loss = 0.0 + batch_idx = 0 + acc_num = 0 + + sign_blocks = self._get_sign_blocks() + keys = self._get_keys(sign_blocks) + + dl, watermark_dl = self.data_loader['train'], self.data_loader['watermark'] + if isinstance(dl, DistributedSampler): + dl.sampler.set_epoch(epoch_idx) + if isinstance(watermark_dl, DistributedSampler): + watermark_dl.sampler.set_epoch(epoch_idx) + + if not self.fed_mode: + trainset_iterator = tqdm.tqdm(dl) + else: + trainset_iterator = dl + batch_label = None + + # collect watermark data and mix them into the training data + watermark_collect = [] + if watermark_dl is not None: + for watermark_batch in watermark_dl: + watermark_collect.append(watermark_batch) + + for _batch_iter in trainset_iterator: + + _batch_iter = self._decode(_batch_iter) + + if isinstance(_batch_iter, list) or isinstance(_batch_iter, tuple): + batch_data, batch_label = _batch_iter + else: + batch_data = _batch_iter + + if watermark_dl is not None: + # Mix the backdoor sample into the training data + wm_batch_idx = int(batch_idx % len(watermark_collect)) + wm_batch = watermark_collect[wm_batch_idx] + if isinstance(wm_batch, list): + wm_batch_data, wm_batch_label = wm_batch + batch_data = torch.cat([batch_data, wm_batch_data], dim=0) + batch_label = torch.cat([batch_label, wm_batch_label], dim=0) + else: + wm_batch_data = wm_batch + batch_data = torch.cat([batch_data, wm_batch_data], dim=0) + + if self.cuda is not None or self._enable_deepspeed: + device = self.cuda_main_device if self.cuda_main_device is not None else self.model.device + batch_data = self.to_cuda(batch_data, device) + if batch_label is not None: + batch_label = self.to_cuda(batch_label, device) + + if not self._enable_deepspeed: + optimizer.zero_grad() + else: + model.zero_grad() + + pred = model(batch_data) + + sign_loss = 0 + # Get the sign loss of model + for name, block in sign_blocks.items(): + + block: SignatureBlock = block + W, signature = keys[name] + if self.cuda is not None: + device = self._get_device() + W = self.to_cuda(W, device) + signature = self.to_cuda(signature, device) + sign_loss += self.alpha * block.sign_loss(W, signature) + + + batch_loss = self.get_loss_from_pred(loss_func, pred, batch_label) + batch_loss += sign_loss + + if not self._enable_deepspeed: + + batch_loss.backward() + optimizer.step() + batch_loss_np = np.array(batch_loss.detach().tolist()) if self.cuda is None \ + else np.array(batch_loss.cpu().detach().tolist()) + + if acc_num + self.batch_size > len(train_set): + batch_len = len(train_set) - acc_num + else: + batch_len = self.batch_size + + epoch_loss += batch_loss_np * batch_len + else: + batch_loss = model.backward(batch_loss) + batch_loss_np = np.array(batch_loss.cpu().detach().tolist()) + model.step() + batch_loss_np = self._sync_loss(batch_loss_np * self._get_batch_size(batch_data)) + if distributed_util.is_rank_0(): + epoch_loss += batch_loss_np + + batch_idx += 1 + + if self.fed_mode: + LOGGER.debug( + 'epoch {} batch {} finished'.format(epoch_idx, batch_idx)) + + epoch_loss = epoch_loss / len(train_set) + + # verify the sign of model during training + if epoch_idx % self.verify_freqs == 0: + # verify feature-based signature + sign_acc = self.verify(sign_blocks, keys) + LOGGER.info(f"epoch {epoch_idx} sign accuracy: {sign_acc}") + # verify backdoor signature + if self.watermark_set is not None: + _, pred, label = self._predict(self.watermark_set) + pred = pred.detach().cpu() + label = label.detach().cpu() + if self.backdoor_verify_method == 'accuracy': + if not isinstance(pred, torch.Tensor) and hasattr(pred, "logits"): + pred = pred.logits + pred = pred.numpy().reshape((len(label), -1)) + label = label.numpy() + pred_label = np.argmax(pred, axis=1) + metric = accuracy_score(pred_label.flatten(), label.flatten()) + else: + metric = self.get_loss_from_pred(loss_func, pred, label) + + LOGGER.info(f"epoch {epoch_idx} backdoor {self.backdoor_verify_method}: {metric}") + + return epoch_loss + + def _predict(self, dataset: Dataset): + pred_result = [] + + # switch eval mode + dataset.eval() + model = self._select_model() + model.eval() + + if not dataset.has_sample_ids(): + dataset.init_sid_and_getfunc(prefix=dataset.get_type()) + + labels = [] + with torch.no_grad(): + for _batch_iter in DataLoader( + dataset, self.batch_size + ): + if isinstance(_batch_iter, list): + batch_data, batch_label = _batch_iter + else: + batch_label = _batch_iter.pop("labels") + batch_data = _batch_iter + + if self.cuda is not None or self._enable_deepspeed: + device = self.cuda_main_device if self.cuda_main_device is not None else self.model.device + batch_data = self.to_cuda(batch_data, device) + + pred = model(batch_data) + + if not isinstance(pred, torch.Tensor) and hasattr(pred, "logits"): + pred = pred.logits + + pred_result.append(pred) + labels.append(batch_label) + + ret_rs = torch.concat(pred_result, axis=0) + ret_label = torch.concat(labels, axis=0) + + # switch back to train mode + dataset.train() + model.train() + + return dataset.get_sample_ids(), ret_rs, ret_label + + def predict(self, dataset: Dataset): + + if self.task_type in [consts.CAUSAL_LM, consts.SEQ_2_SEQ_LM]: + LOGGER.warning(f"Not support prediction of task_types={[consts.CAUSAL_LM, consts.SEQ_2_SEQ_LM]}") + return + + if distributed_util.is_distributed() and not distributed_util.is_rank_0(): + return + + if isinstance(dataset, WaterMarkDataset): + normal_train_set = dataset.get_normal_dataset() + if normal_train_set is None: + raise ValueError('normal train set is None in FedIPR algo predict function') + else: + normal_train_set = normal_train_set + + ids, ret_rs, ret_label = self._predict(normal_train_set) + + if self.fed_mode: + return self.format_predict_result( + ids, ret_rs, ret_label, task_type=self.task_type) + else: + return ret_rs, ret_label + + def save( + self, + model=None, + epoch_idx=-1, + optimizer=None, + converge_status=False, + loss_history=None, + best_epoch=-1, + extra_data={}): + + extra_data = {'keys': self._sign_keys, 'num_bits': self._sign_bits} + super().save(model, epoch_idx, optimizer, converge_status, loss_history, best_epoch, extra_data) + + def local_save(self, + model=None, + epoch_idx=-1, + optimizer=None, + converge_status=False, + loss_history=None, + best_epoch=-1, + extra_data={}): + + extra_data = {'keys': self._sign_keys, 'num_bits': self._sign_bits} + super().local_save(model, epoch_idx, optimizer, converge_status, loss_history, best_epoch, extra_data) + + diff --git a/python/federatedml/nn/homo/trainer/trainer_base.py b/python/federatedml/nn/homo/trainer/trainer_base.py index 424cbed94d..405b8cd161 100644 --- a/python/federatedml/nn/homo/trainer/trainer_base.py +++ b/python/federatedml/nn/homo/trainer/trainer_base.py @@ -274,7 +274,7 @@ def save( if self._exporter: LOGGER.debug('save model to fate') - model_dict = self._exporter.export_model_dict(model=modedel, + model_dict = self._exporter.export_model_dict(model=model, optimizer=optimizer, model_define=self.nn_define, optimizer_define=self.opt_define, diff --git a/python/federatedml/nn/model_zoo/sign_block.py b/python/federatedml/nn/model_zoo/sign_block.py new file mode 100644 index 0000000000..338afeefa7 --- /dev/null +++ b/python/federatedml/nn/model_zoo/sign_block.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn +import torch.nn.init as init +from torch.nn import functional as F +from federatedml.util import LOGGER + +""" +Base +""" + + +class SignatureBlock(nn.Module): + + def __init__(self) -> None: + super().__init__() + + @property + def embeded_param(self): + return None + + def extract_sign(self, W): + pass + + def sign_loss(self, W, sign): + pass + + def embeded_param_num(self): + pass + + +def is_sign_block(block): + return issubclass(type(block), SignatureBlock) + + +class ConvBlock(nn.Module): + def __init__(self, i, o, ks=3, s=1, pd=1, relu=True): + super().__init__() + + self.conv = nn.Conv2d(i, o, ks, s, pd, bias= False) + + if relu: + self.relu = nn.ReLU(inplace=True) + else: + self.relu = None + + self.reset_parameters() + + def reset_parameters(self): + init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu') + + def forward(self, x): + x = self.conv(x) + if self.relu is not None: + x = self.relu(x) + return x + + +def generate_signature(conv_block: SignatureBlock, num_bits): + + sign = torch.sign(torch.rand(num_bits) - 0.5) + W = torch.randn(len(conv_block.embeded_param.flatten()), num_bits) + + return (W, sign) + + +""" +Function & Class for Conv Layer +""" + + +class SignatureConv(SignatureBlock): + + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False): + super(SignatureConv, self).__init__() + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) + self.weight = self.conv.weight + + self._embed_para_num = None + self.init_scale() + self.init_bias() + self.bn = nn.BatchNorm2d(out_channels, affine=False) + self.relu = nn.ReLU(inplace=True) + self.reset_parameters() + + def embeded_param_num(self): + return self._embed_para_num + + def init_bias(self): + self.bias = nn.Parameter(torch.Tensor(self.conv.out_channels).to(self.weight.device)) + init.zeros_(self.bias) + + def init_scale(self): + self.scale = nn.Parameter(torch.Tensor(self.conv.out_channels).to(self.weight.device)) + init.ones_(self.scale) + self._embed_para_num = self.scale.shape[0] + + def reset_parameters(self): + init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu') + + @property + def embeded_param(self): + # embedded in the BatchNorm param, as the same in the paper + return self.scale + + def extract_sign(self, W): + # W is the linear weight for extracting signature + with torch.no_grad(): + return self.scale.view([1, -1]).mm(W).sign().flatten() + + def sign_loss(self, W, sign): + loss = F.relu(-self.scale.view([1, -1]).mm(W).mul(sign.view(-1))).sum() + return loss + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = x * self.scale[None, :, None, None] + self.bias[None, :, None, None] + x = self.relu(x) + return x + + +""" +Function & Class for LM +""" + + +def recursive_replace_layernorm(module, layer_name_set=None): + + """ + Recursively replaces the LayerNorm layers of a given module with SignatureLayerNorm layers. + + Parameters: + module (torch.nn.Module): The module in which LayerNorm layers should be replaced. + layer_name_set (set[str], optional): A set of layer names to be replaced. If None, + all LayerNorm layers in the module will be replaced. + """ + + for name, sub_module in module.named_children(): + if isinstance(sub_module, nn.LayerNorm): + if layer_name_set is not None and name not in layer_name_set: + continue + setattr(module, name, SignatureLayerNorm.from_layer_norm_layer(sub_module)) + LOGGER.debug(f"Replace {name} with SignatureLayerNorm") + recursive_replace_layernorm(sub_module, layer_name_set) + + +class SignatureLayerNorm(SignatureBlock): + + def __init__(self, normalized_shape=None, eps=1e-5, elementwise_affine=True, layer_norm_inst=None): + super(SignatureLayerNorm, self).__init__() + if layer_norm_inst is not None and isinstance(layer_norm_inst, nn.LayerNorm): + self.ln = layer_norm_inst + else: + self.ln = nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + self._embed_param_num = self.ln.weight.numel() + + @property + def embeded_param(self): + return self.ln.weight + + def embeded_param_num(self): + return self._embed_param_num + + @staticmethod + def from_layer_norm_layer(layer_norm_layer: nn.LayerNorm): + return SignatureLayerNorm(layer_norm_inst=layer_norm_layer) + + def extract_sign(self, W): + # W is the linear weight for extracting signature + with torch.no_grad(): + return self.ln.weight.view([1, -1]).mm(W).sign().flatten() + + def sign_loss(self, W, sign): + loss = F.relu(-self.ln.weight.view([1, -1]).mm(W).mul(sign.view(-1))).sum() + return loss + + def forward(self, x): + return self.ln(x) + + +if __name__ == "__main__": + conv = SignatureConv(3, 384, 3, 1, 1) + layer_norm = SignatureLayerNorm((768, )) + layer_norm_2 = SignatureLayerNorm.from_layer_norm_layer(layer_norm.ln) + \ No newline at end of file From 6a939199cd879d233acdea83f4c55a31c92a3611 Mon Sep 17 00:00:00 2001 From: weijingchen Date: Mon, 21 Aug 2023 19:57:34 +0800 Subject: [PATCH 3/9] Add support for fate-llm: 1. Update homonn framework, support arbiter side model 2. Fix bug & log format 3. Update aggregator framework Signed-off-by: weijingchen --- .../homo/aggregator/aggregator_base.py | 9 +- .../homo/aggregator/secure_aggregator.py | 12 +- python/federatedml/framework/homo/blocks.py | 8 +- .../nn/backend/torch/cust_model.py | 55 ------ python/federatedml/nn/homo/_init.py | 140 +++++++++++++++ python/federatedml/nn/homo/client.py | 156 +++------------- python/federatedml/nn/homo/server.py | 60 ++++++- .../nn/homo/trainer/fedavg_trainer.py | 167 ++++++++++++------ .../nn/homo/trainer/trainer_base.py | 65 ++++--- python/federatedml/param/homo_nn_param.py | 4 +- 10 files changed, 386 insertions(+), 290 deletions(-) delete mode 100644 python/federatedml/nn/backend/torch/cust_model.py create mode 100644 python/federatedml/nn/homo/_init.py diff --git a/python/federatedml/framework/homo/aggregator/aggregator_base.py b/python/federatedml/framework/homo/aggregator/aggregator_base.py index d00f8d52f0..7dbe56b282 100644 --- a/python/federatedml/framework/homo/aggregator/aggregator_base.py +++ b/python/federatedml/framework/homo/aggregator/aggregator_base.py @@ -1,4 +1,5 @@ from federatedml.framework.homo.blocks import ServerCommunicator, ClientCommunicator +from federatedml.util import consts class AutoSuffix(object): @@ -19,7 +20,7 @@ def __call__(self): class AggregatorBaseClient(object): - def __init__(self, communicate_match_suffix: str = None): + def __init__(self, communicate_match_suffix: str = None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)): """Base class of client aggregator Parameters @@ -28,7 +29,7 @@ def __init__(self, communicate_match_suffix: str = None): To make sure that client and server can communicate correctly, the server-side and client-side aggregators need to have the same suffix """ - self.communicator = ClientCommunicator(prefix=communicate_match_suffix) + self.communicator = ClientCommunicator(prefix=communicate_match_suffix, server=server, clients=clients) self.suffix = {} def _get_suffix(self, var_name, user_suffix=tuple()): @@ -52,7 +53,7 @@ def get(self, suffix): class AggregatorBaseServer(object): - def __init__(self, communicate_match_suffix=None): + def __init__(self, communicate_match_suffix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)): """Base class of server aggregator Parameters @@ -61,7 +62,7 @@ def __init__(self, communicate_match_suffix=None): To make sure that client and server can communicate correctly, the server-side and client-side aggregators need to have the same suffix """ - self.communicator = ServerCommunicator(prefix=communicate_match_suffix) + self.communicator = ServerCommunicator(prefix=communicate_match_suffix, server=server, clients=clients) self.suffix = {} def _get_suffix(self, var_name, user_suffix=tuple()): diff --git a/python/federatedml/framework/homo/aggregator/secure_aggregator.py b/python/federatedml/framework/homo/aggregator/secure_aggregator.py index ce6a7ec545..a1776fb948 100644 --- a/python/federatedml/framework/homo/aggregator/secure_aggregator.py +++ b/python/federatedml/framework/homo/aggregator/secure_aggregator.py @@ -15,10 +15,10 @@ class SecureAggregatorClient(AggregatorBaseClient): def __init__(self, secure_aggregate=True, aggregate_type='weighted_mean', aggregate_weight=1.0, - communicate_match_suffix=None): + communicate_match_suffix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)): super(SecureAggregatorClient, self).__init__( - communicate_match_suffix=communicate_match_suffix) + communicate_match_suffix=communicate_match_suffix, clients=clients, server=server) self.secure_aggregate = secure_aggregate self.suffix = { "local_loss": AutoSuffix("local_loss"), @@ -31,7 +31,7 @@ def __init__(self, secure_aggregate=True, aggregate_type='weighted_mean', aggreg # init secure aggregate random padding: if self.secure_aggregate: self._random_padding_cipher: PadsCipher = RandomPaddingCipherClient( - trans_var=RandomPaddingCipherTransVar(prefix=communicate_match_suffix)).create_cipher() + trans_var=RandomPaddingCipherTransVar(prefix=communicate_match_suffix, clients=clients, server=server)).create_cipher() LOGGER.info('initialize secure aggregator done') # compute weight @@ -186,9 +186,9 @@ def loss_aggregation(self, loss, suffix=tuple()): class SecureAggregatorServer(AggregatorBaseServer): - def __init__(self, secure_aggregate=True, communicate_match_suffix=None): + def __init__(self, secure_aggregate=True, communicate_match_suffix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)): super(SecureAggregatorServer, self).__init__( - communicate_match_suffix=communicate_match_suffix) + communicate_match_suffix=communicate_match_suffix, clients=clients, server=server) self.suffix = { "local_loss": AutoSuffix("local_loss"), "agg_loss": AutoSuffix("agg_loss"), @@ -199,7 +199,7 @@ def __init__(self, secure_aggregate=True, communicate_match_suffix=None): self.secure_aggregate = secure_aggregate if self.secure_aggregate: RandomPaddingCipherServer(trans_var=RandomPaddingCipherTransVar( - prefix=communicate_match_suffix)).exchange_secret_keys() + prefix=communicate_match_suffix, clients=clients, server=server)).exchange_secret_keys() LOGGER.info('initialize secure aggregator done') agg_weights = self.collect(suffix=('agg_weight', )) diff --git a/python/federatedml/framework/homo/blocks.py b/python/federatedml/framework/homo/blocks.py index 7d2d730e4c..c894172819 100644 --- a/python/federatedml/framework/homo/blocks.py +++ b/python/federatedml/framework/homo/blocks.py @@ -60,8 +60,8 @@ def __init__(self, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST) class ServerCommunicator(object): - def __init__(self, prefix=None): - self.trans_var = CommunicatorTransVar(prefix=prefix) + def __init__(self, prefix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)): + self.trans_var = CommunicatorTransVar(prefix=prefix, server=server, clients=clients) self._client_parties = self.trans_var.client_parties def get_parties(self, party_idx): @@ -85,8 +85,8 @@ def broadcast_obj(self, obj, suffix=tuple(), party_idx=-1): class ClientCommunicator(object): - def __init__(self, prefix=None): - trans_var = CommunicatorTransVar(prefix=prefix) + def __init__(self, prefix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)): + trans_var = CommunicatorTransVar(prefix=prefix, server=server, clients=clients) self.trans_var = trans_var self._server_parties = trans_var.server_parties diff --git a/python/federatedml/nn/backend/torch/cust_model.py b/python/federatedml/nn/backend/torch/cust_model.py deleted file mode 100644 index e9fff34839..0000000000 --- a/python/federatedml/nn/backend/torch/cust_model.py +++ /dev/null @@ -1,55 +0,0 @@ -import importlib - -from torch import nn - -from federatedml.nn.backend.torch.base import FateTorchLayer -from federatedml.nn.backend.utils.common import ML_PATH - -PATH = '{}.model_zoo'.format(ML_PATH) - - -class CustModel(FateTorchLayer, nn.Module): - - def __init__(self, module_name, class_name, **kwargs): - super(CustModel, self).__init__() - assert isinstance( - module_name, str), 'name must be a str, specify the module in the model_zoo' - assert isinstance( - class_name, str), 'class name must be a str, specify the class in the module' - self.param_dict = { - 'module_name': module_name, - 'class_name': class_name, - 'param': kwargs} - self._model = None - - def init_model(self): - if self._model is None: - self._model = self.get_pytorch_model() - - def forward(self, x): - if self._model is None: - raise ValueError('model not init, call init_model() function') - return self._model(x) - - def get_pytorch_model(self): - - module_name: str = self.param_dict['module_name'] - class_name = self.param_dict['class_name'] - module_param: dict = self.param_dict['param'] - if module_name.endswith('.py'): - module_name = module_name.replace('.py', '') - nn_modules = importlib.import_module('{}.{}'.format(PATH, module_name)) - try: - for k, v in nn_modules.__dict__.items(): - if isinstance(v, type): - if issubclass( - v, nn.Module) and v is not nn.Module and v.__name__ == class_name: - return v(**module_param) - raise ValueError( - 'Did not find any class in {}.py that is pytorch nn.Module and named {}'. format( - module_name, class_name)) - except ValueError as e: - raise e - - def __repr__(self): - return 'CustModel({})'.format(str(self.param_dict)) diff --git a/python/federatedml/nn/homo/_init.py b/python/federatedml/nn/homo/_init.py new file mode 100644 index 0000000000..3dc96bab11 --- /dev/null +++ b/python/federatedml/nn/homo/_init.py @@ -0,0 +1,140 @@ +import json +import torch +import inspect +from federatedml.nn.homo.trainer.trainer_base import get_trainer_class, TrainerBase +from federatedml.util import LOGGER +from federatedml.nn.backend.torch import serialization as s +from federatedml.nn.backend.torch.base import FateTorchOptimizer +from federatedml.nn.backend.utils.common import recover_model_bytes +from federatedml.nn.backend.utils import deepspeed_util + + +def init(trainer, trainer_param, nn_define, config_optimizer, config_loss, torch_seed, model_loaded_flag, loaded_model, ds_config): + + warm_start_iter = None + + if ds_config: + deepspeed_util.init_deepspeed_env(ds_config) + + # load trainer class + if trainer is None: + raise ValueError( + 'Trainer is not specified, please specify your trainer') + + trainer_class = get_trainer_class(trainer) + LOGGER.info('trainer class is {}'.format(trainer_class)) + + # recover model from model config / or recover from saved model param + loaded_model_dict = None + + # if has model protobuf, load model config from protobuf + load_opt_state_dict = False + + if model_loaded_flag: + param, meta = get_homo_param_meta(loaded_model) + LOGGER.info('save path is {}'.format(param.local_save_path)) + if param.local_save_path == '': + LOGGER.info('Load model from model protobuf') + warm_start_iter = param.epoch_idx + if param is None or meta is None: + raise ValueError( + 'model protobuf is None, make sure' + 'that your trainer calls export_model() function to save models') + + if meta.nn_define[0] is None: + raise ValueError( + 'nn_define is None, model protobuf has no nn-define, make sure' + 'that your trainer calls export_model() function to save models') + + nn_define = json.loads(meta.nn_define[0]) + loss = json.loads(meta.loss_func_define[0]) + optimizer = json.loads(meta.optimizer_define[0]) + loaded_model_dict = recover_model_bytes(param.model_bytes) + extra_data = recover_model_bytes(param.extra_data_bytes) + + else: + LOGGER.info('Load model from local save path') + save_dict = torch.load(open(param.local_save_path, 'rb')) + warm_start_iter = save_dict['epoch_idx'] + nn_define = save_dict['model_define'] + loss = save_dict['loss_define'] + optimizer = save_dict['optimizer_define'] + loaded_model_dict = save_dict + extra_data = save_dict['extra_data'] + + if config_optimizer is not None and optimizer != config_optimizer: + LOGGER.info('optimizer updated') + else: + config_optimizer = optimizer + load_opt_state_dict = True + + if config_loss is not None and config_loss != loss: + LOGGER.info('loss updated') + else: + config_loss = loss + else: + extra_data = {} + + # check key param + if nn_define is None: + raise ValueError( + 'Model structure is not defined, nn_define is None, please check your param') + + # get model from nn define + model = s.recover_sequential_from_dict(nn_define) + if loaded_model_dict: + model.load_state_dict(loaded_model_dict['model']) + LOGGER.info('load model state dict from check point') + + LOGGER.info('model structure is {}'.format(model)) + # init optimizer + if config_optimizer is not None and not ds_config: + optimizer_: FateTorchOptimizer = s.recover_optimizer_from_dict( + config_optimizer) + # pass model parameters to optimizer + optimizer = optimizer_.to_torch_instance(model.parameters()) + if load_opt_state_dict: + LOGGER.info('load optimizer state dict') + optimizer.load_state_dict(loaded_model_dict['optimizer']) + LOGGER.info('optimizer is {}'.format(optimizer)) + else: + optimizer = None + LOGGER.info('optimizer is not specified') + + # init loss + if config_loss is not None: + loss_fn = s.recover_loss_fn_from_dict(config_loss) + LOGGER.info('loss function is {}'.format(loss_fn)) + else: + loss_fn = None + LOGGER.info('loss function is not specified') + + # init trainer + trainer_inst: TrainerBase = trainer_class(**trainer_param) + LOGGER.info('trainer class is {}'.format(trainer_class)) + + trainer_train_args = inspect.getfullargspec(trainer_inst.train).args + args_format = [ + 'self', + 'train_set', + 'validate_set', + 'optimizer', + 'loss', + 'extra_data' + ] + if len(trainer_train_args) < 6: + raise ValueError( + 'Train function of trainer should take 6 arguments :{}, but current trainer.train ' + 'only takes {} arguments: {}'.format( + args_format, len(trainer_train_args), trainer_train_args)) + + trainer_inst.set_nn_config(nn_define, config_optimizer, config_loss) + trainer_inst.fed_mode = True + + if ds_config: + model, optimizer = deepspeed_util.deepspeed_init(model, ds_config) + trainer_inst.enable_deepspeed(is_zero_3=deepspeed_util.is_zero3(ds_config)) + if deepspeed_util.is_zero3(ds_config): + model.train() + + return trainer_inst, model, optimizer, loss_fn, extra_data, config_optimizer, config_loss, warm_start_iter \ No newline at end of file diff --git a/python/federatedml/nn/homo/client.py b/python/federatedml/nn/homo/client.py index 8397b8f523..8062d421e6 100644 --- a/python/federatedml/nn/homo/client.py +++ b/python/federatedml/nn/homo/client.py @@ -23,6 +23,7 @@ from federatedml.nn.backend.utils.data import add_match_id from federatedml.protobuf.generated.homo_nn_model_param_pb2 import HomoNNParam as HomoNNParamPB from federatedml.protobuf.generated.homo_nn_model_meta_pb2 import HomoNNMeta as HomoNNMetaPB +from federatedml.nn.homo._init import init class NNModelExporter(ExporterBase): @@ -85,6 +86,12 @@ def export_model_dict( return get_homo_model_dict(param, meta) +def default_client_post_process(trainer): + model = trainer.get_cached_model() + summary = trainer.get_summary() + return model, summary + + class HomoNNClient(ModelBase): def __init__(self): @@ -122,6 +129,7 @@ def __init__(self): self._ds_stage = -1 self.model_save_flag = False + def _init_model(self, param: HomoNNParam): train_param = param.trainer.to_dict() @@ -136,136 +144,6 @@ def _init_model(self, param: HomoNNParam): self.optimizer = param.optimizer self.ds_config = param.ds_config - def init(self): - - # set random seed - global_seed(self.torch_seed) - - if self.ds_config: - deepspeed_util.init_deepspeed_env(self.ds_config) - - # load trainer class - if self.trainer is None: - raise ValueError( - 'Trainer is not specified, please specify your trainer') - - trainer_class = get_trainer_class(self.trainer) - LOGGER.info('trainer class is {}'.format(trainer_class)) - - # recover model from model config / or recover from saved model param - loaded_model_dict = None - - # if has model protobuf, load model config from protobuf - load_opt_state_dict = False - if self.model_loaded: - - param, meta = get_homo_param_meta(self.model) - LOGGER.info('save path is {}'.format(param.local_save_path)) - if param.local_save_path == '': - LOGGER.info('Load model from model protobuf') - self.warm_start_iter = param.epoch_idx - if param is None or meta is None: - raise ValueError( - 'model protobuf is None, make sure' - 'that your trainer calls export_model() function to save models') - - if meta.nn_define[0] is None: - raise ValueError( - 'nn_define is None, model protobuf has no nn-define, make sure' - 'that your trainer calls export_model() function to save models') - - self.nn_define = json.loads(meta.nn_define[0]) - loss = json.loads(meta.loss_func_define[0]) - optimizer = json.loads(meta.optimizer_define[0]) - loaded_model_dict = recover_model_bytes(param.model_bytes) - extra_data = recover_model_bytes(param.extra_data_bytes) - else: - LOGGER.info('Load model from local save path') - save_dict = torch.load(open(param.local_save_path, 'rb')) - self.warm_start_iter = save_dict['epoch_idx'] - self.nn_define = save_dict['model_define'] - loss = save_dict['loss_define'] - optimizer = save_dict['optimizer_define'] - loaded_model_dict = save_dict - extra_data = save_dict['extra_data'] - - if self.optimizer is not None and optimizer != self.optimizer: - LOGGER.info('optimizer updated') - else: - self.optimizer = optimizer - load_opt_state_dict = True - - if self.loss is not None and self.loss != loss: - LOGGER.info('loss updated') - else: - self.loss = loss - else: - extra_data = {} - - # check key param - if self.nn_define is None: - raise ValueError( - 'Model structure is not defined, nn_define is None, please check your param') - - # get model from nn define - model = s.recover_sequential_from_dict(self.nn_define) - if loaded_model_dict: - model.load_state_dict(loaded_model_dict['model']) - LOGGER.info('load model state dict from check point') - - LOGGER.info('model structure is {}'.format(model)) - # init optimizer - if self.optimizer is not None and not self.ds_config: - optimizer_: FateTorchOptimizer = s.recover_optimizer_from_dict( - self.optimizer) - # pass model parameters to optimizer - optimizer = optimizer_.to_torch_instance(model.parameters()) - if load_opt_state_dict: - LOGGER.info('load optimizer state dict') - optimizer.load_state_dict(loaded_model_dict['optimizer']) - LOGGER.info('optimizer is {}'.format(optimizer)) - else: - optimizer = None - LOGGER.info('optimizer is not specified') - - # init loss - if self.loss is not None: - loss_fn = s.recover_loss_fn_from_dict(self.loss) - LOGGER.info('loss function is {}'.format(loss_fn)) - else: - loss_fn = None - LOGGER.info('loss function is not specified') - - # init trainer - trainer_inst: TrainerBase = trainer_class(**self.trainer_param) - LOGGER.info('trainer class is {}'.format(trainer_class)) - - trainer_train_args = inspect.getfullargspec(trainer_inst.train).args - args_format = [ - 'self', - 'train_set', - 'validate_set', - 'optimizer', - 'loss', - 'extra_data' - ] - if len(trainer_train_args) < 6: - raise ValueError( - 'Train function of trainer should take 6 arguments :{}, but current trainer.train ' - 'only takes {} arguments: {}'.format( - args_format, len(trainer_train_args), trainer_train_args)) - - trainer_inst.set_nn_config(self.nn_define, self.optimizer, self.loss) - trainer_inst.fed_mode = True - - if self.ds_config: - model, optimizer = deepspeed_util.deepspeed_init(model, self.ds_config) - trainer_inst.enable_deepspeed(is_zero_3=deepspeed_util.is_zero3(self.ds_config)) - if deepspeed_util.is_zero3(self.ds_config): - model.train() - - return trainer_inst, model, optimizer, loss_fn, extra_data - def fit(self, train_input, validate_input=None): LOGGER.debug('train input is {}'.format(train_input)) @@ -294,13 +172,21 @@ def fit(self, train_input, validate_input=None): # set random seed global_seed(self.torch_seed) - self.trainer_inst, model, optimizer, loss_fn, extra_data = self.init() + # init + self.trainer_inst, model, optimizer, loss_fn, extra_data, self.optimizer, self.loss, self.warm_start_iter = init( + trainer=self.trainer, trainer_param=self.trainer_param, nn_define=self.nn_define, + config_optimizer=self.optimizer, config_loss=self.loss, torch_seed=self.torch_seed, model_loaded_flag=self.model_loaded, + loaded_model=self.model, ds_config=self.ds_config + ) + + # prepare to train self.trainer_inst.set_model(model) self.trainer_inst.set_tracker(self.tracker) self.trainer_inst.set_model_exporter(self.exporter) party_id_list = [self.component_properties.guest_partyid] - for i in self.component_properties.host_party_idlist: - party_id_list.append(i) + if self.component_properties.host_party_idlist is not None: + for i in self.component_properties.host_party_idlist: + party_id_list.append(i) self.trainer_inst.set_party_id_list(party_id_list) # load dataset class @@ -343,8 +229,8 @@ def fit(self, train_input, validate_input=None): ) # training is done, get exported model - self.model = self.trainer_inst.get_cached_model() - self.set_summary(self.trainer_inst.get_summary()) + self.model, summary = default_client_post_process(self.trainer_inst) + self.set_summary(summary) def predict(self, cpn_input): diff --git a/python/federatedml/nn/homo/server.py b/python/federatedml/nn/homo/server.py index 106ab0757f..27fb226e32 100644 --- a/python/federatedml/nn/homo/server.py +++ b/python/federatedml/nn/homo/server.py @@ -6,6 +6,9 @@ from federatedml.nn.homo.client import NNModelExporter from federatedml.callbacks.model_checkpoint import ModelCheckpoint from federatedml.nn.backend.utils.common import get_homo_param_meta, recover_model_bytes +from federatedml.nn.homo._init import init +from federatedml.util import consts +from federatedml.nn.backend.utils.common import global_seed class HomoNNServer(ModelBase): @@ -13,7 +16,7 @@ class HomoNNServer(ModelBase): def __init__(self): super(HomoNNServer, self).__init__() self.model_param = HomoNNParam() - self.trainer = None + self.trainer = consts.FEDAVG_TRAINER self.trainer_param = None # arbiter side models @@ -24,13 +27,25 @@ def __init__(self): self.exporter = NNModelExporter() self.extra_data = {} # warm start + self.model_loaded = False self.warm_start_iter = None + # server init: if arbiter need to load model, loss, optimizer from config + self.server_init = False + + self.dataset_module = None + self.dataset = None + self.dataset_param = {} + self.torch_seed = None + self.loss = None + self.optimizer = None + self.nn_define = None + self.ds_config = None def export_model(self): if self.model is None: LOGGER.debug('export an empty model') - return self.exporter.export_model_dict() # return an exporter + return self.exporter.export_model_dict() # return an empyty model return self.model @@ -43,13 +58,24 @@ def load_model(self, model_dict): # load extra data self.extra_data = recover_model_bytes(param.extra_data_bytes) self.warm_start_iter = param.epoch_idx + self.model_loaded = True def _init_model(self, param: HomoNNParam()): + train_param = param.trainer.to_dict() + dataset_param = param.dataset.to_dict() self.trainer = train_param['trainer_name'] + self.dataset = dataset_param['dataset_name'] self.trainer_param = train_param['param'] + self.torch_seed = param.torch_seed + self.nn_define = param.nn_define + self.loss = param.loss + self.optimizer = param.optimizer + self.ds_config = param.ds_config + LOGGER.debug('trainer and trainer param {} {}'.format( self.trainer, self.trainer_param)) + self.server_init = param.server_init def fit(self, data_instance=None, validate_data=None): @@ -63,17 +89,37 @@ def fit(self, data_instance=None, validate_data=None): if self.component_properties.is_warm_start: self.callback_warm_start_init_iter(self.warm_start_iter) - # initialize trainer - trainer_class = get_trainer_class(self.trainer) - LOGGER.info('trainer class is {}'.format(trainer_class)) - # init trainer - trainer_inst = trainer_class(**self.trainer_param) + if self.server_init: + LOGGER.info('server try to load model, loss, optimizer from config') + # init + global_seed(self.torch_seed) + + trainer_inst, model, optimizer, loss_fn, extra_data, optimizer, loss, self.warm_start_iter = init( + trainer=self.trainer, trainer_param=self.trainer_param, nn_define=self.nn_define, + config_optimizer=self.optimizer, config_loss=self.loss, torch_seed=self.torch_seed, model_loaded_flag=self.model_loaded, + loaded_model=self.model, ds_config=self.ds_config + ) + trainer_inst.set_model(model) + + else: + # initialize trainer only + trainer_class = get_trainer_class(self.trainer) + trainer_inst = trainer_class(**self.trainer_param) + LOGGER.info('trainer class is {}'.format(trainer_class)) + # set tracker for fateboard callback trainer_inst.set_tracker(self.tracker) # set exporter trainer_inst.set_model_exporter(self.exporter) + # set party info + party_id_list = [self.component_properties.guest_partyid] + if self.component_properties.host_party_idlist is not None: + for i in self.component_properties.host_party_idlist: + party_id_list.append(i) + trainer_inst.set_party_id_list(party_id_list) # set chceckpoint trainer_inst.set_checkpoint(ModelCheckpoint(self, save_freq=1)) + # run trainer server procedure trainer_inst.server_aggregate_procedure(self.extra_data) diff --git a/python/federatedml/nn/homo/trainer/fedavg_trainer.py b/python/federatedml/nn/homo/trainer/fedavg_trainer.py index 47c6e3f456..86dfe00785 100644 --- a/python/federatedml/nn/homo/trainer/fedavg_trainer.py +++ b/python/federatedml/nn/homo/trainer/fedavg_trainer.py @@ -150,6 +150,11 @@ def __init__(self, epochs=10, batch_size=512, # training parameter self.check_trainer_param( [self.tol], ['tol'], self.is_float, '{} is not a float') + # federation + self.client_agg = None + self.server_agg = None + self.aggregate_round = None + def _init_aggregator(self, train_set): # compute round to aggregate cur_agg_round = 0 @@ -202,6 +207,8 @@ def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss_func): dl = self.data_loader + total_batch_len = len(dl) + if not self.fed_mode: to_iterate = tqdm.tqdm(dl) else: @@ -267,13 +274,87 @@ def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss_func): batch_idx += 1 # LOGGER.info(f"finish epoch={epoch_idx}, batch={batch_idx}") - if self.fed_mode: - LOGGER.debug( - 'epoch {} batch {} finished'.format(epoch_idx, batch_idx)) + if self.fed_mode: + if batch_idx % (total_batch_len // 100) == 0: + percentage = (batch_idx / total_batch_len) * 100 + LOGGER.info(f"Training progress of epoch {epoch_idx}: {percentage:.1f}%") epoch_loss = epoch_loss / len(train_set) return epoch_loss + def on_loop_begin_client(self, **kwargs): + pass + + def on_loop_end_client(self, **kwargs): + pass + + def on_loop_begin_server(self, **kwargs): + pass + + def on_loop_end_server(self, **kwargs): + pass + + def _client_sends_data(self, epoch_idx, epoch_loss, cur_agg_round): + need_stop = False + if self.client_agg is not None or distributed_util.is_distributed(): + if not (self.aggregate_every_n_epoch is not None and (epoch_idx + 1) % self.aggregate_every_n_epoch != 0): + + # model averaging, only aggregate trainable param + if self._deepspeed_zero_3: + deepspeed_util.gather_model(self.model) + + if not distributed_util.is_distributed() or distributed_util.is_rank_0(): + self.model = self.client_agg.model_aggregation(self.model) + if distributed_util.is_distributed() and distributed_util.get_num_workers() > 1: + self._share_model() + else: + self._share_model() + + # agg loss and get converge status + if not distributed_util.is_distributed() or distributed_util.is_rank_0(): + converge_status = self.client_agg.loss_aggregation(epoch_loss) + cur_agg_round += 1 + if distributed_util.is_distributed() and distributed_util.get_num_workers() > 1: + self._sync_converge_status(converge_status) + else: + converge_status = self._sync_converge_status() + + if not distributed_util.is_distributed() or distributed_util.is_rank_0(): + LOGGER.info( + 'model averaging finished, aggregate round {}/{}'.format( + cur_agg_round, self.aggregate_round)) + + if converge_status: + LOGGER.info('early stop triggered, stop training') + need_stop = True + + return need_stop + + def _server_aggregates_data(self, epoch_idx, check_converge, converge_func): + + need_stop = False + if not (self.aggregate_every_n_epoch is not None and (epoch_idx + 1) % self.aggregate_every_n_epoch != 0): + + # model aggregate + self.server_agg.model_aggregation() + + # loss aggregate + agg_loss, converge_status = self.server_agg.loss_aggregation( + check_converge=check_converge, converge_func=converge_func) + self.callback_loss(agg_loss, epoch_idx) + + # save check point process + if self.save_freq is not None and ((epoch_idx + 1) % self.save_freq == 0): + self.checkpoint(epoch_idx=epoch_idx) + LOGGER.info('save checkpoint : epoch {}'.format(epoch_idx)) + + # check stop condition + if converge_status: + LOGGER.debug('stop triggered, stop aggregation') + need_stop = True + + return need_stop + def train( self, train_set: Dataset, @@ -292,7 +373,7 @@ def train( # compute round to aggregate cur_agg_round = 0 - client_agg, aggregate_round = self._init_aggregator(train_set) + self.client_agg, self.aggregate_round = self._init_aggregator(train_set) # running var cur_epoch = 0 @@ -301,6 +382,9 @@ def train( evaluation_summary = {} self.data_loader = self._get_train_data_loader(train_set) + + self.on_loop_begin_client() + # training process for i in range(self.epochs): @@ -314,37 +398,8 @@ def train( LOGGER.info('epoch loss is {}'.format(epoch_loss)) # federation process, if running local mode, cancel federation - if client_agg is not None or distributed_util.is_distributed(): - if not (self.aggregate_every_n_epoch is not None and (i + 1) % self.aggregate_every_n_epoch != 0): - - # model averaging, only aggregate trainable param - if self._deepspeed_zero_3: - deepspeed_util.gather_model(self.model) - - if not distributed_util.is_distributed() or distributed_util.is_rank_0(): - self.model = client_agg.model_aggregation(self.model) - if distributed_util.is_distributed() and distributed_util.get_num_workers() > 1: - self._share_model() - else: - self._share_model() - - # agg loss and get converge status - if not distributed_util.is_distributed() or distributed_util.is_rank_0(): - converge_status = client_agg.loss_aggregation(epoch_loss) - cur_agg_round += 1 - if distributed_util.is_distributed() and distributed_util.get_num_workers() > 1: - self._sync_converge_status(converge_status) - else: - converge_status = self._sync_converge_status() - - if not distributed_util.is_distributed() or distributed_util.is_rank_0(): - LOGGER.info( - 'model averaging finished, aggregate round {}/{}'.format( - cur_agg_round, aggregate_round)) - - if converge_status: - LOGGER.info('early stop triggered, stop training') - need_stop = True + need_stop = self._client_sends_data(i, epoch_loss, cur_agg_round) + cur_agg_round += 1 # validation process if self.validation_freq and ((i + 1) % self.validation_freq == 0): @@ -391,6 +446,8 @@ def train( if self._deepspeed_zero_3: deepspeed_util.gather_model(self.model) + self.on_loop_end_client() + if not distributed_util.is_distributed() or distributed_util.is_rank_0(): best_epoch = int(np.array(loss_history).argmin()) @@ -481,31 +538,23 @@ def server_aggregate_procedure(self, extra_data={}): 'check early stop, converge func is {}'.format(converge_func)) LOGGER.info('server running aggregate procedure') - server_agg = SecureAggServer(self.secure_aggregate, communicate_match_suffix=self.comm_suffix) + self.server_agg = SecureAggServer(self.secure_aggregate, communicate_match_suffix=self.comm_suffix) + self.on_loop_begin_server() # aggregate and broadcast models for i in range(self.epochs): - if not (self.aggregate_every_n_epoch is not None and (i + 1) % self.aggregate_every_n_epoch != 0): - - # model aggregate - server_agg.model_aggregation() - - # loss aggregate - agg_loss, converge_status = server_agg.loss_aggregation( - check_converge=check_converge, converge_func=converge_func) - self.callback_loss(agg_loss, i) - - # save check point process - if self.save_freq is not None and ((i + 1) % self.save_freq == 0): - self.checkpoint(epoch_idx=i) - LOGGER.info('save checkpoint : epoch {}'.format(i)) - # check stop condition - if converge_status: - LOGGER.debug('stop triggered, stop aggregation') - break - - LOGGER.info('server aggregation process done') + need_stop = self._server_aggregates_data(i, check_converge, converge_func) + if need_stop: + break + + self.on_loop_end_server() + if self.model is not None: + if self.save_to_local_dir: + self.local_save(model=self.model, epoch_idx=i, converge_status=need_stop) + else: + self.save(model=self.model, epoch_idx=i, converge_status=need_stop) + LOGGER.info('sever side model saved') def _decode(self, data): if isinstance(data, transformers.tokenization_utils_base.BatchEncoding): @@ -541,7 +590,7 @@ def _get_train_data_loader(self, train_set): collate_fn = self._get_collate_fn(train_set) if not distributed_util.is_distributed() or distributed_util.get_num_workers() <= 1: - self.data_loader = DataLoader( + data_loader = DataLoader( train_set, batch_size=self.batch_size, pin_memory=self.pin_memory, @@ -555,7 +604,7 @@ def _get_train_data_loader(self, train_set): num_replicas=dist.get_world_size(), rank=dist.get_rank() ) - self.data_loader = DataLoader( + data_loader = DataLoader( train_set, batch_size=self.batch_size, pin_memory=self.pin_memory, @@ -564,6 +613,8 @@ def _get_train_data_loader(self, train_set): sampler=train_sampler ) + return data_loader + def _share_model(self): if distributed_util.is_rank_0(): for p in self.model.parameters(): diff --git a/python/federatedml/nn/homo/trainer/trainer_base.py b/python/federatedml/nn/homo/trainer/trainer_base.py index 405b8cd161..2dc3c71529 100644 --- a/python/federatedml/nn/homo/trainer/trainer_base.py +++ b/python/federatedml/nn/homo/trainer/trainer_base.py @@ -9,7 +9,7 @@ from federatedml.util import consts from federatedml.util import LOGGER from federatedml.model_base import serialize_models -from federatedml.nn.backend.utils.common import ML_PATH +from federatedml.nn.backend.utils.common import ML_PATH, LLM_PATH from federatedml.feature.instance import Instance from federatedml.evaluation.evaluation import Evaluation from federatedml.model_base import Metric, MetricMeta @@ -226,10 +226,17 @@ def _local_save( if hasattr(model, "enable_save_pretrained") and model.enable_save_pretrained: unwrap_model.save_pretrained(save_path) else: - model_state_dict = model.state_dict() + if model is None: + model_state_dict = None + else: + model_state_dict = model.state_dict() + if optimizer is None: + optimizer_state_dict = None + else: + optimizer_state_dict = optimizer.state_dict() model_dict = { 'model': model_state_dict, - 'optimizer': optimizer.state_dict(), + 'optimizer': optimizer_state_dict, 'model_define': self.nn_define, 'optimizer_define': self.opt_define, 'loss_define': self.loss_define, @@ -547,24 +554,42 @@ def unwrap_model(model): Load Trainer """ - def get_trainer_class(trainer_module_name: str): - if trainer_module_name.endswith('.py'): trainer_module_name = trainer_module_name.replace('.py', '') - ds_modules = importlib.import_module( - '{}.homo.trainer.{}'.format( - ML_PATH, trainer_module_name)) + + std_fate_trainer_path = '{}.homo.trainer.{}'.format(ML_PATH, trainer_module_name) + + paths_to_check = [std_fate_trainer_path] + errors = [] try: - trainers = [] - for k, v in ds_modules.__dict__.items(): - if isinstance(v, type): - if issubclass(v, TrainerBase) and v is not TrainerBase: - trainers.append(v) - if len(trainers) == 0: - raise ValueError('Did not find any class in {}.py that is the subclass of Trainer class'. - format(trainer_module_name)) - else: - return trainers[-1] # return the last defined trainer - except ValueError as e: - raise e + importlib.import_module(LLM_PATH) + fate_llm_trainer_path = '{}.trainer.{}'.format(LLM_PATH, trainer_module_name) + paths_to_check.append(fate_llm_trainer_path) + except Exception as e: + pass + + trainers = [] + ds_modules = None + + for path in paths_to_check: + try: + ds_modules = importlib.import_module(path) + break + except Exception as e: + errors.append(str(e)) + + if ds_modules is None: + raise ImportError('Could not import from any of the paths: {}, error details {}'.format(', '.join(paths_to_check), errors)) + + for k, v in ds_modules.__dict__.items(): + + if isinstance(v, type): + if issubclass(v, TrainerBase) and v is not TrainerBase: + trainers.append(v) + + if len(trainers) == 0: + raise ValueError('Did not find any class in {}.py that is the subclass of Trainer class'. + format(trainer_module_name)) + else: + return trainers[-1] # return the last defined trainer diff --git a/python/federatedml/param/homo_nn_param.py b/python/federatedml/param/homo_nn_param.py index 297ba12c72..64650a5762 100644 --- a/python/federatedml/param/homo_nn_param.py +++ b/python/federatedml/param/homo_nn_param.py @@ -42,7 +42,8 @@ def __init__(self, nn_define: dict = None, loss: dict = None, optimizer: dict = None, - ds_config: dict = None + ds_config: dict = None, + server_init: bool = False ): super(HomoNNParam, self).__init__() @@ -53,6 +54,7 @@ def __init__(self, self.loss = loss self.optimizer = optimizer self.ds_config = ds_config + self.server_init = server_init def check(self): From 9a07af64ae8976703bed81bb798143650db5d74d Mon Sep 17 00:00:00 2001 From: weijingchen Date: Mon, 21 Aug 2023 19:58:55 +0800 Subject: [PATCH 4/9] Add fate-client support Signed-off-by: weijingchen --- .../fate_client/pipeline/component/homo_nn.py | 12 +++++-- .../component/nn/backend/torch/cust.py | 35 +++++++++++++------ 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/python/fate_client/pipeline/component/homo_nn.py b/python/fate_client/pipeline/component/homo_nn.py index c6b68adf6a..281bec24d5 100644 --- a/python/fate_client/pipeline/component/homo_nn.py +++ b/python/fate_client/pipeline/component/homo_nn.py @@ -44,7 +44,8 @@ 'loss': None, 'optimizer': None, 'nn_define': None, - 'ds_config': None + 'ds_config': None, + 'server_init': False } except Exception as e: print(e) @@ -65,7 +66,10 @@ class HomoNN(FateComponent): torch_seed, global random seed loss, loss function from fate_torch optimizer, optimizer from fate_torch + ds_config, config for deepspeed model, a fate torch sequential defining the model structure + server_init, whether to initialize the model, loss and optimizer on server, if configs are provided, they will be used. In + current version this option is specially designed for offsite-tuning """ @extract_explicit_parameter @@ -82,7 +86,9 @@ def __init__(self, loss=None, optimizer: OptimizerType = None, ds_config: dict = None, - model: Sequential = None, **kwargs): + model: Sequential = None, + server_init: bool = False, + **kwargs): explicit_parameters = copy.deepcopy(DEFAULT_PARAM_DICT) if 'name' not in kwargs["explict_parameters"]: @@ -95,7 +101,7 @@ def __init__(self, self.output = Output(self.name, data_type='single') self._module_name = "HomoNN" self._updated = {'trainer': False, 'dataset': False, - 'torch_seed': False, 'loss': False, 'optimizer': False, 'model': False} + 'torch_seed': False, 'loss': False, 'optimizer': False, 'model': False, 'ds_config': False, 'server_init': False} self._set_param(kwargs["explict_parameters"]) self._check_parameters() diff --git a/python/fate_client/pipeline/component/nn/backend/torch/cust.py b/python/fate_client/pipeline/component/nn/backend/torch/cust.py index 4eba0c54c6..de60736081 100644 --- a/python/fate_client/pipeline/component/nn/backend/torch/cust.py +++ b/python/fate_client/pipeline/component/nn/backend/torch/cust.py @@ -3,9 +3,12 @@ from pipeline.component.nn.backend.torch.base import FateTorchLayer, FateTorchLoss import difflib +ML_PATH = 'federatedml.nn' +LLM_PATH = "fate_llm" -MODEL_PATH = None -LOSS_PATH = None +LLM_MODEL_PATH = '{}.model_zoo'.format(LLM_PATH) +MODEL_PATH = '{}.model_zoo'.format(ML_PATH) +LOSS_PATH = '{}.loss'.format(ML_PATH) def str_simi(str_a, str_b): @@ -45,9 +48,14 @@ class CustModel(FateTorchLayer, nn.Module): def __init__(self, module_name, class_name, **kwargs): super(CustModel, self).__init__() - assert isinstance(module_name, str), 'name must be a str, specify the module in the model_zoo' - assert isinstance(class_name, str), 'class name must be a str, specify the class in the module' - self.param_dict = {'module_name': module_name, 'class_name': class_name, 'param': kwargs} + assert isinstance( + module_name, str), 'name must be a str, specify the module in the model_zoo' + assert isinstance( + class_name, str), 'class name must be a str, specify the class in the module' + self.param_dict = { + 'module_name': module_name, + 'class_name': class_name, + 'param': kwargs} self._model = None def init_model(self): @@ -62,11 +70,18 @@ def forward(self, x): def get_pytorch_model(self, module_path=None): if module_path is None: - return get_class( - self.param_dict['module_name'], - self.param_dict['class_name'], - self.param_dict['param'], - MODEL_PATH) + try: + return get_class( + self.param_dict['module_name'], + self.param_dict['class_name'], + self.param_dict['param'], + MODEL_PATH) + except BaseException: + return get_class( + self.param_dict['module_name'], + self.param_dict['class_name'], + self.param_dict['param'], + LLM_MODEL_PATH) else: return get_class( self.param_dict['module_name'], From b96b392404d9200d01ab1296049d55b4cea5c0e4 Mon Sep 17 00:00:00 2001 From: weijingchen Date: Mon, 21 Aug 2023 20:00:33 +0800 Subject: [PATCH 5/9] Fix pep8 Signed-off-by: weijingchen --- .../homo/aggregator/aggregator_base.py | 5 +- .../homo/aggregator/secure_aggregator.py | 16 +++- python/federatedml/nn/homo/_init.py | 15 +++- python/federatedml/nn/homo/client.py | 1 - .../nn/homo/trainer/fedavg_trainer.py | 11 ++- .../nn/homo/trainer/fedipr_trainer.py | 82 +++++++++++-------- .../nn/homo/trainer/trainer_base.py | 13 +-- 7 files changed, 90 insertions(+), 53 deletions(-) diff --git a/python/federatedml/framework/homo/aggregator/aggregator_base.py b/python/federatedml/framework/homo/aggregator/aggregator_base.py index 7dbe56b282..d46e96f195 100644 --- a/python/federatedml/framework/homo/aggregator/aggregator_base.py +++ b/python/federatedml/framework/homo/aggregator/aggregator_base.py @@ -20,7 +20,10 @@ def __call__(self): class AggregatorBaseClient(object): - def __init__(self, communicate_match_suffix: str = None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)): + def __init__( + self, communicate_match_suffix: str = None, server=( + consts.ARBITER,), clients=( + consts.GUEST, consts.HOST)): """Base class of client aggregator Parameters diff --git a/python/federatedml/framework/homo/aggregator/secure_aggregator.py b/python/federatedml/framework/homo/aggregator/secure_aggregator.py index a1776fb948..4ed8303989 100644 --- a/python/federatedml/framework/homo/aggregator/secure_aggregator.py +++ b/python/federatedml/framework/homo/aggregator/secure_aggregator.py @@ -31,7 +31,10 @@ def __init__(self, secure_aggregate=True, aggregate_type='weighted_mean', aggreg # init secure aggregate random padding: if self.secure_aggregate: self._random_padding_cipher: PadsCipher = RandomPaddingCipherClient( - trans_var=RandomPaddingCipherTransVar(prefix=communicate_match_suffix, clients=clients, server=server)).create_cipher() + trans_var=RandomPaddingCipherTransVar( + prefix=communicate_match_suffix, + clients=clients, + server=server)).create_cipher() LOGGER.info('initialize secure aggregator done') # compute weight @@ -186,7 +189,16 @@ def loss_aggregation(self, loss, suffix=tuple()): class SecureAggregatorServer(AggregatorBaseServer): - def __init__(self, secure_aggregate=True, communicate_match_suffix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)): + def __init__( + self, + secure_aggregate=True, + communicate_match_suffix=None, + server=( + consts.ARBITER, + ), + clients=( + consts.GUEST, + consts.HOST)): super(SecureAggregatorServer, self).__init__( communicate_match_suffix=communicate_match_suffix, clients=clients, server=server) self.suffix = { diff --git a/python/federatedml/nn/homo/_init.py b/python/federatedml/nn/homo/_init.py index 3dc96bab11..aa3fc96bc7 100644 --- a/python/federatedml/nn/homo/_init.py +++ b/python/federatedml/nn/homo/_init.py @@ -9,8 +9,17 @@ from federatedml.nn.backend.utils import deepspeed_util -def init(trainer, trainer_param, nn_define, config_optimizer, config_loss, torch_seed, model_loaded_flag, loaded_model, ds_config): - +def init( + trainer, + trainer_param, + nn_define, + config_optimizer, + config_loss, + torch_seed, + model_loaded_flag, + loaded_model, + ds_config): + warm_start_iter = None if ds_config: @@ -137,4 +146,4 @@ def init(trainer, trainer_param, nn_define, config_optimizer, config_loss, torch if deepspeed_util.is_zero3(ds_config): model.train() - return trainer_inst, model, optimizer, loss_fn, extra_data, config_optimizer, config_loss, warm_start_iter \ No newline at end of file + return trainer_inst, model, optimizer, loss_fn, extra_data, config_optimizer, config_loss, warm_start_iter diff --git a/python/federatedml/nn/homo/client.py b/python/federatedml/nn/homo/client.py index 8062d421e6..c2d7d83fcb 100644 --- a/python/federatedml/nn/homo/client.py +++ b/python/federatedml/nn/homo/client.py @@ -129,7 +129,6 @@ def __init__(self): self._ds_stage = -1 self.model_save_flag = False - def _init_model(self, param: HomoNNParam): train_param = param.trainer.to_dict() diff --git a/python/federatedml/nn/homo/trainer/fedavg_trainer.py b/python/federatedml/nn/homo/trainer/fedavg_trainer.py index 86dfe00785..97722912a7 100644 --- a/python/federatedml/nn/homo/trainer/fedavg_trainer.py +++ b/python/federatedml/nn/homo/trainer/fedavg_trainer.py @@ -331,7 +331,7 @@ def _client_sends_data(self, epoch_idx, epoch_loss, cur_agg_round): return need_stop def _server_aggregates_data(self, epoch_idx, check_converge, converge_func): - + need_stop = False if not (self.aggregate_every_n_epoch is not None and (epoch_idx + 1) % self.aggregate_every_n_epoch != 0): @@ -382,9 +382,9 @@ def train( evaluation_summary = {} self.data_loader = self._get_train_data_loader(train_set) - + self.on_loop_begin_client() - + # training process for i in range(self.epochs): @@ -547,7 +547,7 @@ def server_aggregate_procedure(self, extra_data={}): need_stop = self._server_aggregates_data(i, check_converge, converge_func) if need_stop: break - + self.on_loop_end_server() if self.model is not None: if self.save_to_local_dir: @@ -614,7 +614,7 @@ def _get_train_data_loader(self, train_set): ) return data_loader - + def _share_model(self): if distributed_util.is_rank_0(): for p in self.model.parameters(): @@ -650,4 +650,3 @@ def _sync_loss(self, loss): else: dist.gather(loss, dst=0, async_op=False) # LOGGER.info(f"Loss on rank{dist.get_rank()}={loss}") - diff --git a/python/federatedml/nn/homo/trainer/fedipr_trainer.py b/python/federatedml/nn/homo/trainer/fedipr_trainer.py index 74f4dda49f..c39523d76a 100644 --- a/python/federatedml/nn/homo/trainer/fedipr_trainer.py +++ b/python/federatedml/nn/homo/trainer/fedipr_trainer.py @@ -17,7 +17,7 @@ def get_sign_blocks(model: torch.nn.Module): - + record_sign_block = {} for name, m in model.named_modules(): if is_sign_block(m): @@ -27,9 +27,9 @@ def get_sign_blocks(model: torch.nn.Module): def get_keys(sign_block_dict: dict, num_bits: int): - - key_pairs = {} - param_len = [] + + key_pairs = {} + param_len = [] sum_allocated_bits = 0 # Iterate through each layer and compute the flattened parameter lengths for k, v in sign_block_dict.items(): @@ -49,7 +49,7 @@ def get_keys(sign_block_dict: dict, num_bits: int): for k, v in sign_block_dict.items(): key_pairs[k] = generate_signature(v, alloc_bits[k]) - + return key_pairs @@ -57,6 +57,7 @@ def get_keys(sign_block_dict: dict, num_bits: int): Verify Tools """ + def to_cuda(var, device=0): if hasattr(var, 'cuda'): return var.cuda(device) @@ -85,7 +86,7 @@ def _verify_sign_blocks(sign_blocks, keys, cuda=False, device=None): extract_bits = block.extract_sign(W) total_bit += len(extract_bits) signature_correct_count += (extract_bits == signature).sum().detach().cpu().item() - + sign_acc = signature_correct_count / total_bit return sign_acc @@ -94,7 +95,9 @@ def _suggest_sign_bit(param_num, client_num): max_signbit = param_num // client_num max_signbit -= 1 # not to exceed if max_signbit <= 0: - raise ValueError('not able to add feature based watermark, param_num is {}, client num is {}, computed max bit is {} <=0'.format(param_num, client_num, max_signbit)) + raise ValueError( + 'not able to add feature based watermark, param_num is {}, client num is {}, computed max bit is {} <=0'.format( + param_num, client_num, max_signbit)) return max_signbit @@ -109,27 +112,41 @@ def compute_sign_bit(model, client_num): def verify_feature_based_signature(model, keys): - + model = model.cpu() sign_blocks = get_sign_blocks(model) return _verify_sign_blocks(sign_blocks, keys, cuda=False) - class FedIPRTrainer(FedAVGTrainer): - def __init__(self, epochs=10, noraml_dataset_batch_size=32, watermark_dataset_batch_size=2, - early_stop=None, tol=0.0001, secure_aggregate=True, weighted_aggregation=True, - aggregate_every_n_epoch=None, cuda=None, pin_memory=True, shuffle=True, - data_loader_worker=0, validation_freqs=None, checkpoint_save_freqs=None, + def __init__(self, epochs=10, noraml_dataset_batch_size=32, watermark_dataset_batch_size=2, + early_stop=None, tol=0.0001, secure_aggregate=True, weighted_aggregation=True, + aggregate_every_n_epoch=None, cuda=None, pin_memory=True, shuffle=True, + data_loader_worker=0, validation_freqs=None, checkpoint_save_freqs=None, task_type='auto', save_to_local_dir=False, collate_fn=None, collate_fn_params=None, alpha=0.01, verify_freqs=1, backdoor_verify_method: Literal['accuracy', 'loss'] = 'accuracy' ): - - super().__init__(epochs, noraml_dataset_batch_size, early_stop, tol, secure_aggregate, weighted_aggregation, - aggregate_every_n_epoch, cuda, pin_memory, shuffle, data_loader_worker, - validation_freqs, checkpoint_save_freqs, task_type, save_to_local_dir, collate_fn, collate_fn_params) - + + super().__init__( + epochs, + noraml_dataset_batch_size, + early_stop, + tol, + secure_aggregate, + weighted_aggregation, + aggregate_every_n_epoch, + cuda, + pin_memory, + shuffle, + data_loader_worker, + validation_freqs, + checkpoint_save_freqs, + task_type, + save_to_local_dir, + collate_fn, + collate_fn_params) + self.normal_train_set = None self.watermark_set = None self.data_loader = None @@ -177,7 +194,6 @@ def _handle_dataset(self, train_set, collate_fn): sampler=train_sampler ) - def _get_train_data_loader(self, train_set): collate_fn = self._get_collate_fn(train_set) @@ -210,7 +226,7 @@ def _get_device(self): return device else: return None - + def verify(self, sign_blocks: dict, keys: dict): return _verify_sign_blocks(sign_blocks, keys, self.cuda is not None, self._get_device()) @@ -232,24 +248,24 @@ def get_loss_from_pred(self, loss, pred, batch_label): raise ValueError( 'Trainer requires a loss function, but got None, please specify loss function in the' ' job configuration') - + return batch_loss - + def _get_keys(self, sign_blocks): - + if self._sign_keys is None: self._sign_keys = get_keys(sign_blocks, self._sign_bits) return self._sign_keys - + def _get_sign_blocks(self): if self._sign_blocks is None: sign_blocks = get_sign_blocks(self.model) self._sign_blocks = sign_blocks return self._sign_blocks - - def train(self, train_set: Dataset, validate_set: Dataset = None, optimizer = None, loss=None, extra_dict={}): - + + def train(self, train_set: Dataset, validate_set: Dataset = None, optimizer=None, loss=None, extra_dict={}): + if 'keys' in extra_dict: self._sign_keys = extra_dict['keys'] self._sign_bits = extra_dict['num_bits'] @@ -258,7 +274,6 @@ def train(self, train_set: Dataset, validate_set: Dataset = None, optimizer = No if self._client_num is None and self.party_id_list is not None: self._client_num = len(self.party_id_list) self._sign_bits = compute_sign_bit(self.model, self._client_num) - LOGGER.info('client num {}, party id list {}'.format(self._client_num, self.party_id_list)) LOGGER.info('will assign {} bits for feature based watermark'.format(self._sign_bits)) @@ -322,7 +337,7 @@ def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss_func): optimizer.zero_grad() else: model.zero_grad() - + pred = model(batch_data) sign_loss = 0 @@ -337,8 +352,7 @@ def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss_func): signature = self.to_cuda(signature, device) sign_loss += self.alpha * block.sign_loss(W, signature) - - batch_loss = self.get_loss_from_pred(loss_func, pred, batch_label) + batch_loss = self.get_loss_from_pred(loss_func, pred, batch_label) batch_loss += sign_loss if not self._enable_deepspeed: @@ -445,7 +459,7 @@ def predict(self, dataset: Dataset): if distributed_util.is_distributed() and not distributed_util.is_rank_0(): return - + if isinstance(dataset, WaterMarkDataset): normal_train_set = dataset.get_normal_dataset() if normal_train_set is None: @@ -460,7 +474,7 @@ def predict(self, dataset: Dataset): ids, ret_rs, ret_label, task_type=self.task_type) else: return ret_rs, ret_label - + def save( self, model=None, @@ -485,5 +499,3 @@ def local_save(self, extra_data = {'keys': self._sign_keys, 'num_bits': self._sign_bits} super().local_save(model, epoch_idx, optimizer, converge_status, loss_history, best_epoch, extra_data) - - diff --git a/python/federatedml/nn/homo/trainer/trainer_base.py b/python/federatedml/nn/homo/trainer/trainer_base.py index 2dc3c71529..cf47061e21 100644 --- a/python/federatedml/nn/homo/trainer/trainer_base.py +++ b/python/federatedml/nn/homo/trainer/trainer_base.py @@ -554,12 +554,13 @@ def unwrap_model(model): Load Trainer """ + def get_trainer_class(trainer_module_name: str): if trainer_module_name.endswith('.py'): trainer_module_name = trainer_module_name.replace('.py', '') - + std_fate_trainer_path = '{}.homo.trainer.{}'.format(ML_PATH, trainer_module_name) - + paths_to_check = [std_fate_trainer_path] errors = [] try: @@ -568,7 +569,7 @@ def get_trainer_class(trainer_module_name: str): paths_to_check.append(fate_llm_trainer_path) except Exception as e: pass - + trainers = [] ds_modules = None @@ -580,10 +581,12 @@ def get_trainer_class(trainer_module_name: str): errors.append(str(e)) if ds_modules is None: - raise ImportError('Could not import from any of the paths: {}, error details {}'.format(', '.join(paths_to_check), errors)) + raise ImportError( + 'Could not import from any of the paths: {}, error details {}'.format( + ', '.join(paths_to_check), errors)) for k, v in ds_modules.__dict__.items(): - + if isinstance(v, type): if issubclass(v, TrainerBase) and v is not TrainerBase: trainers.append(v) From a6f5a5575e43111ffcd1a221fb7d0547c7ea787f Mon Sep 17 00:00:00 2001 From: weijingchen Date: Tue, 22 Aug 2023 10:42:45 +0800 Subject: [PATCH 6/9] Fix log & update files Signed-off-by: weijingchen --- .../nn/homo/trainer/fedavg_trainer.py | 2 +- .../nn/homo/trainer/fedipr_trainer.py | 501 ------------------ python/federatedml/nn/model_zoo/sign_block.py | 187 ------- 3 files changed, 1 insertion(+), 689 deletions(-) delete mode 100644 python/federatedml/nn/homo/trainer/fedipr_trainer.py delete mode 100644 python/federatedml/nn/model_zoo/sign_block.py diff --git a/python/federatedml/nn/homo/trainer/fedavg_trainer.py b/python/federatedml/nn/homo/trainer/fedavg_trainer.py index 97722912a7..dc74ba3fb6 100644 --- a/python/federatedml/nn/homo/trainer/fedavg_trainer.py +++ b/python/federatedml/nn/homo/trainer/fedavg_trainer.py @@ -277,7 +277,7 @@ def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss_func): if self.fed_mode: if batch_idx % (total_batch_len // 100) == 0: percentage = (batch_idx / total_batch_len) * 100 - LOGGER.info(f"Training progress of epoch {epoch_idx}: {percentage:.1f}%") + LOGGER.debug(f"Training progress of epoch {epoch_idx}: {percentage:.1f}%") epoch_loss = epoch_loss / len(train_set) return epoch_loss diff --git a/python/federatedml/nn/homo/trainer/fedipr_trainer.py b/python/federatedml/nn/homo/trainer/fedipr_trainer.py deleted file mode 100644 index c39523d76a..0000000000 --- a/python/federatedml/nn/homo/trainer/fedipr_trainer.py +++ /dev/null @@ -1,501 +0,0 @@ -import torch as t -import tqdm -import numpy as np -import torch -from typing import Literal -from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer -from federatedml.nn.backend.utils import distributed_util -from torch.utils.data import DataLoader, DistributedSampler -import torch.distributed as dist -from federatedml.nn.dataset.watermark import WaterMarkImageDataset, WaterMarkDataset -from federatedml.util import LOGGER -from federatedml.nn.model_zoo.sign_block import generate_signature, is_sign_block -from federatedml.nn.model_zoo.sign_block import SignatureBlock -from sklearn.metrics import accuracy_score -from federatedml.nn.dataset.base import Dataset -from federatedml.util import consts - - -def get_sign_blocks(model: torch.nn.Module): - - record_sign_block = {} - for name, m in model.named_modules(): - if is_sign_block(m): - record_sign_block[name] = m - - return record_sign_block - - -def get_keys(sign_block_dict: dict, num_bits: int): - - key_pairs = {} - param_len = [] - sum_allocated_bits = 0 - # Iterate through each layer and compute the flattened parameter lengths - for k, v in sign_block_dict.items(): - param_len.append(len(v.embeded_param.flatten())) - total_param_len = sum(param_len) - - alloc_bits = {} - - for i, (k, v) in enumerate(sign_block_dict.items()): - allocated_bits = int((param_len[i] / total_param_len) * num_bits) - alloc_bits[k] = allocated_bits - sum_allocated_bits += allocated_bits - - rest_bits = num_bits - sum_allocated_bits - if rest_bits > 0: - alloc_bits[k] += rest_bits - - for k, v in sign_block_dict.items(): - key_pairs[k] = generate_signature(v, alloc_bits[k]) - - return key_pairs - - -""" -Verify Tools -""" - - -def to_cuda(var, device=0): - if hasattr(var, 'cuda'): - return var.cuda(device) - elif isinstance(var, tuple) or isinstance(var, list): - ret = tuple(to_cuda(i) for i in var) - return ret - elif isinstance(var, dict): - for k in var: - if hasattr(var[k], 'cuda'): - var[k] = var[k].cuda(device) - return var - else: - return var - - -def _verify_sign_blocks(sign_blocks, keys, cuda=False, device=None): - - signature_correct_count = 0 - total_bit = 0 - for name, block in sign_blocks.items(): - block: SignatureBlock = block - W, signature = keys[name] - if cuda: - W = to_cuda(W, device=device) - signature = to_cuda(signature, device=device) - extract_bits = block.extract_sign(W) - total_bit += len(extract_bits) - signature_correct_count += (extract_bits == signature).sum().detach().cpu().item() - - sign_acc = signature_correct_count / total_bit - return sign_acc - - -def _suggest_sign_bit(param_num, client_num): - max_signbit = param_num // client_num - max_signbit -= 1 # not to exceed - if max_signbit <= 0: - raise ValueError( - 'not able to add feature based watermark, param_num is {}, client num is {}, computed max bit is {} <=0'.format( - param_num, client_num, max_signbit)) - return max_signbit - - -def compute_sign_bit(model, client_num): - total_param_num = 0 - blocks = get_sign_blocks(model) - for k, v in blocks.items(): - total_param_num += v.embeded_param_num() - if total_param_num == 0: - return 0 - return _suggest_sign_bit(total_param_num, client_num) - - -def verify_feature_based_signature(model, keys): - - model = model.cpu() - sign_blocks = get_sign_blocks(model) - return _verify_sign_blocks(sign_blocks, keys, cuda=False) - - -class FedIPRTrainer(FedAVGTrainer): - - def __init__(self, epochs=10, noraml_dataset_batch_size=32, watermark_dataset_batch_size=2, - early_stop=None, tol=0.0001, secure_aggregate=True, weighted_aggregation=True, - aggregate_every_n_epoch=None, cuda=None, pin_memory=True, shuffle=True, - data_loader_worker=0, validation_freqs=None, checkpoint_save_freqs=None, - task_type='auto', save_to_local_dir=False, collate_fn=None, collate_fn_params=None, - alpha=0.01, verify_freqs=1, backdoor_verify_method: Literal['accuracy', 'loss'] = 'accuracy' - ): - - super().__init__( - epochs, - noraml_dataset_batch_size, - early_stop, - tol, - secure_aggregate, - weighted_aggregation, - aggregate_every_n_epoch, - cuda, - pin_memory, - shuffle, - data_loader_worker, - validation_freqs, - checkpoint_save_freqs, - task_type, - save_to_local_dir, - collate_fn, - collate_fn_params) - - self.normal_train_set = None - self.watermark_set = None - self.data_loader = None - self.normal_dataset_batch_size = noraml_dataset_batch_size - self.watermark_dataset_batch_size = watermark_dataset_batch_size - self.alpha = alpha - self.verify_freqs = verify_freqs - self.backdoor_verify_method = backdoor_verify_method - self._sign_keys = None - self._sign_blocks = None - self._client_num = None - self._sign_bits = None - - assert self.alpha > 0, 'alpha must be greater than 0' - assert self.verify_freqs > 0 and isinstance(self.verify_freqs, int), 'verify_freqs must be greater than 0' - assert self.backdoor_verify_method in ['accuracy', 'loss'], 'backdoor_verify_method must be accuracy or loss' - - def local_mode(self): - self.fed_mode = False - self._client_num = 1 - - def _handle_dataset(self, train_set, collate_fn): - - if not distributed_util.is_distributed() or distributed_util.get_num_workers() <= 1: - return DataLoader( - train_set, - batch_size=self.batch_size, - pin_memory=self.pin_memory, - shuffle=self.shuffle, - num_workers=self.data_loader_worker, - collate_fn=collate_fn - ) - else: - train_sampler = DistributedSampler( - train_set, - num_replicas=dist.get_world_size(), - rank=dist.get_rank() - ) - return DataLoader( - train_set, - batch_size=self.batch_size, - pin_memory=self.pin_memory, - num_workers=self.data_loader_worker, - collate_fn=collate_fn, - sampler=train_sampler - ) - - def _get_train_data_loader(self, train_set): - - collate_fn = self._get_collate_fn(train_set) - - if isinstance(train_set, WaterMarkDataset): - LOGGER.info('detect watermark dataset, split watermark dataset and normal dataset') - normal_train_set = train_set.get_normal_dataset() - watermark_set = train_set.get_watermark_dataset() - if normal_train_set is None: - raise ValueError('normal dataset must not be None in FedIPR algo') - train_dataloder = self._handle_dataset(normal_train_set, collate_fn) - - if watermark_set is not None: - watermark_dataloader = self._handle_dataset(watermark_set, collate_fn) - else: - watermark_dataloader = None - self.normal_train_set = normal_train_set - self.watermark_set = watermark_set - dataloaders = {'train': train_dataloder, 'watermark': watermark_dataloader} - return dataloaders - else: - LOGGER.info('detect non-watermark dataset') - train_dataloder = self._handle_dataset(train_set, collate_fn) - dataloaders = {'train': train_dataloder, 'watermark': None} - return dataloaders - - def _get_device(self): - if self.cuda is not None or self._enable_deepspeed: - device = self.cuda_main_device if self.cuda_main_device is not None else self.model.device - return device - else: - return None - - def verify(self, sign_blocks: dict, keys: dict): - - return _verify_sign_blocks(sign_blocks, keys, self.cuda is not None, self._get_device()) - - def get_loss_from_pred(self, loss, pred, batch_label): - - if not loss and hasattr(pred, "loss"): - batch_loss = pred.loss - - elif loss is not None: - if batch_label is None: - raise ValueError( - "When loss is set, please provide label to calculate loss" - ) - if not isinstance(pred, torch.Tensor) and hasattr(pred, "logits"): - pred = pred.logits - batch_loss = loss(pred, batch_label) - else: - raise ValueError( - 'Trainer requires a loss function, but got None, please specify loss function in the' - ' job configuration') - - return batch_loss - - def _get_keys(self, sign_blocks): - - if self._sign_keys is None: - self._sign_keys = get_keys(sign_blocks, self._sign_bits) - return self._sign_keys - - def _get_sign_blocks(self): - if self._sign_blocks is None: - sign_blocks = get_sign_blocks(self.model) - self._sign_blocks = sign_blocks - - return self._sign_blocks - - def train(self, train_set: Dataset, validate_set: Dataset = None, optimizer=None, loss=None, extra_dict={}): - - if 'keys' in extra_dict: - self._sign_keys = extra_dict['keys'] - self._sign_bits = extra_dict['num_bits'] - else: - LOGGER.info('computing feature based sign bits') - if self._client_num is None and self.party_id_list is not None: - self._client_num = len(self.party_id_list) - self._sign_bits = compute_sign_bit(self.model, self._client_num) - - LOGGER.info('client num {}, party id list {}'.format(self._client_num, self.party_id_list)) - LOGGER.info('will assign {} bits for feature based watermark'.format(self._sign_bits)) - return super().train(train_set, validate_set, optimizer, loss, extra_dict) - - def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss_func): - - epoch_loss = 0.0 - batch_idx = 0 - acc_num = 0 - - sign_blocks = self._get_sign_blocks() - keys = self._get_keys(sign_blocks) - - dl, watermark_dl = self.data_loader['train'], self.data_loader['watermark'] - if isinstance(dl, DistributedSampler): - dl.sampler.set_epoch(epoch_idx) - if isinstance(watermark_dl, DistributedSampler): - watermark_dl.sampler.set_epoch(epoch_idx) - - if not self.fed_mode: - trainset_iterator = tqdm.tqdm(dl) - else: - trainset_iterator = dl - batch_label = None - - # collect watermark data and mix them into the training data - watermark_collect = [] - if watermark_dl is not None: - for watermark_batch in watermark_dl: - watermark_collect.append(watermark_batch) - - for _batch_iter in trainset_iterator: - - _batch_iter = self._decode(_batch_iter) - - if isinstance(_batch_iter, list) or isinstance(_batch_iter, tuple): - batch_data, batch_label = _batch_iter - else: - batch_data = _batch_iter - - if watermark_dl is not None: - # Mix the backdoor sample into the training data - wm_batch_idx = int(batch_idx % len(watermark_collect)) - wm_batch = watermark_collect[wm_batch_idx] - if isinstance(wm_batch, list): - wm_batch_data, wm_batch_label = wm_batch - batch_data = torch.cat([batch_data, wm_batch_data], dim=0) - batch_label = torch.cat([batch_label, wm_batch_label], dim=0) - else: - wm_batch_data = wm_batch - batch_data = torch.cat([batch_data, wm_batch_data], dim=0) - - if self.cuda is not None or self._enable_deepspeed: - device = self.cuda_main_device if self.cuda_main_device is not None else self.model.device - batch_data = self.to_cuda(batch_data, device) - if batch_label is not None: - batch_label = self.to_cuda(batch_label, device) - - if not self._enable_deepspeed: - optimizer.zero_grad() - else: - model.zero_grad() - - pred = model(batch_data) - - sign_loss = 0 - # Get the sign loss of model - for name, block in sign_blocks.items(): - - block: SignatureBlock = block - W, signature = keys[name] - if self.cuda is not None: - device = self._get_device() - W = self.to_cuda(W, device) - signature = self.to_cuda(signature, device) - sign_loss += self.alpha * block.sign_loss(W, signature) - - batch_loss = self.get_loss_from_pred(loss_func, pred, batch_label) - batch_loss += sign_loss - - if not self._enable_deepspeed: - - batch_loss.backward() - optimizer.step() - batch_loss_np = np.array(batch_loss.detach().tolist()) if self.cuda is None \ - else np.array(batch_loss.cpu().detach().tolist()) - - if acc_num + self.batch_size > len(train_set): - batch_len = len(train_set) - acc_num - else: - batch_len = self.batch_size - - epoch_loss += batch_loss_np * batch_len - else: - batch_loss = model.backward(batch_loss) - batch_loss_np = np.array(batch_loss.cpu().detach().tolist()) - model.step() - batch_loss_np = self._sync_loss(batch_loss_np * self._get_batch_size(batch_data)) - if distributed_util.is_rank_0(): - epoch_loss += batch_loss_np - - batch_idx += 1 - - if self.fed_mode: - LOGGER.debug( - 'epoch {} batch {} finished'.format(epoch_idx, batch_idx)) - - epoch_loss = epoch_loss / len(train_set) - - # verify the sign of model during training - if epoch_idx % self.verify_freqs == 0: - # verify feature-based signature - sign_acc = self.verify(sign_blocks, keys) - LOGGER.info(f"epoch {epoch_idx} sign accuracy: {sign_acc}") - # verify backdoor signature - if self.watermark_set is not None: - _, pred, label = self._predict(self.watermark_set) - pred = pred.detach().cpu() - label = label.detach().cpu() - if self.backdoor_verify_method == 'accuracy': - if not isinstance(pred, torch.Tensor) and hasattr(pred, "logits"): - pred = pred.logits - pred = pred.numpy().reshape((len(label), -1)) - label = label.numpy() - pred_label = np.argmax(pred, axis=1) - metric = accuracy_score(pred_label.flatten(), label.flatten()) - else: - metric = self.get_loss_from_pred(loss_func, pred, label) - - LOGGER.info(f"epoch {epoch_idx} backdoor {self.backdoor_verify_method}: {metric}") - - return epoch_loss - - def _predict(self, dataset: Dataset): - pred_result = [] - - # switch eval mode - dataset.eval() - model = self._select_model() - model.eval() - - if not dataset.has_sample_ids(): - dataset.init_sid_and_getfunc(prefix=dataset.get_type()) - - labels = [] - with torch.no_grad(): - for _batch_iter in DataLoader( - dataset, self.batch_size - ): - if isinstance(_batch_iter, list): - batch_data, batch_label = _batch_iter - else: - batch_label = _batch_iter.pop("labels") - batch_data = _batch_iter - - if self.cuda is not None or self._enable_deepspeed: - device = self.cuda_main_device if self.cuda_main_device is not None else self.model.device - batch_data = self.to_cuda(batch_data, device) - - pred = model(batch_data) - - if not isinstance(pred, torch.Tensor) and hasattr(pred, "logits"): - pred = pred.logits - - pred_result.append(pred) - labels.append(batch_label) - - ret_rs = torch.concat(pred_result, axis=0) - ret_label = torch.concat(labels, axis=0) - - # switch back to train mode - dataset.train() - model.train() - - return dataset.get_sample_ids(), ret_rs, ret_label - - def predict(self, dataset: Dataset): - - if self.task_type in [consts.CAUSAL_LM, consts.SEQ_2_SEQ_LM]: - LOGGER.warning(f"Not support prediction of task_types={[consts.CAUSAL_LM, consts.SEQ_2_SEQ_LM]}") - return - - if distributed_util.is_distributed() and not distributed_util.is_rank_0(): - return - - if isinstance(dataset, WaterMarkDataset): - normal_train_set = dataset.get_normal_dataset() - if normal_train_set is None: - raise ValueError('normal train set is None in FedIPR algo predict function') - else: - normal_train_set = normal_train_set - - ids, ret_rs, ret_label = self._predict(normal_train_set) - - if self.fed_mode: - return self.format_predict_result( - ids, ret_rs, ret_label, task_type=self.task_type) - else: - return ret_rs, ret_label - - def save( - self, - model=None, - epoch_idx=-1, - optimizer=None, - converge_status=False, - loss_history=None, - best_epoch=-1, - extra_data={}): - - extra_data = {'keys': self._sign_keys, 'num_bits': self._sign_bits} - super().save(model, epoch_idx, optimizer, converge_status, loss_history, best_epoch, extra_data) - - def local_save(self, - model=None, - epoch_idx=-1, - optimizer=None, - converge_status=False, - loss_history=None, - best_epoch=-1, - extra_data={}): - - extra_data = {'keys': self._sign_keys, 'num_bits': self._sign_bits} - super().local_save(model, epoch_idx, optimizer, converge_status, loss_history, best_epoch, extra_data) diff --git a/python/federatedml/nn/model_zoo/sign_block.py b/python/federatedml/nn/model_zoo/sign_block.py deleted file mode 100644 index 338afeefa7..0000000000 --- a/python/federatedml/nn/model_zoo/sign_block.py +++ /dev/null @@ -1,187 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.init as init -from torch.nn import functional as F -from federatedml.util import LOGGER - -""" -Base -""" - - -class SignatureBlock(nn.Module): - - def __init__(self) -> None: - super().__init__() - - @property - def embeded_param(self): - return None - - def extract_sign(self, W): - pass - - def sign_loss(self, W, sign): - pass - - def embeded_param_num(self): - pass - - -def is_sign_block(block): - return issubclass(type(block), SignatureBlock) - - -class ConvBlock(nn.Module): - def __init__(self, i, o, ks=3, s=1, pd=1, relu=True): - super().__init__() - - self.conv = nn.Conv2d(i, o, ks, s, pd, bias= False) - - if relu: - self.relu = nn.ReLU(inplace=True) - else: - self.relu = None - - self.reset_parameters() - - def reset_parameters(self): - init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu') - - def forward(self, x): - x = self.conv(x) - if self.relu is not None: - x = self.relu(x) - return x - - -def generate_signature(conv_block: SignatureBlock, num_bits): - - sign = torch.sign(torch.rand(num_bits) - 0.5) - W = torch.randn(len(conv_block.embeded_param.flatten()), num_bits) - - return (W, sign) - - -""" -Function & Class for Conv Layer -""" - - -class SignatureConv(SignatureBlock): - - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False): - super(SignatureConv, self).__init__() - - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) - self.weight = self.conv.weight - - self._embed_para_num = None - self.init_scale() - self.init_bias() - self.bn = nn.BatchNorm2d(out_channels, affine=False) - self.relu = nn.ReLU(inplace=True) - self.reset_parameters() - - def embeded_param_num(self): - return self._embed_para_num - - def init_bias(self): - self.bias = nn.Parameter(torch.Tensor(self.conv.out_channels).to(self.weight.device)) - init.zeros_(self.bias) - - def init_scale(self): - self.scale = nn.Parameter(torch.Tensor(self.conv.out_channels).to(self.weight.device)) - init.ones_(self.scale) - self._embed_para_num = self.scale.shape[0] - - def reset_parameters(self): - init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu') - - @property - def embeded_param(self): - # embedded in the BatchNorm param, as the same in the paper - return self.scale - - def extract_sign(self, W): - # W is the linear weight for extracting signature - with torch.no_grad(): - return self.scale.view([1, -1]).mm(W).sign().flatten() - - def sign_loss(self, W, sign): - loss = F.relu(-self.scale.view([1, -1]).mm(W).mul(sign.view(-1))).sum() - return loss - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - x = x * self.scale[None, :, None, None] + self.bias[None, :, None, None] - x = self.relu(x) - return x - - -""" -Function & Class for LM -""" - - -def recursive_replace_layernorm(module, layer_name_set=None): - - """ - Recursively replaces the LayerNorm layers of a given module with SignatureLayerNorm layers. - - Parameters: - module (torch.nn.Module): The module in which LayerNorm layers should be replaced. - layer_name_set (set[str], optional): A set of layer names to be replaced. If None, - all LayerNorm layers in the module will be replaced. - """ - - for name, sub_module in module.named_children(): - if isinstance(sub_module, nn.LayerNorm): - if layer_name_set is not None and name not in layer_name_set: - continue - setattr(module, name, SignatureLayerNorm.from_layer_norm_layer(sub_module)) - LOGGER.debug(f"Replace {name} with SignatureLayerNorm") - recursive_replace_layernorm(sub_module, layer_name_set) - - -class SignatureLayerNorm(SignatureBlock): - - def __init__(self, normalized_shape=None, eps=1e-5, elementwise_affine=True, layer_norm_inst=None): - super(SignatureLayerNorm, self).__init__() - if layer_norm_inst is not None and isinstance(layer_norm_inst, nn.LayerNorm): - self.ln = layer_norm_inst - else: - self.ln = nn.LayerNorm(normalized_shape, eps, elementwise_affine) - - self._embed_param_num = self.ln.weight.numel() - - @property - def embeded_param(self): - return self.ln.weight - - def embeded_param_num(self): - return self._embed_param_num - - @staticmethod - def from_layer_norm_layer(layer_norm_layer: nn.LayerNorm): - return SignatureLayerNorm(layer_norm_inst=layer_norm_layer) - - def extract_sign(self, W): - # W is the linear weight for extracting signature - with torch.no_grad(): - return self.ln.weight.view([1, -1]).mm(W).sign().flatten() - - def sign_loss(self, W, sign): - loss = F.relu(-self.ln.weight.view([1, -1]).mm(W).mul(sign.view(-1))).sum() - return loss - - def forward(self, x): - return self.ln(x) - - -if __name__ == "__main__": - conv = SignatureConv(3, 384, 3, 1, 1) - layer_norm = SignatureLayerNorm((768, )) - layer_norm_2 = SignatureLayerNorm.from_layer_norm_layer(layer_norm.ln) - \ No newline at end of file From 33e8128eea133362203955b1f93287600633efb4 Mon Sep 17 00:00:00 2001 From: Chen Date: Tue, 22 Aug 2023 15:01:42 +0800 Subject: [PATCH 7/9] Update python/fate_client/pipeline/component/homo_nn.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Signed-off-by: Chen --- python/fate_client/pipeline/component/homo_nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/fate_client/pipeline/component/homo_nn.py b/python/fate_client/pipeline/component/homo_nn.py index 281bec24d5..fa3ef87227 100644 --- a/python/fate_client/pipeline/component/homo_nn.py +++ b/python/fate_client/pipeline/component/homo_nn.py @@ -68,7 +68,7 @@ class HomoNN(FateComponent): optimizer, optimizer from fate_torch ds_config, config for deepspeed model, a fate torch sequential defining the model structure - server_init, whether to initialize the model, loss and optimizer on server, if configs are provided, they will be used. In + server_init, whether to initialize the model, loss and optimizer on server, if configs are provided, they will be used. In current version this option is specially designed for offsite-tuning """ From c795110886824b517bd1e21f36d42dc49e1ae67b Mon Sep 17 00:00:00 2001 From: Chen Date: Tue, 22 Aug 2023 15:01:51 +0800 Subject: [PATCH 8/9] Update python/fate_client/pipeline/component/homo_nn.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Signed-off-by: Chen --- python/fate_client/pipeline/component/homo_nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/fate_client/pipeline/component/homo_nn.py b/python/fate_client/pipeline/component/homo_nn.py index fa3ef87227..701abfb918 100644 --- a/python/fate_client/pipeline/component/homo_nn.py +++ b/python/fate_client/pipeline/component/homo_nn.py @@ -86,7 +86,7 @@ def __init__(self, loss=None, optimizer: OptimizerType = None, ds_config: dict = None, - model: Sequential = None, + model: Sequential = None, server_init: bool = False, **kwargs): From 5bfc8979fb8c5c4a386cb984f4b693d137d72967 Mon Sep 17 00:00:00 2001 From: weijingchen Date: Tue, 22 Aug 2023 15:05:21 +0800 Subject: [PATCH 9/9] Update autopep8 Signed-off-by: weijingchen --- python/fate_client/pipeline/component/homo_nn.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/fate_client/pipeline/component/homo_nn.py b/python/fate_client/pipeline/component/homo_nn.py index 701abfb918..9f821b6566 100644 --- a/python/fate_client/pipeline/component/homo_nn.py +++ b/python/fate_client/pipeline/component/homo_nn.py @@ -100,8 +100,15 @@ def __init__(self, self.input = Input(self.name, data_type="multi") self.output = Output(self.name, data_type='single') self._module_name = "HomoNN" - self._updated = {'trainer': False, 'dataset': False, - 'torch_seed': False, 'loss': False, 'optimizer': False, 'model': False, 'ds_config': False, 'server_init': False} + self._updated = { + 'trainer': False, + 'dataset': False, + 'torch_seed': False, + 'loss': False, + 'optimizer': False, + 'model': False, + 'ds_config': False, + 'server_init': False} self._set_param(kwargs["explict_parameters"]) self._check_parameters()