Skip to content

Commit

Permalink
Merge pull request #10 from FederatedAI/feature-1.0-add-feature-data-…
Browse files Browse the repository at this point in the history
…params

add feature data params
  • Loading branch information
dylan-fan authored Oct 28, 2019
2 parents 2eec324 + 90c3110 commit a610dc7
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@ public class Dict {
public static final String HIT_CACHE = "hitCache";

public static final String REQUEST_SEQNO="REQUEST_SEQNO";

public static final String GUEST_APP_ID = "guestAppId";
public static final String HOST_APP_ID = "hostAppId";
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.webank.ai.fate.serving.adapter.dataaccess;

import com.webank.ai.fate.core.bean.ReturnResult;
import com.webank.ai.fate.serving.core.bean.Context;
import com.webank.ai.fate.serving.utils.HttpClientPool;
import com.webank.ai.fate.core.utils.ObjectTransform;
import com.webank.ai.fate.serving.core.constant.InferenceRetCode;
Expand All @@ -33,7 +34,7 @@ public class DTest implements FeatureData {
private static final Logger LOGGER = LogManager.getLogger();

@Override
public ReturnResult getData(Map<String, Object> featureIds) {
public ReturnResult getData(Map<String, Object> featureIds, Context context) {
ReturnResult returnResult = new ReturnResult();
Map<String, Object> requestData = new HashMap<>();
requestData.putAll(featureIds);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
package com.webank.ai.fate.serving.adapter.dataaccess;

import com.webank.ai.fate.core.bean.ReturnResult;
import com.webank.ai.fate.serving.core.bean.Context;

import java.util.Map;

public interface FeatureData {
ReturnResult getData(Map<String, Object> featureIds);
ReturnResult getData(Map<String, Object> featureIds, Context context);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.webank.ai.fate.serving.adapter.dataaccess;

import com.webank.ai.fate.core.bean.ReturnResult;
import com.webank.ai.fate.serving.core.bean.Context;
import com.webank.ai.fate.serving.core.constant.InferenceRetCode;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
Expand All @@ -32,7 +33,7 @@ public class TestFile implements FeatureData {
private static final Logger LOGGER = LogManager.getLogger();

@Override
public ReturnResult getData(Map<String, Object> featureIds) {
public ReturnResult getData(Map<String, Object> featureIds, Context context) {
ReturnResult returnResult = new ReturnResult();
Map<String, Object> data = new HashMap<>();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ public static ReturnResult federatedInference(Context context,Map<String,Object
predictParams.put("federatedParams", federatedParams);

try {
ReturnResult getFeatureDataResult = getFeatureData(featureIds);
ReturnResult getFeatureDataResult = getFeatureData(featureIds, context);
if (getFeatureDataResult.getRetcode() == InferenceRetCode.OK) {
if (getFeatureDataResult.getData() == null || getFeatureDataResult.getData().size() < 1) {
returnResult.setRetcode(InferenceRetCode.GET_FEATURE_FAILED);
Expand Down Expand Up @@ -321,7 +321,7 @@ private static PostProcessingResult getPostProcessedResult(Context context ,Map
LOGGER.info("postprocess caseid {} cost time {}",context.getCaseId(),endTime-beginTime);
}
}
private static ReturnResult getFeatureData(Map<String, Object> featureIds) {
private static ReturnResult getFeatureData(Map<String, Object> featureIds, Context context) {
ReturnResult defaultReturnResult = new ReturnResult();
String classPath = FeatureData.class.getPackage().getName() + "." + Configuration.getProperty("OnlineDataAccessAdapter");
FeatureData featureData = (FeatureData) InferenceUtils.getClassByName(classPath);
Expand All @@ -330,7 +330,7 @@ private static ReturnResult getFeatureData(Map<String, Object> featureIds) {
return defaultReturnResult;
}
try {
return featureData.getData(featureIds);
return featureData.getData(featureIds, context);
} catch (Exception ex) {
defaultReturnResult.setRetcode(InferenceRetCode.GET_FEATURE_FAILED);
LOGGER.error("get feature data error:", ex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ public void unaryCall(Proxy.Packet req, StreamObserver<Proxy.Packet> responseObs
requestData = (Map<String, Object>) ObjectTransform.json2Bean(req.getBody().getValue().toStringUtf8(), HashMap.class);
context.setCaseId(requestData.get(Dict.CASEID)!=null?requestData.get(Dict.CASEID).toString():Dict.NONE);

FederatedParty party = (FederatedParty) ObjectTransform.json2Bean(requestData.get("local").toString(), FederatedParty.class);
FederatedParty partnerParty = (FederatedParty) ObjectTransform.json2Bean(requestData.get("partner_local").toString(), FederatedParty.class);

context.putData(Dict.GUEST_APP_ID, partnerParty.getPartyId());
context.putData(Dict.HOST_APP_ID, party.getPartyId());

switch (req.getHeader().getCommand().getName()) {
case "federatedInference":
responseResult = InferenceManager.federatedInference(context, requestData);
Expand All @@ -67,8 +73,6 @@ public void unaryCall(Proxy.Packet req, StreamObserver<Proxy.Packet> responseObs

Proxy.Metadata.Builder metaDataBuilder = Proxy.Metadata.newBuilder();
Proxy.Topic.Builder topicBuilder = Proxy.Topic.newBuilder();
FederatedParty partnerParty = (FederatedParty) ObjectTransform.json2Bean(requestData.get("partner_local").toString(), FederatedParty.class);
FederatedParty party = (FederatedParty) ObjectTransform.json2Bean(requestData.get("local").toString(), FederatedParty.class);

metaDataBuilder.setSrc(
topicBuilder.setPartyId(String.valueOf(party.getPartyId()))
Expand Down

0 comments on commit a610dc7

Please sign in to comment.