diff --git a/python/fate_client/pipeline/component/homo_nn.py b/python/fate_client/pipeline/component/homo_nn.py index c6b68adf6a..9f821b6566 100644 --- a/python/fate_client/pipeline/component/homo_nn.py +++ b/python/fate_client/pipeline/component/homo_nn.py @@ -44,7 +44,8 @@ 'loss': None, 'optimizer': None, 'nn_define': None, - 'ds_config': None + 'ds_config': None, + 'server_init': False } except Exception as e: print(e) @@ -65,7 +66,10 @@ class HomoNN(FateComponent): torch_seed, global random seed loss, loss function from fate_torch optimizer, optimizer from fate_torch + ds_config, config for deepspeed model, a fate torch sequential defining the model structure + server_init, whether to initialize the model, loss and optimizer on server, if configs are provided, they will be used. In + current version this option is specially designed for offsite-tuning """ @extract_explicit_parameter @@ -82,7 +86,9 @@ def __init__(self, loss=None, optimizer: OptimizerType = None, ds_config: dict = None, - model: Sequential = None, **kwargs): + model: Sequential = None, + server_init: bool = False, + **kwargs): explicit_parameters = copy.deepcopy(DEFAULT_PARAM_DICT) if 'name' not in kwargs["explict_parameters"]: @@ -94,8 +100,15 @@ def __init__(self, self.input = Input(self.name, data_type="multi") self.output = Output(self.name, data_type='single') self._module_name = "HomoNN" - self._updated = {'trainer': False, 'dataset': False, - 'torch_seed': False, 'loss': False, 'optimizer': False, 'model': False} + self._updated = { + 'trainer': False, + 'dataset': False, + 'torch_seed': False, + 'loss': False, + 'optimizer': False, + 'model': False, + 'ds_config': False, + 'server_init': False} self._set_param(kwargs["explict_parameters"]) self._check_parameters() diff --git a/python/fate_client/pipeline/component/nn/backend/torch/cust.py b/python/fate_client/pipeline/component/nn/backend/torch/cust.py index 4eba0c54c6..de60736081 100644 --- a/python/fate_client/pipeline/component/nn/backend/torch/cust.py +++ b/python/fate_client/pipeline/component/nn/backend/torch/cust.py @@ -3,9 +3,12 @@ from pipeline.component.nn.backend.torch.base import FateTorchLayer, FateTorchLoss import difflib +ML_PATH = 'federatedml.nn' +LLM_PATH = "fate_llm" -MODEL_PATH = None -LOSS_PATH = None +LLM_MODEL_PATH = '{}.model_zoo'.format(LLM_PATH) +MODEL_PATH = '{}.model_zoo'.format(ML_PATH) +LOSS_PATH = '{}.loss'.format(ML_PATH) def str_simi(str_a, str_b): @@ -45,9 +48,14 @@ class CustModel(FateTorchLayer, nn.Module): def __init__(self, module_name, class_name, **kwargs): super(CustModel, self).__init__() - assert isinstance(module_name, str), 'name must be a str, specify the module in the model_zoo' - assert isinstance(class_name, str), 'class name must be a str, specify the class in the module' - self.param_dict = {'module_name': module_name, 'class_name': class_name, 'param': kwargs} + assert isinstance( + module_name, str), 'name must be a str, specify the module in the model_zoo' + assert isinstance( + class_name, str), 'class name must be a str, specify the class in the module' + self.param_dict = { + 'module_name': module_name, + 'class_name': class_name, + 'param': kwargs} self._model = None def init_model(self): @@ -62,11 +70,18 @@ def forward(self, x): def get_pytorch_model(self, module_path=None): if module_path is None: - return get_class( - self.param_dict['module_name'], - self.param_dict['class_name'], - self.param_dict['param'], - MODEL_PATH) + try: + return get_class( + self.param_dict['module_name'], + self.param_dict['class_name'], + self.param_dict['param'], + MODEL_PATH) + except BaseException: + return get_class( + self.param_dict['module_name'], + self.param_dict['class_name'], + self.param_dict['param'], + LLM_MODEL_PATH) else: return get_class( self.param_dict['module_name'], diff --git a/python/federatedml/framework/homo/aggregator/aggregator_base.py b/python/federatedml/framework/homo/aggregator/aggregator_base.py index d00f8d52f0..d46e96f195 100644 --- a/python/federatedml/framework/homo/aggregator/aggregator_base.py +++ b/python/federatedml/framework/homo/aggregator/aggregator_base.py @@ -1,4 +1,5 @@ from federatedml.framework.homo.blocks import ServerCommunicator, ClientCommunicator +from federatedml.util import consts class AutoSuffix(object): @@ -19,7 +20,10 @@ def __call__(self): class AggregatorBaseClient(object): - def __init__(self, communicate_match_suffix: str = None): + def __init__( + self, communicate_match_suffix: str = None, server=( + consts.ARBITER,), clients=( + consts.GUEST, consts.HOST)): """Base class of client aggregator Parameters @@ -28,7 +32,7 @@ def __init__(self, communicate_match_suffix: str = None): To make sure that client and server can communicate correctly, the server-side and client-side aggregators need to have the same suffix """ - self.communicator = ClientCommunicator(prefix=communicate_match_suffix) + self.communicator = ClientCommunicator(prefix=communicate_match_suffix, server=server, clients=clients) self.suffix = {} def _get_suffix(self, var_name, user_suffix=tuple()): @@ -52,7 +56,7 @@ def get(self, suffix): class AggregatorBaseServer(object): - def __init__(self, communicate_match_suffix=None): + def __init__(self, communicate_match_suffix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)): """Base class of server aggregator Parameters @@ -61,7 +65,7 @@ def __init__(self, communicate_match_suffix=None): To make sure that client and server can communicate correctly, the server-side and client-side aggregators need to have the same suffix """ - self.communicator = ServerCommunicator(prefix=communicate_match_suffix) + self.communicator = ServerCommunicator(prefix=communicate_match_suffix, server=server, clients=clients) self.suffix = {} def _get_suffix(self, var_name, user_suffix=tuple()): diff --git a/python/federatedml/framework/homo/aggregator/secure_aggregator.py b/python/federatedml/framework/homo/aggregator/secure_aggregator.py index ce6a7ec545..4ed8303989 100644 --- a/python/federatedml/framework/homo/aggregator/secure_aggregator.py +++ b/python/federatedml/framework/homo/aggregator/secure_aggregator.py @@ -15,10 +15,10 @@ class SecureAggregatorClient(AggregatorBaseClient): def __init__(self, secure_aggregate=True, aggregate_type='weighted_mean', aggregate_weight=1.0, - communicate_match_suffix=None): + communicate_match_suffix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)): super(SecureAggregatorClient, self).__init__( - communicate_match_suffix=communicate_match_suffix) + communicate_match_suffix=communicate_match_suffix, clients=clients, server=server) self.secure_aggregate = secure_aggregate self.suffix = { "local_loss": AutoSuffix("local_loss"), @@ -31,7 +31,10 @@ def __init__(self, secure_aggregate=True, aggregate_type='weighted_mean', aggreg # init secure aggregate random padding: if self.secure_aggregate: self._random_padding_cipher: PadsCipher = RandomPaddingCipherClient( - trans_var=RandomPaddingCipherTransVar(prefix=communicate_match_suffix)).create_cipher() + trans_var=RandomPaddingCipherTransVar( + prefix=communicate_match_suffix, + clients=clients, + server=server)).create_cipher() LOGGER.info('initialize secure aggregator done') # compute weight @@ -186,9 +189,18 @@ def loss_aggregation(self, loss, suffix=tuple()): class SecureAggregatorServer(AggregatorBaseServer): - def __init__(self, secure_aggregate=True, communicate_match_suffix=None): + def __init__( + self, + secure_aggregate=True, + communicate_match_suffix=None, + server=( + consts.ARBITER, + ), + clients=( + consts.GUEST, + consts.HOST)): super(SecureAggregatorServer, self).__init__( - communicate_match_suffix=communicate_match_suffix) + communicate_match_suffix=communicate_match_suffix, clients=clients, server=server) self.suffix = { "local_loss": AutoSuffix("local_loss"), "agg_loss": AutoSuffix("agg_loss"), @@ -199,7 +211,7 @@ def __init__(self, secure_aggregate=True, communicate_match_suffix=None): self.secure_aggregate = secure_aggregate if self.secure_aggregate: RandomPaddingCipherServer(trans_var=RandomPaddingCipherTransVar( - prefix=communicate_match_suffix)).exchange_secret_keys() + prefix=communicate_match_suffix, clients=clients, server=server)).exchange_secret_keys() LOGGER.info('initialize secure aggregator done') agg_weights = self.collect(suffix=('agg_weight', )) diff --git a/python/federatedml/framework/homo/blocks.py b/python/federatedml/framework/homo/blocks.py index 7d2d730e4c..c894172819 100644 --- a/python/federatedml/framework/homo/blocks.py +++ b/python/federatedml/framework/homo/blocks.py @@ -60,8 +60,8 @@ def __init__(self, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST) class ServerCommunicator(object): - def __init__(self, prefix=None): - self.trans_var = CommunicatorTransVar(prefix=prefix) + def __init__(self, prefix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)): + self.trans_var = CommunicatorTransVar(prefix=prefix, server=server, clients=clients) self._client_parties = self.trans_var.client_parties def get_parties(self, party_idx): @@ -85,8 +85,8 @@ def broadcast_obj(self, obj, suffix=tuple(), party_idx=-1): class ClientCommunicator(object): - def __init__(self, prefix=None): - trans_var = CommunicatorTransVar(prefix=prefix) + def __init__(self, prefix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)): + trans_var = CommunicatorTransVar(prefix=prefix, server=server, clients=clients) self.trans_var = trans_var self._server_parties = trans_var.server_parties diff --git a/python/federatedml/nn/backend/torch/cust_model.py b/python/federatedml/nn/backend/torch/cust_model.py deleted file mode 100644 index e9fff34839..0000000000 --- a/python/federatedml/nn/backend/torch/cust_model.py +++ /dev/null @@ -1,55 +0,0 @@ -import importlib - -from torch import nn - -from federatedml.nn.backend.torch.base import FateTorchLayer -from federatedml.nn.backend.utils.common import ML_PATH - -PATH = '{}.model_zoo'.format(ML_PATH) - - -class CustModel(FateTorchLayer, nn.Module): - - def __init__(self, module_name, class_name, **kwargs): - super(CustModel, self).__init__() - assert isinstance( - module_name, str), 'name must be a str, specify the module in the model_zoo' - assert isinstance( - class_name, str), 'class name must be a str, specify the class in the module' - self.param_dict = { - 'module_name': module_name, - 'class_name': class_name, - 'param': kwargs} - self._model = None - - def init_model(self): - if self._model is None: - self._model = self.get_pytorch_model() - - def forward(self, x): - if self._model is None: - raise ValueError('model not init, call init_model() function') - return self._model(x) - - def get_pytorch_model(self): - - module_name: str = self.param_dict['module_name'] - class_name = self.param_dict['class_name'] - module_param: dict = self.param_dict['param'] - if module_name.endswith('.py'): - module_name = module_name.replace('.py', '') - nn_modules = importlib.import_module('{}.{}'.format(PATH, module_name)) - try: - for k, v in nn_modules.__dict__.items(): - if isinstance(v, type): - if issubclass( - v, nn.Module) and v is not nn.Module and v.__name__ == class_name: - return v(**module_param) - raise ValueError( - 'Did not find any class in {}.py that is pytorch nn.Module and named {}'. format( - module_name, class_name)) - except ValueError as e: - raise e - - def __repr__(self): - return 'CustModel({})'.format(str(self.param_dict)) diff --git a/python/federatedml/nn/homo/_init.py b/python/federatedml/nn/homo/_init.py new file mode 100644 index 0000000000..aa3fc96bc7 --- /dev/null +++ b/python/federatedml/nn/homo/_init.py @@ -0,0 +1,149 @@ +import json +import torch +import inspect +from federatedml.nn.homo.trainer.trainer_base import get_trainer_class, TrainerBase +from federatedml.util import LOGGER +from federatedml.nn.backend.torch import serialization as s +from federatedml.nn.backend.torch.base import FateTorchOptimizer +from federatedml.nn.backend.utils.common import recover_model_bytes +from federatedml.nn.backend.utils import deepspeed_util + + +def init( + trainer, + trainer_param, + nn_define, + config_optimizer, + config_loss, + torch_seed, + model_loaded_flag, + loaded_model, + ds_config): + + warm_start_iter = None + + if ds_config: + deepspeed_util.init_deepspeed_env(ds_config) + + # load trainer class + if trainer is None: + raise ValueError( + 'Trainer is not specified, please specify your trainer') + + trainer_class = get_trainer_class(trainer) + LOGGER.info('trainer class is {}'.format(trainer_class)) + + # recover model from model config / or recover from saved model param + loaded_model_dict = None + + # if has model protobuf, load model config from protobuf + load_opt_state_dict = False + + if model_loaded_flag: + param, meta = get_homo_param_meta(loaded_model) + LOGGER.info('save path is {}'.format(param.local_save_path)) + if param.local_save_path == '': + LOGGER.info('Load model from model protobuf') + warm_start_iter = param.epoch_idx + if param is None or meta is None: + raise ValueError( + 'model protobuf is None, make sure' + 'that your trainer calls export_model() function to save models') + + if meta.nn_define[0] is None: + raise ValueError( + 'nn_define is None, model protobuf has no nn-define, make sure' + 'that your trainer calls export_model() function to save models') + + nn_define = json.loads(meta.nn_define[0]) + loss = json.loads(meta.loss_func_define[0]) + optimizer = json.loads(meta.optimizer_define[0]) + loaded_model_dict = recover_model_bytes(param.model_bytes) + extra_data = recover_model_bytes(param.extra_data_bytes) + + else: + LOGGER.info('Load model from local save path') + save_dict = torch.load(open(param.local_save_path, 'rb')) + warm_start_iter = save_dict['epoch_idx'] + nn_define = save_dict['model_define'] + loss = save_dict['loss_define'] + optimizer = save_dict['optimizer_define'] + loaded_model_dict = save_dict + extra_data = save_dict['extra_data'] + + if config_optimizer is not None and optimizer != config_optimizer: + LOGGER.info('optimizer updated') + else: + config_optimizer = optimizer + load_opt_state_dict = True + + if config_loss is not None and config_loss != loss: + LOGGER.info('loss updated') + else: + config_loss = loss + else: + extra_data = {} + + # check key param + if nn_define is None: + raise ValueError( + 'Model structure is not defined, nn_define is None, please check your param') + + # get model from nn define + model = s.recover_sequential_from_dict(nn_define) + if loaded_model_dict: + model.load_state_dict(loaded_model_dict['model']) + LOGGER.info('load model state dict from check point') + + LOGGER.info('model structure is {}'.format(model)) + # init optimizer + if config_optimizer is not None and not ds_config: + optimizer_: FateTorchOptimizer = s.recover_optimizer_from_dict( + config_optimizer) + # pass model parameters to optimizer + optimizer = optimizer_.to_torch_instance(model.parameters()) + if load_opt_state_dict: + LOGGER.info('load optimizer state dict') + optimizer.load_state_dict(loaded_model_dict['optimizer']) + LOGGER.info('optimizer is {}'.format(optimizer)) + else: + optimizer = None + LOGGER.info('optimizer is not specified') + + # init loss + if config_loss is not None: + loss_fn = s.recover_loss_fn_from_dict(config_loss) + LOGGER.info('loss function is {}'.format(loss_fn)) + else: + loss_fn = None + LOGGER.info('loss function is not specified') + + # init trainer + trainer_inst: TrainerBase = trainer_class(**trainer_param) + LOGGER.info('trainer class is {}'.format(trainer_class)) + + trainer_train_args = inspect.getfullargspec(trainer_inst.train).args + args_format = [ + 'self', + 'train_set', + 'validate_set', + 'optimizer', + 'loss', + 'extra_data' + ] + if len(trainer_train_args) < 6: + raise ValueError( + 'Train function of trainer should take 6 arguments :{}, but current trainer.train ' + 'only takes {} arguments: {}'.format( + args_format, len(trainer_train_args), trainer_train_args)) + + trainer_inst.set_nn_config(nn_define, config_optimizer, config_loss) + trainer_inst.fed_mode = True + + if ds_config: + model, optimizer = deepspeed_util.deepspeed_init(model, ds_config) + trainer_inst.enable_deepspeed(is_zero_3=deepspeed_util.is_zero3(ds_config)) + if deepspeed_util.is_zero3(ds_config): + model.train() + + return trainer_inst, model, optimizer, loss_fn, extra_data, config_optimizer, config_loss, warm_start_iter diff --git a/python/federatedml/nn/homo/client.py b/python/federatedml/nn/homo/client.py index f705d145d6..c2d7d83fcb 100644 --- a/python/federatedml/nn/homo/client.py +++ b/python/federatedml/nn/homo/client.py @@ -23,6 +23,7 @@ from federatedml.nn.backend.utils.data import add_match_id from federatedml.protobuf.generated.homo_nn_model_param_pb2 import HomoNNParam as HomoNNParamPB from federatedml.protobuf.generated.homo_nn_model_meta_pb2 import HomoNNMeta as HomoNNMetaPB +from federatedml.nn.homo._init import init class NNModelExporter(ExporterBase): @@ -85,6 +86,12 @@ def export_model_dict( return get_homo_model_dict(param, meta) +def default_client_post_process(trainer): + model = trainer.get_cached_model() + summary = trainer.get_summary() + return model, summary + + class HomoNNClient(ModelBase): def __init__(self): @@ -136,136 +143,6 @@ def _init_model(self, param: HomoNNParam): self.optimizer = param.optimizer self.ds_config = param.ds_config - def init(self): - - # set random seed - global_seed(self.torch_seed) - - if self.ds_config: - deepspeed_util.init_deepspeed_env(self.ds_config) - - # load trainer class - if self.trainer is None: - raise ValueError( - 'Trainer is not specified, please specify your trainer') - - trainer_class = get_trainer_class(self.trainer) - LOGGER.info('trainer class is {}'.format(trainer_class)) - - # recover model from model config / or recover from saved model param - loaded_model_dict = None - - # if has model protobuf, load model config from protobuf - load_opt_state_dict = False - if self.model_loaded: - - param, meta = get_homo_param_meta(self.model) - LOGGER.info('save path is {}'.format(param.local_save_path)) - if param.local_save_path == '': - LOGGER.info('Load model from model protobuf') - self.warm_start_iter = param.epoch_idx - if param is None or meta is None: - raise ValueError( - 'model protobuf is None, make sure' - 'that your trainer calls export_model() function to save models') - - if meta.nn_define[0] is None: - raise ValueError( - 'nn_define is None, model protobuf has no nn-define, make sure' - 'that your trainer calls export_model() function to save models') - - self.nn_define = json.loads(meta.nn_define[0]) - loss = json.loads(meta.loss_func_define[0]) - optimizer = json.loads(meta.optimizer_define[0]) - loaded_model_dict = recover_model_bytes(param.model_bytes) - extra_data = recover_model_bytes(param.extra_data_bytes) - else: - LOGGER.info('Load model from local save path') - save_dict = torch.load(open(param.local_save_path, 'rb')) - self.warm_start_iter = save_dict['epoch_idx'] - self.nn_define = save_dict['model_define'] - loss = save_dict['loss_define'] - optimizer = save_dict['optimizer_define'] - loaded_model_dict = save_dict - extra_data = save_dict['extra_data'] - - if self.optimizer is not None and optimizer != self.optimizer: - LOGGER.info('optimizer updated') - else: - self.optimizer = optimizer - load_opt_state_dict = True - - if self.loss is not None and self.loss != loss: - LOGGER.info('loss updated') - else: - self.loss = loss - else: - extra_data = {} - - # check key param - if self.nn_define is None: - raise ValueError( - 'Model structure is not defined, nn_define is None, please check your param') - - # get model from nn define - model = s.recover_sequential_from_dict(self.nn_define) - if loaded_model_dict: - model.load_state_dict(loaded_model_dict['model']) - LOGGER.info('load model state dict from check point') - - LOGGER.info('model structure is {}'.format(model)) - # init optimizer - if self.optimizer is not None and not self.ds_config: - optimizer_: FateTorchOptimizer = s.recover_optimizer_from_dict( - self.optimizer) - # pass model parameters to optimizer - optimizer = optimizer_.to_torch_instance(model.parameters()) - if load_opt_state_dict: - LOGGER.info('load optimizer state dict') - optimizer.load_state_dict(loaded_model_dict['optimizer']) - LOGGER.info('optimizer is {}'.format(optimizer)) - else: - optimizer = None - LOGGER.info('optimizer is not specified') - - # init loss - if self.loss is not None: - loss_fn = s.recover_loss_fn_from_dict(self.loss) - LOGGER.info('loss function is {}'.format(loss_fn)) - else: - loss_fn = None - LOGGER.info('loss function is not specified') - - # init trainer - trainer_inst: TrainerBase = trainer_class(**self.trainer_param) - LOGGER.info('trainer class is {}'.format(trainer_class)) - - trainer_train_args = inspect.getfullargspec(trainer_inst.train).args - args_format = [ - 'self', - 'train_set', - 'validate_set', - 'optimizer', - 'loss', - 'extra_data' - ] - if len(trainer_train_args) < 6: - raise ValueError( - 'Train function of trainer should take 6 arguments :{}, but current trainer.train ' - 'only takes {} arguments: {}'.format( - args_format, len(trainer_train_args), trainer_train_args)) - - trainer_inst.set_nn_config(self.nn_define, self.optimizer, self.loss) - trainer_inst.fed_mode = True - - if self.ds_config: - model, optimizer = deepspeed_util.deepspeed_init(model, self.ds_config) - trainer_inst.enable_deepspeed(is_zero_3=deepspeed_util.is_zero3(self.ds_config)) - if deepspeed_util.is_zero3(self.ds_config): - model.train() - - return trainer_inst, model, optimizer, loss_fn, extra_data - def fit(self, train_input, validate_input=None): LOGGER.debug('train input is {}'.format(train_input)) @@ -294,10 +171,22 @@ def fit(self, train_input, validate_input=None): # set random seed global_seed(self.torch_seed) - self.trainer_inst, model, optimizer, loss_fn, extra_data = self.init() + # init + self.trainer_inst, model, optimizer, loss_fn, extra_data, self.optimizer, self.loss, self.warm_start_iter = init( + trainer=self.trainer, trainer_param=self.trainer_param, nn_define=self.nn_define, + config_optimizer=self.optimizer, config_loss=self.loss, torch_seed=self.torch_seed, model_loaded_flag=self.model_loaded, + loaded_model=self.model, ds_config=self.ds_config + ) + + # prepare to train self.trainer_inst.set_model(model) self.trainer_inst.set_tracker(self.tracker) self.trainer_inst.set_model_exporter(self.exporter) + party_id_list = [self.component_properties.guest_partyid] + if self.component_properties.host_party_idlist is not None: + for i in self.component_properties.host_party_idlist: + party_id_list.append(i) + self.trainer_inst.set_party_id_list(party_id_list) # load dataset class dataset_inst = load_dataset( @@ -339,8 +228,8 @@ def fit(self, train_input, validate_input=None): ) # training is done, get exported model - self.model = self.trainer_inst.get_cached_model() - self.set_summary(self.trainer_inst.get_summary()) + self.model, summary = default_client_post_process(self.trainer_inst) + self.set_summary(summary) def predict(self, cpn_input): diff --git a/python/federatedml/nn/homo/server.py b/python/federatedml/nn/homo/server.py index 106ab0757f..27fb226e32 100644 --- a/python/federatedml/nn/homo/server.py +++ b/python/federatedml/nn/homo/server.py @@ -6,6 +6,9 @@ from federatedml.nn.homo.client import NNModelExporter from federatedml.callbacks.model_checkpoint import ModelCheckpoint from federatedml.nn.backend.utils.common import get_homo_param_meta, recover_model_bytes +from federatedml.nn.homo._init import init +from federatedml.util import consts +from federatedml.nn.backend.utils.common import global_seed class HomoNNServer(ModelBase): @@ -13,7 +16,7 @@ class HomoNNServer(ModelBase): def __init__(self): super(HomoNNServer, self).__init__() self.model_param = HomoNNParam() - self.trainer = None + self.trainer = consts.FEDAVG_TRAINER self.trainer_param = None # arbiter side models @@ -24,13 +27,25 @@ def __init__(self): self.exporter = NNModelExporter() self.extra_data = {} # warm start + self.model_loaded = False self.warm_start_iter = None + # server init: if arbiter need to load model, loss, optimizer from config + self.server_init = False + + self.dataset_module = None + self.dataset = None + self.dataset_param = {} + self.torch_seed = None + self.loss = None + self.optimizer = None + self.nn_define = None + self.ds_config = None def export_model(self): if self.model is None: LOGGER.debug('export an empty model') - return self.exporter.export_model_dict() # return an exporter + return self.exporter.export_model_dict() # return an empyty model return self.model @@ -43,13 +58,24 @@ def load_model(self, model_dict): # load extra data self.extra_data = recover_model_bytes(param.extra_data_bytes) self.warm_start_iter = param.epoch_idx + self.model_loaded = True def _init_model(self, param: HomoNNParam()): + train_param = param.trainer.to_dict() + dataset_param = param.dataset.to_dict() self.trainer = train_param['trainer_name'] + self.dataset = dataset_param['dataset_name'] self.trainer_param = train_param['param'] + self.torch_seed = param.torch_seed + self.nn_define = param.nn_define + self.loss = param.loss + self.optimizer = param.optimizer + self.ds_config = param.ds_config + LOGGER.debug('trainer and trainer param {} {}'.format( self.trainer, self.trainer_param)) + self.server_init = param.server_init def fit(self, data_instance=None, validate_data=None): @@ -63,17 +89,37 @@ def fit(self, data_instance=None, validate_data=None): if self.component_properties.is_warm_start: self.callback_warm_start_init_iter(self.warm_start_iter) - # initialize trainer - trainer_class = get_trainer_class(self.trainer) - LOGGER.info('trainer class is {}'.format(trainer_class)) - # init trainer - trainer_inst = trainer_class(**self.trainer_param) + if self.server_init: + LOGGER.info('server try to load model, loss, optimizer from config') + # init + global_seed(self.torch_seed) + + trainer_inst, model, optimizer, loss_fn, extra_data, optimizer, loss, self.warm_start_iter = init( + trainer=self.trainer, trainer_param=self.trainer_param, nn_define=self.nn_define, + config_optimizer=self.optimizer, config_loss=self.loss, torch_seed=self.torch_seed, model_loaded_flag=self.model_loaded, + loaded_model=self.model, ds_config=self.ds_config + ) + trainer_inst.set_model(model) + + else: + # initialize trainer only + trainer_class = get_trainer_class(self.trainer) + trainer_inst = trainer_class(**self.trainer_param) + LOGGER.info('trainer class is {}'.format(trainer_class)) + # set tracker for fateboard callback trainer_inst.set_tracker(self.tracker) # set exporter trainer_inst.set_model_exporter(self.exporter) + # set party info + party_id_list = [self.component_properties.guest_partyid] + if self.component_properties.host_party_idlist is not None: + for i in self.component_properties.host_party_idlist: + party_id_list.append(i) + trainer_inst.set_party_id_list(party_id_list) # set chceckpoint trainer_inst.set_checkpoint(ModelCheckpoint(self, save_freq=1)) + # run trainer server procedure trainer_inst.server_aggregate_procedure(self.extra_data) diff --git a/python/federatedml/nn/homo/trainer/fedavg_trainer.py b/python/federatedml/nn/homo/trainer/fedavg_trainer.py index 9b2597f17f..dc74ba3fb6 100644 --- a/python/federatedml/nn/homo/trainer/fedavg_trainer.py +++ b/python/federatedml/nn/homo/trainer/fedavg_trainer.py @@ -150,6 +150,11 @@ def __init__(self, epochs=10, batch_size=512, # training parameter self.check_trainer_param( [self.tol], ['tol'], self.is_float, '{} is not a float') + # federation + self.client_agg = None + self.server_agg = None + self.aggregate_round = None + def _init_aggregator(self, train_set): # compute round to aggregate cur_agg_round = 0 @@ -191,7 +196,7 @@ def _select_model(self): else: return self.model - def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss): + def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss_func): epoch_loss = 0.0 batch_idx = 0 @@ -202,6 +207,8 @@ def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss): dl = self.data_loader + total_batch_len = len(dl) + if not self.fed_mode: to_iterate = tqdm.tqdm(dl) else: @@ -210,19 +217,10 @@ def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss): batch_label = None for _batch_iter in to_iterate: _batch_iter = self._decode(_batch_iter) - if isinstance(_batch_iter, list): + if isinstance(_batch_iter, list) or isinstance(_batch_iter, tuple): batch_data, batch_label = _batch_iter else: batch_data = _batch_iter - """ - if self.task_type in [consts.CAUSAL_LM, consts.SEQ_2_SEQ_LM]: - batch_data = _batch_iter - else: - batch_data, batch_label = _batch_iter - - batch_data = self._decode(batch_data) - batch_label = self._decode(batch_label) - """ if self.cuda is not None or self._enable_deepspeed: device = self.cuda_main_device if self.cuda_main_device is not None else self.model.device @@ -237,17 +235,17 @@ def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss): pred = model(batch_data) - if not loss and hasattr(pred, "loss"): + if not loss_func and hasattr(pred, "loss"): batch_loss = pred.loss - elif loss is not None: + elif loss_func is not None: if batch_label is None: raise ValueError( "When loss is set, please provide label to calculate loss" ) if not isinstance(pred, torch.Tensor) and hasattr(pred, "logits"): pred = pred.logits - batch_loss = loss(pred, batch_label) + batch_loss = loss_func(pred, batch_label) else: raise ValueError( 'FedAVGTrainer requires a loss function, but got None, please specify loss function in the' @@ -276,13 +274,87 @@ def train_an_epoch(self, epoch_idx, model, train_set, optimizer, loss): batch_idx += 1 # LOGGER.info(f"finish epoch={epoch_idx}, batch={batch_idx}") - if self.fed_mode: - LOGGER.debug( - 'epoch {} batch {} finished'.format(epoch_idx, batch_idx)) + if self.fed_mode: + if batch_idx % (total_batch_len // 100) == 0: + percentage = (batch_idx / total_batch_len) * 100 + LOGGER.debug(f"Training progress of epoch {epoch_idx}: {percentage:.1f}%") epoch_loss = epoch_loss / len(train_set) return epoch_loss + def on_loop_begin_client(self, **kwargs): + pass + + def on_loop_end_client(self, **kwargs): + pass + + def on_loop_begin_server(self, **kwargs): + pass + + def on_loop_end_server(self, **kwargs): + pass + + def _client_sends_data(self, epoch_idx, epoch_loss, cur_agg_round): + need_stop = False + if self.client_agg is not None or distributed_util.is_distributed(): + if not (self.aggregate_every_n_epoch is not None and (epoch_idx + 1) % self.aggregate_every_n_epoch != 0): + + # model averaging, only aggregate trainable param + if self._deepspeed_zero_3: + deepspeed_util.gather_model(self.model) + + if not distributed_util.is_distributed() or distributed_util.is_rank_0(): + self.model = self.client_agg.model_aggregation(self.model) + if distributed_util.is_distributed() and distributed_util.get_num_workers() > 1: + self._share_model() + else: + self._share_model() + + # agg loss and get converge status + if not distributed_util.is_distributed() or distributed_util.is_rank_0(): + converge_status = self.client_agg.loss_aggregation(epoch_loss) + cur_agg_round += 1 + if distributed_util.is_distributed() and distributed_util.get_num_workers() > 1: + self._sync_converge_status(converge_status) + else: + converge_status = self._sync_converge_status() + + if not distributed_util.is_distributed() or distributed_util.is_rank_0(): + LOGGER.info( + 'model averaging finished, aggregate round {}/{}'.format( + cur_agg_round, self.aggregate_round)) + + if converge_status: + LOGGER.info('early stop triggered, stop training') + need_stop = True + + return need_stop + + def _server_aggregates_data(self, epoch_idx, check_converge, converge_func): + + need_stop = False + if not (self.aggregate_every_n_epoch is not None and (epoch_idx + 1) % self.aggregate_every_n_epoch != 0): + + # model aggregate + self.server_agg.model_aggregation() + + # loss aggregate + agg_loss, converge_status = self.server_agg.loss_aggregation( + check_converge=check_converge, converge_func=converge_func) + self.callback_loss(agg_loss, epoch_idx) + + # save check point process + if self.save_freq is not None and ((epoch_idx + 1) % self.save_freq == 0): + self.checkpoint(epoch_idx=epoch_idx) + LOGGER.info('save checkpoint : epoch {}'.format(epoch_idx)) + + # check stop condition + if converge_status: + LOGGER.debug('stop triggered, stop aggregation') + need_stop = True + + return need_stop + def train( self, train_set: Dataset, @@ -293,7 +365,7 @@ def train( if optimizer is None: raise ValueError( - 'FedAVGTrainer requires an optimizer, but got None, please specify optimizer in the ' + 'An optimizer is required, but got None, please specify optimizer in the ' 'job configuration') if self.batch_size > len(train_set) or self.batch_size == -1: @@ -301,7 +373,7 @@ def train( # compute round to aggregate cur_agg_round = 0 - client_agg, aggregate_round = self._init_aggregator(train_set) + self.client_agg, self.aggregate_round = self._init_aggregator(train_set) # running var cur_epoch = 0 @@ -309,7 +381,10 @@ def train( need_stop = False evaluation_summary = {} - self._get_train_data_loader(train_set) + self.data_loader = self._get_train_data_loader(train_set) + + self.on_loop_begin_client() + # training process for i in range(self.epochs): @@ -323,37 +398,8 @@ def train( LOGGER.info('epoch loss is {}'.format(epoch_loss)) # federation process, if running local mode, cancel federation - if client_agg is not None or distributed_util.is_distributed(): - if not (self.aggregate_every_n_epoch is not None and (i + 1) % self.aggregate_every_n_epoch != 0): - - # model averaging, only aggregate trainable param - if self._deepspeed_zero_3: - deepspeed_util.gather_model(self.model) - - if not distributed_util.is_distributed() or distributed_util.is_rank_0(): - self.model = client_agg.model_aggregation(self.model) - if distributed_util.is_distributed() and distributed_util.get_num_workers() > 1: - self._share_model() - else: - self._share_model() - - # agg loss and get converge status - if not distributed_util.is_distributed() or distributed_util.is_rank_0(): - converge_status = client_agg.loss_aggregation(epoch_loss) - cur_agg_round += 1 - if distributed_util.is_distributed() and distributed_util.get_num_workers() > 1: - self._sync_converge_status(converge_status) - else: - converge_status = self._sync_converge_status() - - if not distributed_util.is_distributed() or distributed_util.is_rank_0(): - LOGGER.info( - 'model averaging finished, aggregate round {}/{}'.format( - cur_agg_round, aggregate_round)) - - if converge_status: - LOGGER.info('early stop triggered, stop training') - need_stop = True + need_stop = self._client_sends_data(i, epoch_loss, cur_agg_round) + cur_agg_round += 1 # validation process if self.validation_freq and ((i + 1) % self.validation_freq == 0): @@ -400,6 +446,8 @@ def train( if self._deepspeed_zero_3: deepspeed_util.gather_model(self.model) + self.on_loop_end_client() + if not distributed_util.is_distributed() or distributed_util.is_rank_0(): best_epoch = int(np.array(loss_history).argmin()) @@ -490,31 +538,23 @@ def server_aggregate_procedure(self, extra_data={}): 'check early stop, converge func is {}'.format(converge_func)) LOGGER.info('server running aggregate procedure') - server_agg = SecureAggServer(self.secure_aggregate, communicate_match_suffix=self.comm_suffix) + self.server_agg = SecureAggServer(self.secure_aggregate, communicate_match_suffix=self.comm_suffix) + self.on_loop_begin_server() # aggregate and broadcast models for i in range(self.epochs): - if not (self.aggregate_every_n_epoch is not None and (i + 1) % self.aggregate_every_n_epoch != 0): - - # model aggregate - server_agg.model_aggregation() - - # loss aggregate - agg_loss, converge_status = server_agg.loss_aggregation( - check_converge=check_converge, converge_func=converge_func) - self.callback_loss(agg_loss, i) - - # save check point process - if self.save_freq is not None and ((i + 1) % self.save_freq == 0): - self.checkpoint(epoch_idx=i) - LOGGER.info('save checkpoint : epoch {}'.format(i)) - # check stop condition - if converge_status: - LOGGER.debug('stop triggered, stop aggregation') - break + need_stop = self._server_aggregates_data(i, check_converge, converge_func) + if need_stop: + break - LOGGER.info('server aggregation process done') + self.on_loop_end_server() + if self.model is not None: + if self.save_to_local_dir: + self.local_save(model=self.model, epoch_idx=i, converge_status=need_stop) + else: + self.save(model=self.model, epoch_idx=i, converge_status=need_stop) + LOGGER.info('sever side model saved') def _decode(self, data): if isinstance(data, transformers.tokenization_utils_base.BatchEncoding): @@ -550,7 +590,7 @@ def _get_train_data_loader(self, train_set): collate_fn = self._get_collate_fn(train_set) if not distributed_util.is_distributed() or distributed_util.get_num_workers() <= 1: - self.data_loader = DataLoader( + data_loader = DataLoader( train_set, batch_size=self.batch_size, pin_memory=self.pin_memory, @@ -564,7 +604,7 @@ def _get_train_data_loader(self, train_set): num_replicas=dist.get_world_size(), rank=dist.get_rank() ) - self.data_loader = DataLoader( + data_loader = DataLoader( train_set, batch_size=self.batch_size, pin_memory=self.pin_memory, @@ -573,6 +613,8 @@ def _get_train_data_loader(self, train_set): sampler=train_sampler ) + return data_loader + def _share_model(self): if distributed_util.is_rank_0(): for p in self.model.parameters(): diff --git a/python/federatedml/nn/homo/trainer/trainer_base.py b/python/federatedml/nn/homo/trainer/trainer_base.py index 3af07dd355..cf47061e21 100644 --- a/python/federatedml/nn/homo/trainer/trainer_base.py +++ b/python/federatedml/nn/homo/trainer/trainer_base.py @@ -9,7 +9,7 @@ from federatedml.util import consts from federatedml.util import LOGGER from federatedml.model_base import serialize_models -from federatedml.nn.backend.utils.common import ML_PATH +from federatedml.nn.backend.utils.common import ML_PATH, LLM_PATH from federatedml.feature.instance import Instance from federatedml.evaluation.evaluation import Evaluation from federatedml.model_base import Metric, MetricMeta @@ -52,6 +52,7 @@ def __init__(self, **kwargs): self._model_checkpoint = None self._exporter = None self._evaluation_summary = {} + self._client_num = None # running status self._set_model_checkpoint_epoch = set() @@ -225,10 +226,17 @@ def _local_save( if hasattr(model, "enable_save_pretrained") and model.enable_save_pretrained: unwrap_model.save_pretrained(save_path) else: - model_state_dict = model.state_dict() + if model is None: + model_state_dict = None + else: + model_state_dict = model.state_dict() + if optimizer is None: + optimizer_state_dict = None + else: + optimizer_state_dict = optimizer.state_dict() model_dict = { 'model': model_state_dict, - 'optimizer': optimizer.state_dict(), + 'optimizer': optimizer_state_dict, 'model_define': self.nn_define, 'optimizer_define': self.opt_define, 'loss_define': self.loss_define, @@ -548,22 +556,43 @@ def unwrap_model(model): def get_trainer_class(trainer_module_name: str): - if trainer_module_name.endswith('.py'): trainer_module_name = trainer_module_name.replace('.py', '') - ds_modules = importlib.import_module( - '{}.homo.trainer.{}'.format( - ML_PATH, trainer_module_name)) + + std_fate_trainer_path = '{}.homo.trainer.{}'.format(ML_PATH, trainer_module_name) + + paths_to_check = [std_fate_trainer_path] + errors = [] try: - trainers = [] - for k, v in ds_modules.__dict__.items(): - if isinstance(v, type): - if issubclass(v, TrainerBase) and v is not TrainerBase: - trainers.append(v) - if len(trainers) == 0: - raise ValueError('Did not find any class in {}.py that is the subclass of Trainer class'. - format(trainer_module_name)) - else: - return trainers[-1] # return the last defined trainer - except ValueError as e: - raise e + importlib.import_module(LLM_PATH) + fate_llm_trainer_path = '{}.trainer.{}'.format(LLM_PATH, trainer_module_name) + paths_to_check.append(fate_llm_trainer_path) + except Exception as e: + pass + + trainers = [] + ds_modules = None + + for path in paths_to_check: + try: + ds_modules = importlib.import_module(path) + break + except Exception as e: + errors.append(str(e)) + + if ds_modules is None: + raise ImportError( + 'Could not import from any of the paths: {}, error details {}'.format( + ', '.join(paths_to_check), errors)) + + for k, v in ds_modules.__dict__.items(): + + if isinstance(v, type): + if issubclass(v, TrainerBase) and v is not TrainerBase: + trainers.append(v) + + if len(trainers) == 0: + raise ValueError('Did not find any class in {}.py that is the subclass of Trainer class'. + format(trainer_module_name)) + else: + return trainers[-1] # return the last defined trainer diff --git a/python/federatedml/param/homo_nn_param.py b/python/federatedml/param/homo_nn_param.py index 297ba12c72..64650a5762 100644 --- a/python/federatedml/param/homo_nn_param.py +++ b/python/federatedml/param/homo_nn_param.py @@ -42,7 +42,8 @@ def __init__(self, nn_define: dict = None, loss: dict = None, optimizer: dict = None, - ds_config: dict = None + ds_config: dict = None, + server_init: bool = False ): super(HomoNNParam, self).__init__() @@ -53,6 +54,7 @@ def __init__(self, self.loss = loss self.optimizer = optimizer self.ds_config = ds_config + self.server_init = server_init def check(self):