Skip to content

Commit

Permalink
Merge pull request #89 from FederatedAI/develop-2.0.2
Browse files Browse the repository at this point in the history
Compatible with the old version
  • Loading branch information
forgivedengkai authored Oct 26, 2020
2 parents 63775ef + 249744a commit 497331c
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ public Map<String, Object> singleLocalPredict(Context context, Map<String, Objec
if (logger.isDebugEnabled()) {
logger.debug("component {} is Returnable return data {}", component, result);
}
if (StringUtils.isBlank(context.getVersion()) || Long.parseLong(context.getVersion()) < 200) {
if (StringUtils.isBlank(context.getVersion()) || Double.parseDouble(context.getVersion()) < 200) {
result.putAll(componentResult);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ public void unaryCall(Proxy.Packet req, StreamObserver<Proxy.Packet> responseObs

public InboundPackage<Proxy.Packet> buildInboundPackage(Context context, Proxy.Packet req) {
context.setCaseId(Long.toString(System.currentTimeMillis()));
context.setVersion(req.getAuth().getVersion());
if (StringUtils.isEmpty(context.getVersion())) {
context.setVersion(Dict.DEFAULT_VERSION);
if (StringUtils.isNotBlank(req.getHeader().getOperator())) {
context.setVersion(req.getHeader().getOperator());
}
context.setGuestAppId(req.getHeader().getSrc().getPartyId());
context.setHostAppid(req.getHeader().getDst().getPartyId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,15 @@ private String getEnvironment(Context context, InboundPackage inboundPackage) {
// guest, proxy -> serving
return (String) inboundPackage.getHead().get(Dict.SERVICE_ID);
}
// default unaryCall
if (GrpcType.INTRA_GRPC == context.getGrpcType()) {
// guest, serving -> proxy
// return Dict.ONLINE_ENVIRONMENT;
return null;
} else {

if (Dict.UNARYCALL.equals(context.getServiceName()) && context.getGrpcType() == GrpcType.INTER_GRPC) {
// host, proxy -> serving
Proxy.Packet sourcePacket = (Proxy.Packet) inboundPackage.getBody();
if (MetaInfo.PROPERTY_COORDINATOR.equals(sourcePacket.getHeader().getDst().getPartyId())) {
// host, proxy -> serving
return FederatedModelUtils.getModelRouteKey(sourcePacket);
} else {
// exchange, proxy -> proxy
// return Dict.ONLINE_ENVIRONMENT;
return null;
}
return FederatedModelUtils.getModelRouteKey(context, sourcePacket);
}

// default unaryCall proxy -> proxy
return null;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,48 @@
package com.webank.ai.fate.serving.proxy.utils;

import com.webank.ai.fate.api.networking.proxy.Proxy;
import com.webank.ai.fate.serving.core.bean.Context;
import com.webank.ai.fate.serving.core.bean.EncryptMethod;
import com.webank.ai.fate.serving.core.utils.EncryptUtils;
import com.webank.ai.fate.serving.core.utils.JsonUtil;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.Map;

public class FederatedModelUtils {

private static final Logger logger = LoggerFactory.getLogger(FederatedModelUtils.class);

private static final String MODEL_KEY_SEPARATOR = "&";

public static String genModelKey(String tableName, String namespace) {
return StringUtils.join(Arrays.asList(tableName, namespace), MODEL_KEY_SEPARATOR);
}

public static String getModelRouteKey(Proxy.Packet packet) {
String data = packet.getBody().getValue().toStringUtf8();
Proxy.Model model = packet.getHeader().getTask().getModel();
String key = genModelKey(model.getTableName(), model.getNamespace());
String md5Key = EncryptUtils.encrypt(key, EncryptMethod.MD5);
return md5Key;
public static String getModelRouteKey(Context context, Proxy.Packet packet) {
String namespace;
String tableName;
if (StringUtils.isBlank(context.getVersion()) || Double.parseDouble(context.getVersion()) < 200) {
// version 1.x
String data = packet.getBody().getValue().toStringUtf8();
Map hostFederatedParams = JsonUtil.json2Object(data, Map.class);
Map partnerModelInfo = (Map) hostFederatedParams.get("partnerModelInfo");
namespace = partnerModelInfo.get("namespace").toString();
tableName = partnerModelInfo.get("name").toString();
} else {
// version 2.0.0+
Proxy.Model model = packet.getHeader().getTask().getModel();
namespace = model.getNamespace();
tableName = model.getTableName();
}

String key = genModelKey(tableName, namespace);
logger.info("get model route key by version: {} namespace: {} tablename: {}, key: {}", context.getVersion(), namespace, tableName, key);

return EncryptUtils.encrypt(key, EncryptMethod.MD5);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public void unaryCall(Proxy.Packet req, StreamObserver<Proxy.Packet> responseObs
String tableName = req.getHeader().getTask().getModel().getTableName();
context.setActionType(actionType);
context.setVersion(req.getHeader().getOperator());
if (StringUtils.isBlank(context.getVersion()) || Long.parseLong(context.getVersion()) < 200) {
if (StringUtils.isBlank(context.getVersion()) || Double.parseDouble(context.getVersion()) < 200) {
// 1.x
Map hostFederatedParams = JsonUtil.json2Object(req.getBody().getValue().toStringUtf8(), Map.class);
Map partnerModelInfo = (Map) hostFederatedParams.get("partnerModelInfo");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public void doPreProcess(Context context, InboundPackage inboundPackage, Outboun
String tableName = servingServerContext.getModelTableName();
String nameSpace = servingServerContext.getModelNamesapce();
Model model;
if (StringUtils.isBlank(context.getVersion()) || Long.parseLong(context.getVersion()) < 200) {
if (StringUtils.isBlank(context.getVersion()) || Double.parseDouble(context.getVersion()) < 200) {
model = modelManager.getPartnerModel(tableName, nameSpace);
} else {
model = modelManager.getModelByTableNameAndNamespace(tableName, nameSpace);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public void doPreProcess(Context context, InboundPackage inboundPackage, Outboun
inboundPackage.setBody(params);
} else {
InferenceRequest inferenceRequest = JsonUtil.json2Object(reqBody, InferenceRequest.class);
if (StringUtils.isBlank(context.getVersion()) || Long.parseLong(context.getVersion()) < 200) {
if (StringUtils.isBlank(context.getVersion()) || Double.parseDouble(context.getVersion()) < 200) {
Map hostParams = JsonUtil.json2Object(reqBody, Map.class);
Preconditions.checkArgument(hostParams != null, "parse inference params error");
Preconditions.checkArgument(hostParams.get("featureIdMap") != null, "parse inference params featureIdMap error");
Expand Down

0 comments on commit 497331c

Please sign in to comment.