Skip to content

Commit

Permalink
add DALL-E image generation
Browse files Browse the repository at this point in the history
  • Loading branch information
mazzz1y committed Aug 10, 2023
1 parent 5e79487 commit de2d7a3
Show file tree
Hide file tree
Showing 12 changed files with 227 additions and 211 deletions.
22 changes: 15 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,20 @@ information.

## Usage

1. Begin by adding the bot to your contact list.
Once added, you can start interacting with it. Simply send your questions or commands, and the bot will respond.
2. If at any point you wish to reset the context of the conversation, use the `!reset` command.
Send `!reset` as a message to the bot, and it will clear the existing context, allowing you to start a fresh
conversation.
Follow these simple steps to interact with the bot:

## TODO
1. **Add the bot to your contact list.**

* Add image generation using DALL·E model
After the bot has been added, begin your interaction by sending your questions, commands or comments.

2. **Send commands.**

You can use the following commands to communicate with the bot:

- **Generate an Image:** `!image [text]` - This command creates an image based on the provided text.
- **Reset User History:** `!reset [text]` - This command resets the user's command history. If text is provided following the reset command, the bot will generate a GPT-based response based on this text.
- **Send a Text Message:** `[text]` - Send any text to the bot and it will generate a GPT-based response relevant to your text.

3. **Identify error responses.**

If there are any errors in processing your requests or commands, the bot will respond with a ❌ reaction.
3 changes: 1 addition & 2 deletions cmd/matrix-gpt/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ func run(c *cli.Context) error {
mPassword := c.String("matrix-password")
mUserId := c.String("matrix-id")
mUrl := c.String("matrix-url")
mRoom := c.String("matrix-room")
sqlitePath := c.String("sqlite-path")

gptModel := c.String("gpt-model")
Expand All @@ -28,7 +27,7 @@ func run(c *cli.Context) error {
setLogLevel(logLevel, logType)

g := gpt.New(openaiToken, gptModel, historyLimit, gptTimeout, maxAttempts, userIDs)
m, err := bot.NewBot(mUrl, mUserId, mPassword, sqlitePath, mRoom, historyExpire, g)
m, err := bot.NewBot(mUrl, mUserId, mPassword, sqlitePath, historyExpire, g)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/matrix-gpt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func main() {
Name: "gpt-timeout",
Usage: "Time to wait for a GPT response (in seconds)",
EnvVars: []string{"GPT_TIMEOUT"},
Value: 45,
Value: 180,
},
&cli.IntFlag{
Name: "max-attempts",
Expand Down
119 changes: 107 additions & 12 deletions internal/bot/actions.go
Original file line number Diff line number Diff line change
@@ -1,39 +1,105 @@
package bot

import (
"bytes"
"github.com/mazzz1y/matrix-gpt/internal/gpt"
"image"
"image/png"
"io"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
"net/http"
"time"
)

// sendAnswer responds to a user message with a GPT-based completion.
func (b *Bot) sendAnswer(u *gpt.User, evt *event.Event) error {
if err := b.client.MarkRead(evt.RoomID, evt.ID); err != nil {
// completionResponse responds to a user message with a GPT-based completion.
func (b *Bot) completionResponse(u *gpt.User, roomID id.RoomID, msg string) error {
answer, err := b.gptClient.CreateCompletion(u, msg)
if err != nil {
return err
}

return b.markdownResponse(roomID, answer)
}

// helpResponse responds with help message.
func (b *Bot) helpResponse(roomID id.RoomID) error {
return b.markdownResponse(roomID, helpMessage)
}

// imageResponse responds to the user message with a DALL-E created image.
func (b *Bot) imageResponse(roomID id.RoomID, msg string) error {
img, err := b.gptClient.CreateImage(msg)
if err != nil {
return err
}

imageBytes, err := getImageBytesFromURL(img.Data[0].URL)
if err != nil {
return err
}

cfg, err := png.DecodeConfig(bytes.NewReader(imageBytes))
if err != nil {
return err
}

b.startTyping(evt.RoomID)
defer b.stopTyping(evt.RoomID)
content := b.createImageMessageContent(imageBytes, cfg)

file := attachment.NewEncryptedFile()
file.EncryptInPlace(imageBytes)

req := mautrix.ReqUploadMedia{
ContentBytes: imageBytes,
ContentType: "application/octet-stream",
}

msg := evt.Content.AsMessage().Body
answer, err := b.gptClient.GetCompletion(u, msg)
upload, err := b.client.UploadMedia(req)
if err != nil {
return err
}

formattedMsg := format.RenderMarkdown(answer, true, false)
_, err = b.client.SendMessageEvent(evt.RoomID, event.EventMessage, &formattedMsg)
content.File = &event.EncryptedFileInfo{
EncryptedFile: *file,
URL: upload.ContentURI.CUString(),
}

_, err = b.client.SendMessageEvent(roomID, event.EventMessage, content)
return err
}

// sendReaction sends a reaction to a message.
func (b *Bot) sendReaction(evt *event.Event, emoji string) error {
_, err := b.client.SendReaction(evt.RoomID, evt.ID, emoji)
// resetResponse clears the user's history. If a message is provided, it's processed as a new input.
// Otherwise, a reaction is sent to indicate successful history reset.
func (b *Bot) resetResponse(u *gpt.User, evt *event.Event, msg string) error {
u.History.ResetHistory()
if msg != "" {
return b.completionResponse(u, evt.RoomID, msg)
} else {
b.reactionResponse(evt, "✅")
}
return nil
}

// markdownResponse sends a message response in markdown format.
func (b *Bot) markdownResponse(roomID id.RoomID, msg string) error {
formattedMsg := format.RenderMarkdown(msg, true, false)
_, err := b.client.SendMessageEvent(roomID, event.EventMessage, &formattedMsg)
return err
}

// reactionResponse sends a reaction to a message.
func (b *Bot) reactionResponse(evt *event.Event, emoji string) {
_, _ = b.client.SendReaction(evt.RoomID, evt.ID, emoji)
}

// markRead marks the given event as read by the bot.
func (b *Bot) markRead(evt *event.Event) {
_ = b.client.MarkRead(evt.RoomID, evt.ID)
}

// startTyping notifies the room that the bot is typing.
func (b *Bot) startTyping(roomID id.RoomID) {
timeout := time.Duration(b.gptClient.GetTimeout()) * time.Second
Expand All @@ -44,3 +110,32 @@ func (b *Bot) startTyping(roomID id.RoomID) {
func (b *Bot) stopTyping(roomID id.RoomID) {
_, _ = b.client.UserTyping(roomID, false, 0)
}

// createImageMessageContent creates the which contains the image information and the reply references.
func (b *Bot) createImageMessageContent(imageBytes []byte, cfg image.Config) *event.MessageEventContent {
return &event.MessageEventContent{
MsgType: event.MsgImage,
Info: &event.FileInfo{
Height: cfg.Height,
MimeType: http.DetectContentType(imageBytes),
Width: cfg.Height,
Size: len(imageBytes),
},
}
}

// getImageBytesFromURL returns the byte data from the image URL.
func getImageBytesFromURL(url string) ([]byte, error) {
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()

var buf bytes.Buffer
if _, err := io.Copy(&buf, resp.Body); err != nil {
return nil, err
}

return buf.Bytes(), nil
}
13 changes: 6 additions & 7 deletions internal/bot/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ import (
"github.com/rs/zerolog/log"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/cryptohelper"
"time"
)

type Bot struct {
client *mautrix.Client
gptClient *gpt.Gpt
selfProfile mautrix.RespUserProfile
replaceFile string
historyExpire int
historyExpire time.Duration
}

// NewBot initializes a new Matrix bot instance.
func NewBot(serverUrl, userID, password, sqlitePath, scheduleRoom string, historyExpire int, gpt *gpt.Gpt) (*Bot, error) {
func NewBot(serverUrl, userID, password, sqlitePath string, historyExpire int, gpt *gpt.Gpt) (*Bot, error) {
client, err := mautrix.NewClient(serverUrl, "", "")
if err != nil {
return nil, err
Expand All @@ -39,7 +39,6 @@ func NewBot(serverUrl, userID, password, sqlitePath, scheduleRoom string, histor
}

client.Crypto = crypto

profile, err := client.GetProfile(client.UserID)
if err != nil {
return nil, err
Expand All @@ -57,16 +56,16 @@ func NewBot(serverUrl, userID, password, sqlitePath, scheduleRoom string, histor
client: client,
gptClient: gpt,
selfProfile: *profile,
historyExpire: historyExpire,
historyExpire: time.Duration(historyExpire) * time.Hour,
}, nil
}

// StartHandler initializes bot event handlers and starts the matrix client sync.
func (b *Bot) StartHandler() error {
logger := log.With().Str("component", "handler").Logger()

b.joinRoomHandler()
b.messageHandler()
b.setupJoinRoomEvent()
b.setupMessageEvent()

logger.Info().Msg("started handler")
return b.client.Sync()
Expand Down
27 changes: 27 additions & 0 deletions internal/bot/commands.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package bot

import "strings"

const (
GenerateImageCommand = "image"
HistoryResetCommand = "reset"
HelpCommand = "help"
)

func extractCommand(s string) (cmd string) {
if strings.HasPrefix(s, "!") && len(s) > 1 {
//Get the word after '!'
command := strings.Fields(s)[0][1:]
return command
}
return ""
}

func trimCommand(s string) string {
if strings.HasPrefix(s, "!") && len(s) > 1 {
//Remove command from s and clean up leading spaces
trimmed := strings.TrimPrefix(s, strings.Fields(s)[0])
return strings.TrimSpace(trimmed)
}
return s
}
77 changes: 36 additions & 41 deletions internal/bot/handlers.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
package bot

import (
"time"

"fmt"
"github.com/mazzz1y/matrix-gpt/internal/gpt"
"github.com/mazzz1y/matrix-gpt/internal/text"
"github.com/rs/zerolog/log"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
)

// messageHandler sets up the handler for incoming messages.
func (b *Bot) messageHandler() {
// setupMessageEvent sets up the handler for incoming messages.
func (b *Bot) setupMessageEvent() {
syncer := b.client.Syncer.(*mautrix.DefaultSyncer)
syncer.OnEventType(event.EventMessage, b.msgEvtDispatcher)
syncer.OnEventType(event.EventMessage, b.messageHandler)
}

// joinRoomHandler sets up the handler for joining rooms.
func (b *Bot) joinRoomHandler() {
// setupJoinRoomEvent sets up the handler for joining rooms.
func (b *Bot) setupJoinRoomEvent() {
syncer := b.client.Syncer.(*mautrix.DefaultSyncer)
syncer.OnEventType(event.StateMember, func(source mautrix.EventSource, evt *event.Event) {
_, ok := b.gptClient.GetUser(evt.Sender.String())
Expand All @@ -32,28 +30,8 @@ func (b *Bot) joinRoomHandler() {
})
}

// historyResetHandler checks for the reset command and resets history if found.
func (b *Bot) historyResetHandler(user *gpt.User, evt *event.Event) (ok bool) {
if text.HasPrefixIgnoreCase(evt.Content.AsMessage().Body, "!reset") {
user.History.ResetHistory()
_ = b.sendReaction(evt, "✅")
return true
}
return false
}

// historyExpireHandler checks if the history for a user has expired and resets if necessary.
func (b *Bot) historyExpireHandler(user *gpt.User) (ok bool) {
if user.GetLastMsgTime().Add(time.Duration(b.historyExpire) * time.Hour).Before(time.Now()) {
user.History.ResetHistory()
return true
}
return false
}

// msgEvtDispatcher dispatches incoming messages to their appropriate handlers.
func (b *Bot) msgEvtDispatcher(source mautrix.EventSource, evt *event.Event) {
// Ignore messages sent by the bot itself
// messageHandler handles incoming messages based on their type.
func (b *Bot) messageHandler(source mautrix.EventSource, evt *event.Event) {
if b.client.UserID.String() == evt.Sender.String() {
return
}
Expand All @@ -69,20 +47,37 @@ func (b *Bot) msgEvtDispatcher(source mautrix.EventSource, evt *event.Event) {
return
}

if b.historyResetHandler(user, evt) {
l.Info().Msg("history reset by user command")
return
}
if b.historyExpireHandler(user) {
l.Info().Msg("history has expired, resetting")
}

err := b.sendAnswer(user, evt)
err := b.sendResponse(user, evt)
if err != nil {
l.Err(err).Msg("failed to send message")
return
b.reactionResponse(evt, "❌")
l.Err(err).Msg("response error")
}

user.UpdateLastMsgTime()
l.Info().Msg("message sent")
}

// sendResponse responds to the user command.
func (b *Bot) sendResponse(user *gpt.User, evt *event.Event) (err error) {
b.markRead(evt)
b.startTyping(evt.RoomID)
defer b.stopTyping(evt.RoomID)

cmd := extractCommand(evt.Content.AsMessage().Body)
msg := trimCommand(evt.Content.AsMessage().Body)

switch cmd {
case HelpCommand:
err = b.helpResponse(evt.RoomID)
case GenerateImageCommand:
err = b.imageResponse(evt.RoomID, msg)
case HistoryResetCommand:
err = b.resetResponse(user, evt, msg)
case "":
err = b.completionResponse(user, evt.RoomID, msg)
default:
err = fmt.Errorf("command: %s does not exist", cmd)
}

return err
}
Loading

0 comments on commit de2d7a3

Please sign in to comment.