Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature 1.13 llm update #5063

Merged
merged 9 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions python/fate_client/pipeline/component/homo_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
'loss': None,
'optimizer': None,
'nn_define': None,
'ds_config': None
'ds_config': None,
'server_init': False
}
except Exception as e:
print(e)
Expand All @@ -65,7 +66,10 @@ class HomoNN(FateComponent):
torch_seed, global random seed
loss, loss function from fate_torch
optimizer, optimizer from fate_torch
ds_config, config for deepspeed
model, a fate torch sequential defining the model structure
server_init, whether to initialize the model, loss and optimizer on server, if configs are provided, they will be used. In
current version this option is specially designed for offsite-tuning
"""

@extract_explicit_parameter
Expand All @@ -82,7 +86,9 @@ def __init__(self,
loss=None,
optimizer: OptimizerType = None,
ds_config: dict = None,
model: Sequential = None, **kwargs):
model: Sequential = None,
server_init: bool = False,
**kwargs):

explicit_parameters = copy.deepcopy(DEFAULT_PARAM_DICT)
if 'name' not in kwargs["explict_parameters"]:
Expand All @@ -94,8 +100,15 @@ def __init__(self,
self.input = Input(self.name, data_type="multi")
self.output = Output(self.name, data_type='single')
self._module_name = "HomoNN"
self._updated = {'trainer': False, 'dataset': False,
'torch_seed': False, 'loss': False, 'optimizer': False, 'model': False}
self._updated = {
'trainer': False,
'dataset': False,
'torch_seed': False,
'loss': False,
'optimizer': False,
'model': False,
'ds_config': False,
'server_init': False}
self._set_param(kwargs["explict_parameters"])
self._check_parameters()

Expand Down
35 changes: 25 additions & 10 deletions python/fate_client/pipeline/component/nn/backend/torch/cust.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from pipeline.component.nn.backend.torch.base import FateTorchLayer, FateTorchLoss
import difflib

ML_PATH = 'federatedml.nn'
LLM_PATH = "fate_llm"

MODEL_PATH = None
LOSS_PATH = None
LLM_MODEL_PATH = '{}.model_zoo'.format(LLM_PATH)
MODEL_PATH = '{}.model_zoo'.format(ML_PATH)
LOSS_PATH = '{}.loss'.format(ML_PATH)


def str_simi(str_a, str_b):
Expand Down Expand Up @@ -45,9 +48,14 @@ class CustModel(FateTorchLayer, nn.Module):

def __init__(self, module_name, class_name, **kwargs):
super(CustModel, self).__init__()
assert isinstance(module_name, str), 'name must be a str, specify the module in the model_zoo'
assert isinstance(class_name, str), 'class name must be a str, specify the class in the module'
self.param_dict = {'module_name': module_name, 'class_name': class_name, 'param': kwargs}
assert isinstance(
module_name, str), 'name must be a str, specify the module in the model_zoo'
assert isinstance(
class_name, str), 'class name must be a str, specify the class in the module'
self.param_dict = {
'module_name': module_name,
'class_name': class_name,
'param': kwargs}
self._model = None

def init_model(self):
Expand All @@ -62,11 +70,18 @@ def forward(self, x):
def get_pytorch_model(self, module_path=None):

if module_path is None:
return get_class(
self.param_dict['module_name'],
self.param_dict['class_name'],
self.param_dict['param'],
MODEL_PATH)
try:
return get_class(
self.param_dict['module_name'],
self.param_dict['class_name'],
self.param_dict['param'],
MODEL_PATH)
except BaseException:
return get_class(
self.param_dict['module_name'],
self.param_dict['class_name'],
self.param_dict['param'],
LLM_MODEL_PATH)
else:
return get_class(
self.param_dict['module_name'],
Expand Down
12 changes: 8 additions & 4 deletions python/federatedml/framework/homo/aggregator/aggregator_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from federatedml.framework.homo.blocks import ServerCommunicator, ClientCommunicator
from federatedml.util import consts


class AutoSuffix(object):
Expand All @@ -19,7 +20,10 @@ def __call__(self):

class AggregatorBaseClient(object):

def __init__(self, communicate_match_suffix: str = None):
def __init__(
self, communicate_match_suffix: str = None, server=(
consts.ARBITER,), clients=(
consts.GUEST, consts.HOST)):
"""Base class of client aggregator

Parameters
Expand All @@ -28,7 +32,7 @@ def __init__(self, communicate_match_suffix: str = None):
To make sure that client and server can communicate correctly,
the server-side and client-side aggregators need to have the same suffix
"""
self.communicator = ClientCommunicator(prefix=communicate_match_suffix)
self.communicator = ClientCommunicator(prefix=communicate_match_suffix, server=server, clients=clients)
self.suffix = {}

def _get_suffix(self, var_name, user_suffix=tuple()):
Expand All @@ -52,7 +56,7 @@ def get(self, suffix):

class AggregatorBaseServer(object):

def __init__(self, communicate_match_suffix=None):
def __init__(self, communicate_match_suffix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)):
"""Base class of server aggregator

Parameters
Expand All @@ -61,7 +65,7 @@ def __init__(self, communicate_match_suffix=None):
To make sure that client and server can communicate correctly,
the server-side and client-side aggregators need to have the same suffix
"""
self.communicator = ServerCommunicator(prefix=communicate_match_suffix)
self.communicator = ServerCommunicator(prefix=communicate_match_suffix, server=server, clients=clients)
self.suffix = {}

def _get_suffix(self, var_name, user_suffix=tuple()):
Expand Down
24 changes: 18 additions & 6 deletions python/federatedml/framework/homo/aggregator/secure_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
class SecureAggregatorClient(AggregatorBaseClient):

def __init__(self, secure_aggregate=True, aggregate_type='weighted_mean', aggregate_weight=1.0,
communicate_match_suffix=None):
communicate_match_suffix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)):

super(SecureAggregatorClient, self).__init__(
communicate_match_suffix=communicate_match_suffix)
communicate_match_suffix=communicate_match_suffix, clients=clients, server=server)
self.secure_aggregate = secure_aggregate
self.suffix = {
"local_loss": AutoSuffix("local_loss"),
Expand All @@ -31,7 +31,10 @@ def __init__(self, secure_aggregate=True, aggregate_type='weighted_mean', aggreg
# init secure aggregate random padding:
if self.secure_aggregate:
self._random_padding_cipher: PadsCipher = RandomPaddingCipherClient(
trans_var=RandomPaddingCipherTransVar(prefix=communicate_match_suffix)).create_cipher()
trans_var=RandomPaddingCipherTransVar(
prefix=communicate_match_suffix,
clients=clients,
server=server)).create_cipher()
LOGGER.info('initialize secure aggregator done')

# compute weight
Expand Down Expand Up @@ -186,9 +189,18 @@ def loss_aggregation(self, loss, suffix=tuple()):

class SecureAggregatorServer(AggregatorBaseServer):

def __init__(self, secure_aggregate=True, communicate_match_suffix=None):
def __init__(
self,
secure_aggregate=True,
communicate_match_suffix=None,
server=(
consts.ARBITER,
),
clients=(
consts.GUEST,
consts.HOST)):
super(SecureAggregatorServer, self).__init__(
communicate_match_suffix=communicate_match_suffix)
communicate_match_suffix=communicate_match_suffix, clients=clients, server=server)
self.suffix = {
"local_loss": AutoSuffix("local_loss"),
"agg_loss": AutoSuffix("agg_loss"),
Expand All @@ -199,7 +211,7 @@ def __init__(self, secure_aggregate=True, communicate_match_suffix=None):
self.secure_aggregate = secure_aggregate
if self.secure_aggregate:
RandomPaddingCipherServer(trans_var=RandomPaddingCipherTransVar(
prefix=communicate_match_suffix)).exchange_secret_keys()
prefix=communicate_match_suffix, clients=clients, server=server)).exchange_secret_keys()
LOGGER.info('initialize secure aggregator done')

agg_weights = self.collect(suffix=('agg_weight', ))
Expand Down
8 changes: 4 additions & 4 deletions python/federatedml/framework/homo/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __init__(self, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)

class ServerCommunicator(object):

def __init__(self, prefix=None):
self.trans_var = CommunicatorTransVar(prefix=prefix)
def __init__(self, prefix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)):
self.trans_var = CommunicatorTransVar(prefix=prefix, server=server, clients=clients)
self._client_parties = self.trans_var.client_parties

def get_parties(self, party_idx):
Expand All @@ -85,8 +85,8 @@ def broadcast_obj(self, obj, suffix=tuple(), party_idx=-1):

class ClientCommunicator(object):

def __init__(self, prefix=None):
trans_var = CommunicatorTransVar(prefix=prefix)
def __init__(self, prefix=None, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST)):
trans_var = CommunicatorTransVar(prefix=prefix, server=server, clients=clients)
self.trans_var = trans_var
self._server_parties = trans_var.server_parties

Expand Down
55 changes: 0 additions & 55 deletions python/federatedml/nn/backend/torch/cust_model.py

This file was deleted.

Loading
Loading