Skip to content

Commit

Permalink
Merge branch 'main' into refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
xjd committed Feb 1, 2024
2 parents ab97587 + 88e41ff commit 25f9e40
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 53 deletions.
2 changes: 1 addition & 1 deletion common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ var StreamRequestOutTime = os.Getenv("STREAM_REQUEST_OUT_TIME")

var DebugEnabled = os.Getenv("DEBUG") == "true"

var Version = "v2.1.0" // this hard coding will be replaced automatically when building, no need to manually change
var Version = "v2.1.1" // this hard coding will be replaced automatically when building, no need to manually change

const (
RequestIdKey = "X-Request-Id"
Expand Down
62 changes: 15 additions & 47 deletions controller/chat.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package controller

import (
"context"
"coze-discord-proxy/common"
"coze-discord-proxy/discord"
"coze-discord-proxy/model"
Expand Down Expand Up @@ -31,25 +30,17 @@ func Chat(c *gin.Context) {
var chatModel model.ChatReq
err := json.NewDecoder(c.Request.Body).Decode(&chatModel)
if err != nil {
common.LogError(context.Background(), err.Error())
common.LogError(c.Request.Context(), err.Error())
c.JSON(http.StatusOK, gin.H{
"message": "无效的参数",
"success": false,
})
return
}

if runeCount := len([]rune(chatModel.Content)); runeCount > 2000 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("prompt最大为2000字符 [%v]", runeCount),
})
return
}

sendChannelId, calledCozeBotId, err := getSendChannelIdAndCozeBotId(c, false, chatModel)
if err != nil {
common.LogError(context.Background(), err.Error())
common.LogError(c.Request.Context(), err.Error())
c.JSON(http.StatusOK, model.OpenAIErrorResponse{
OpenAIError: model.OpenAIError{
Message: "配置异常",
Expand All @@ -60,12 +51,11 @@ func Chat(c *gin.Context) {
return
}

sentMsg, err := discord.SendMessage(sendChannelId, calledCozeBotId, chatModel.Content)
sentMsg, err := discord.SendMessage(c, sendChannelId, calledCozeBotId, chatModel.Content)
if err != nil {
common.LogError(context.Background(), err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "discord发送消息异常",
"message": err.Error(),
})
return
}
Expand All @@ -80,7 +70,7 @@ func Chat(c *gin.Context) {

timer, err := setTimerWithHeader(c, chatModel.Stream, common.RequestOutTimeDuration)
if err != nil {
common.LogError(context.Background(), err.Error())
common.LogError(c.Request.Context(), err.Error())
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "超时时间设置异常",
Expand Down Expand Up @@ -148,7 +138,7 @@ func ChatForOpenAI(c *gin.Context) {
var request model.OpenAIChatCompletionRequest
err := json.NewDecoder(c.Request.Body).Decode(&request)
if err != nil {
common.LogError(context.Background(), err.Error())
common.LogError(c.Request.Context(), err.Error())
c.JSON(http.StatusOK, model.OpenAIErrorResponse{
OpenAIError: model.OpenAIError{
Message: "无效的参数",
Expand All @@ -167,15 +157,7 @@ func ChatForOpenAI(c *gin.Context) {
if message.Role == "user" {
switch contentObj := message.Content.(type) {
case string:
if runeCount := len([]rune(contentObj)); runeCount > 2000 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("prompt最大为2000字符 [%v]", runeCount),
})
return
}
content = contentObj

case []interface{}:
content, err = buildOpenAIGPT4VForImageContent(contentObj)
if err != nil {
Expand All @@ -202,7 +184,7 @@ func ChatForOpenAI(c *gin.Context) {

sendChannelId, calledCozeBotId, err := getSendChannelIdAndCozeBotId(c, true, request)
if err != nil {
common.LogError(context.Background(), err.Error())
common.LogError(c.Request.Context(), err.Error())
c.JSON(http.StatusOK, model.OpenAIErrorResponse{
OpenAIError: model.OpenAIError{
Message: "配置异常",
Expand All @@ -213,12 +195,11 @@ func ChatForOpenAI(c *gin.Context) {
return
}

sentMsg, err := discord.SendMessage(sendChannelId, calledCozeBotId, content)
sentMsg, err := discord.SendMessage(c, sendChannelId, calledCozeBotId, content)
if err != nil {
common.LogError(context.Background(), err.Error())
c.JSON(http.StatusOK, model.OpenAIErrorResponse{
OpenAIError: model.OpenAIError{
Message: "discord发送消息异常",
Message: err.Error(),
Type: "invalid_request_error",
Code: "discord_request_err",
},
Expand All @@ -236,7 +217,7 @@ func ChatForOpenAI(c *gin.Context) {

timer, err := setTimerWithHeader(c, request.Stream, common.RequestOutTimeDuration)
if err != nil {
common.LogError(context.Background(), err.Error())
common.LogError(c.Request.Context(), err.Error())
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "超时时间设置异常",
Expand Down Expand Up @@ -361,7 +342,7 @@ func ImagesForOpenAI(c *gin.Context) {
var request model.OpenAIImagesGenerationRequest
err := json.NewDecoder(c.Request.Body).Decode(&request)
if err != nil {
common.LogError(context.Background(), err.Error())
common.LogError(c.Request.Context(), err.Error())
c.JSON(http.StatusOK, model.OpenAIErrorResponse{
OpenAIError: model.OpenAIError{
Message: "无效的参数",
Expand All @@ -372,18 +353,6 @@ func ImagesForOpenAI(c *gin.Context) {
return
}

if err != nil {
common.LogError(context.Background(), err.Error())
c.JSON(http.StatusOK, model.OpenAIErrorResponse{
OpenAIError: model.OpenAIError{
Message: "配置异常",
Type: "invalid_request_error",
Code: "discord_request_err",
},
})
return
}

if runeCount := len([]rune(request.Prompt)); runeCount > 2000 {
c.JSON(http.StatusOK, gin.H{
"success": false,
Expand All @@ -394,7 +363,7 @@ func ImagesForOpenAI(c *gin.Context) {

sendChannelId, calledCozeBotId, err := getSendChannelIdAndCozeBotId(c, true, request)
if err != nil {
common.LogError(context.Background(), err.Error())
common.LogError(c.Request.Context(), err.Error())
c.JSON(http.StatusOK, model.OpenAIErrorResponse{
OpenAIError: model.OpenAIError{
Message: "配置异常",
Expand All @@ -405,12 +374,11 @@ func ImagesForOpenAI(c *gin.Context) {
return
}

sentMsg, err := discord.SendMessage(sendChannelId, calledCozeBotId, request.Prompt)
sentMsg, err := discord.SendMessage(c, sendChannelId, calledCozeBotId, request.Prompt)
if err != nil {
common.LogError(context.Background(), err.Error())
c.JSON(http.StatusOK, model.OpenAIErrorResponse{
OpenAIError: model.OpenAIError{
Message: "discord发送消息异常",
Message: err.Error(),
Type: "invalid_request_error",
Code: "discord_request_err",
},
Expand All @@ -428,7 +396,7 @@ func ImagesForOpenAI(c *gin.Context) {

timer, err := setTimerWithHeader(c, false, common.RequestOutTimeDuration)
if err != nil {
common.LogError(context.Background(), err.Error())
common.LogError(c.Request.Context(), err.Error())
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "超时时间设置异常",
Expand Down
27 changes: 22 additions & 5 deletions discord/discord.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/json"
"fmt"
"github.com/bwmarrin/discordgo"
"github.com/gin-gonic/gin"
"github.com/h2non/filetype"
"golang.org/x/net/proxy"
"log"
Expand Down Expand Up @@ -275,15 +276,31 @@ func processMessageForOpenAIImage(m *discordgo.MessageUpdate) model.OpenAIImages
}
}

func SendMessage(channelID, cozeBotId, message string) (*discordgo.Message, error) {
func SendMessage(c *gin.Context, channelID, cozeBotId, message string) (*discordgo.Message, error) {
var ctx context.Context
if c == nil {
ctx = context.Background()
} else {
ctx = c.Request.Context()
}

if Session == nil {
return nil, fmt.Errorf("Discord session not initialized")
common.LogError(ctx, "discord session is nil")
return nil, fmt.Errorf("discord session not initialized")
}

content := fmt.Sprintf("<@%s> %s", cozeBotId, message)

if runeCount := len([]rune(content)); runeCount > 2000 {
common.LogError(ctx, fmt.Sprintf("prompt已超过限制,请分段发送 [%v] %s", runeCount, content))
return nil, fmt.Errorf("prompt已超过限制,请分段发送 [%v]", runeCount)
}

// 添加@机器人逻辑
sentMsg, err := Session.ChannelMessageSend(channelID, fmt.Sprintf("<@%s> %s", cozeBotId, message))
sentMsg, err := Session.ChannelMessageSend(channelID, content)
if err != nil {
return nil, fmt.Errorf("error sending message: %s", err)
common.LogError(ctx, fmt.Sprintf("error sending message: %s", err))
return nil, fmt.Errorf("error sending message")
}
return sentMsg, nil
}
Expand Down Expand Up @@ -390,7 +407,7 @@ func scheduleDailyMessage() {
botConfigs := model.FilterUniqueBotChannel(BotConfigList)
for _, config := range botConfigs {

_, err := SendMessage(config.ChannelId, config.CozeBotId, "Hi!")
_, err := SendMessage(nil, config.ChannelId, config.CozeBotId, "Hi!")
if err != nil {
common.LogWarn(context.Background(), fmt.Sprintf("ChannelId{%s} BotId{%s} 活跃机器人任务消息发送异常!", config.ChannelId, config.CozeBotId))
} else {
Expand Down

0 comments on commit 25f9e40

Please sign in to comment.