Skip to content

Commit

Permalink
Support citrix cloud gov environment
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuolun-citrix committed Jan 12, 2024
1 parent b868036 commit bf12d42
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 25 deletions.
91 changes: 75 additions & 16 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package citrixclient

import (
"bytes"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -30,6 +31,15 @@ type TrustAuthResponse struct {
ExpiresAt string `json:"ExpiresAt"`
}

// AuthResponse -
type CCTrustAuthResponse struct {
Token string `json:"Token"`
Principal string `json:"Principal"`
UserId string `json:"UserId"`
CustomerId string `json:"CustomerId"`
ExpiresAt int `json:"ExpiresAt"`
}

// SignIn - Get a new token for user
func (c *CitrixDaasClient) SignIn() (string, *http.Response, error) {
if c.AuthConfig.ClientId == "" || c.AuthConfig.ClientSecret == "" {
Expand All @@ -48,12 +58,11 @@ func (c *CitrixDaasClient) SignIn() (string, *http.Response, error) {
var expiresAt string

if c.AuthConfig.OnPremises {
// tr := &http.Transport{
// TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
// }
// client := &http.Client{Transport: tr}
client := c.ApiClient.GetConfig().HTTPClient
req, _ := http.NewRequest(http.MethodPost, c.AuthConfig.AuthUrl, nil)
req, err := http.NewRequest(http.MethodPost, c.AuthConfig.AuthUrl, nil)
if err != nil {
return "", nil, err
}
req.SetBasicAuth(c.AuthConfig.ClientId, c.AuthConfig.ClientSecret)
resp, err := client.Do(req)
if err != nil {
Expand All @@ -78,8 +87,8 @@ func (c *CitrixDaasClient) SignIn() (string, *http.Response, error) {

token = fmt.Sprintf("CWSAuth bearer=%s", ar.Token)
expiresAt = ar.ExpiresAt

} else {
} else if !c.AuthConfig.IsGov {
// If Auth is not for Citrix Cloud Gov, use oauth2
grant_type := "client_credentials"
operation := func() (CCAuthResponse, *http.Response, error) {
return performCCAuth(c.AuthConfig.AuthUrl, url.Values{"grant_type": {grant_type}, "client_id": {c.AuthConfig.ClientId}, "client_secret": {c.AuthConfig.ClientSecret}})
Expand All @@ -95,6 +104,23 @@ func (c *CitrixDaasClient) SignIn() (string, *http.Response, error) {
expiryInSeconds, _ := strconv.Atoi(ccAuthResponse.Expiration)
expiresAtTime := time.Now().UTC().Add(time.Second * time.Duration(expiryInSeconds))
expiresAt = expiresAtTime.Format(time.RFC3339)
} else {
// If Auth is for Citrix Cloud Gov, use trust service
client := c.ApiClient.GetConfig().HTTPClient

operation := func() (CCTrustAuthResponse, *http.Response, error) {
return performCCTrustAuth(client, c.AuthConfig.AuthUrl, []byte(fmt.Sprintf(`{"ClientId":"%s", "ClientSecret":"%s"}`, c.AuthConfig.ClientId, c.AuthConfig.ClientSecret)))
}

ccTrustAuthResponse, resp, err := RetryOperationWithExponentialBackOff(operation, 10, 3)

if err != nil {
return "", resp, err
}

token = fmt.Sprintf("CWSAuth bearer=%s", ccTrustAuthResponse.Token)
expiresAtTime := time.Now().UTC().Add(time.Second * time.Duration(ccTrustAuthResponse.ExpiresAt))
expiresAt = expiresAtTime.Format(time.RFC3339)
}

authTokenModel := AuthTokenModel{}
Expand All @@ -106,29 +132,62 @@ func (c *CitrixDaasClient) SignIn() (string, *http.Response, error) {
}

func performCCAuth(authUrl string, urlValues url.Values) (CCAuthResponse, *http.Response, error) {
ccAuthResonse := CCAuthResponse{}
ccAuthResponse := CCAuthResponse{}
httpResp, err := http.PostForm(authUrl, urlValues)
if err != nil {
return ccAuthResonse, httpResp, err
return ccAuthResponse, httpResp, err
}
defer httpResp.Body.Close()
body, err := io.ReadAll(httpResp.Body)
if err != nil {
return ccAuthResonse, httpResp, err
return ccAuthResponse, httpResp, err
}

if httpResp.StatusCode > 200 {
return ccAuthResonse, httpResp, fmt.Errorf("could not sign into Citrix Cloud, %s", string(body))
return ccAuthResponse, httpResp, fmt.Errorf("could not sign into Citrix Cloud, %s", string(body))
}

err = json.Unmarshal(body, &ccAuthResponse)
if err != nil {
return ccAuthResponse, httpResp, err
}

if ccAuthResponse.Error != "" {
return ccAuthResponse, httpResp, fmt.Errorf("could not sign into Citrix Cloud, %s: %s", ccAuthResponse.Error, ccAuthResponse.ErrorDescription)
}

return ccAuthResponse, httpResp, err
}

func performCCTrustAuth(client *http.Client, authUrl string, jsonStr []byte) (CCTrustAuthResponse, *http.Response, error) {
ccTrustResponse := CCTrustAuthResponse{}

req, err := http.NewRequest(http.MethodPost, authUrl, bytes.NewBuffer(jsonStr))
if err != nil {
return ccTrustResponse, nil, err
}
req.Header.Set("X-Custom-Header", "myvalue")
req.Header.Set("Content-Type", "application/json")

resp, err := client.Do(req)
if err != nil {
return ccTrustResponse, nil, err
}

err = json.Unmarshal(body, &ccAuthResonse)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return ccAuthResonse, httpResp, err
return ccTrustResponse, nil, err
}

if ccAuthResonse.Error != "" {
return ccAuthResonse, httpResp, fmt.Errorf("could not sign into Citrix Cloud, %s: %s", ccAuthResonse.Error, ccAuthResonse.ErrorDescription)
if resp.StatusCode > 200 {
return ccTrustResponse, resp, fmt.Errorf("could not sign into DDC, %s", string(body))
}

err = json.Unmarshal(body, &ccTrustResponse)
if err != nil {
return ccTrustResponse, nil, err
}

return ccAuthResonse, httpResp, err
return ccTrustResponse, resp, err
}
37 changes: 29 additions & 8 deletions client/citrixdaas_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/http"
"os"
"reflect"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -39,7 +40,7 @@ func getMiddlewareWithClient(authClient *CitrixDaasClient, middlewareAuthFunc Mi
}
}

func NewCitrixDaasClient(ctx context.Context, authUrl, hostname, customerId, clientId, clientSecret string, onPremises bool, disableSslVerification bool, userAgent *string, middlewareFunc MiddlewareAuthFunction) (*CitrixDaasClient, *http.Response, error) {
func NewCitrixDaasClient(ctx context.Context, authUrl, hostname, customerId, clientId, clientSecret string, onPremises bool, apiGateway bool, isGov bool, disableSslVerification bool, userAgent *string, middlewareFunc MiddlewareAuthFunction) (*CitrixDaasClient, *http.Response, error) {
daasClient := &CitrixDaasClient{}

/* ------ Setup API Client ------ */
Expand All @@ -54,6 +55,12 @@ func NewCitrixDaasClient(ctx context.Context, authUrl, hostname, customerId, cli
client := &http.Client{Transport: tr}
localCfg.HTTPClient = client

localCfg.Servers = citrixorchestration.ServerConfigurations{
{
URL: localCfg.Scheme + "://" + hostname + "/citrix/orchestration/api/techpreview",
},
}
} else if !apiGateway {
localCfg.Servers = citrixorchestration.ServerConfigurations{
{
URL: localCfg.Scheme + "://" + hostname + "/citrix/orchestration/api/techpreview",
Expand All @@ -70,6 +77,8 @@ func NewCitrixDaasClient(ctx context.Context, authUrl, hostname, customerId, cli
localAuthCfg.ClientId = clientId
localAuthCfg.ClientSecret = clientSecret
localAuthCfg.OnPremises = onPremises
localAuthCfg.ApiGateway = apiGateway
localAuthCfg.IsGov = isGov

daasClient.AuthConfig = localAuthCfg

Expand All @@ -88,8 +97,8 @@ func NewCitrixDaasClient(ctx context.Context, authUrl, hostname, customerId, cli

localClientCfg := &ClientConfiguration{}
localClientCfg.CustomerId = customerId
if len(resp.Customers) == 0 || len(resp.Customers[0].Sites) == 0 {
return nil, httpResp, fmt.Errorf("Customer does not exist or does not have a valid site.")
if resp == nil || len(resp.Customers) == 0 || len(resp.Customers[0].Sites) == 0 {
return nil, httpResp, fmt.Errorf("customer does not exist or does not have a valid site")
}
localClientCfg.SiteId = resp.Customers[0].Sites[0].Id

Expand All @@ -101,7 +110,7 @@ func NewCitrixDaasClient(ctx context.Context, authUrl, hostname, customerId, cli

daasClient.ClientConfig = localClientCfg

if onPremises {
if onPremises || !apiGateway {
// add CustomerId and SiteId to base path for on-prem. The Me Api will no longer work.
localCfg.Servers[0].URL += "/" + localClientCfg.CustomerId + "/" + localClientCfg.SiteId
}
Expand Down Expand Up @@ -240,6 +249,17 @@ func GetTransactionIdFromHttpResponse(httpResponse *http.Response) string {
return httpResponse.Header.Get("Citrix-TransactionId")
}

func (c *CitrixDaasClient) GetBatchRequestItemRelativeUrl(relativeUrl string) string {
if !c.AuthConfig.OnPremises && c.AuthConfig.ApiGateway {
// In case of Cloud customer going through API Gateway, use path as is
return relativeUrl
}

// In case of on-premises customer or Cloud customer bypassing API Gateway, append customerId and siteId to the path
path := fmt.Sprintf("%s/%s/%s", c.ClientConfig.CustomerId, c.ClientConfig.SiteId, relativeUrl)
return path
}

func PerformBatchOperation(ctx context.Context, client *CitrixDaasClient, batchRequestModel citrixorchestration.BatchRequestModel) (int, string, error) {

successCount := 0
Expand All @@ -263,13 +283,13 @@ func PerformBatchOperation(ctx context.Context, client *CitrixDaasClient, batchR
return successCount, txnId, fmt.Errorf(jobResponseModel.GetErrorString())
}

return successCount, txnId, fmt.Errorf("An unexpected error occurred.")
return successCount, txnId, fmt.Errorf("an unexpected error occurred")
}

getJobResultsRequest := client.ApiClient.JobsAPIsDAAS.JobsGetJobResults(ctx, jobId)
jobResultsResponseContents, _, _ := ExecuteWithRetry[*os.File](getJobResultsRequest, client)
if jobResultsResponseContents == nil {
return successCount, txnId, fmt.Errorf("An unexpected error occurred.")
return successCount, txnId, fmt.Errorf("an unexpected error occurred")
}

data, err := io.ReadAll(jobResultsResponseContents)
Expand All @@ -284,15 +304,16 @@ func PerformBatchOperation(ctx context.Context, client *CitrixDaasClient, batchR
subJob := subJobs[i]
subJobCode := subJob.GetCode()
if subJobCode == 202 {
locationValues := strings.Split(subJob.Headers[0].GetValue(), "/")
locationHeader := slices.IndexFunc(subJob.Headers, func(pair citrixorchestration.NameValueStringPairModel) bool { return pair.GetName() == "Location" })
locationValues := strings.Split(subJob.Headers[locationHeader].GetValue(), "/")
jobResponseModel, err := client.WaitForJob(ctx, locationValues[len(locationValues)-1], 60)

if err != nil || jobResponseModel == nil || jobResponseModel.GetStatus() != citrixorchestration.JOBSTATUS_COMPLETE {
continue
}

successCount++
} else if subJobCode == 204 {
} else if subJobCode == 204 || subJobCode == 200 {
successCount++
}
}
Expand Down
4 changes: 3 additions & 1 deletion client/citrixdaas_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

package citrixclient

// AuthenticationConfiguration provides authentication settings for CC Athena / on-prem trust service
// AuthenticationConfiguration provides authentication settings for CC Athena / on-premises trust service
type AuthenticationConfiguration struct {
AuthUrl string `json:"auth_url"`
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"`
OnPremises bool `json:"on_premises"`
ApiGateway bool `json:"api_gateway"`
IsGov bool `json:"is_gov"`
}

// ClientConfiguration provides Citrix DaaS customer context
Expand Down

0 comments on commit bf12d42

Please sign in to comment.