From a18fa2203d53e670112b22ebab581117b7759a01 Mon Sep 17 00:00:00 2001 From: Dmitry Rubtsov Date: Wed, 9 Aug 2023 17:27:19 +0600 Subject: [PATCH] init --- .github/workflows/docker.yml | 49 +++++++++++++++++ .gitignore | 1 + Dockerfile | 18 ++++++ README.md | 50 +++++++++++++++++ cmd/matrix-gpt/cmd.go | 37 +++++++++++++ cmd/matrix-gpt/log.go | 41 ++++++++++++++ cmd/matrix-gpt/main.go | 103 +++++++++++++++++++++++++++++++++++ go.mod | 29 ++++++++++ go.sum | 52 ++++++++++++++++++ internal/bot/actions.go | 46 ++++++++++++++++ internal/bot/bot.go | 77 ++++++++++++++++++++++++++ internal/bot/handlers.go | 79 +++++++++++++++++++++++++++ internal/gpt/completions.go | 50 +++++++++++++++++ internal/gpt/gpt.go | 55 +++++++++++++++++++ internal/gpt/history.go | 63 +++++++++++++++++++++ internal/gpt/users.go | 18 ++++++ internal/text/text.go | 49 +++++++++++++++++ internal/text/text_test.go | 90 ++++++++++++++++++++++++++++++ 18 files changed, 907 insertions(+) create mode 100644 .github/workflows/docker.yml create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 README.md create mode 100644 cmd/matrix-gpt/cmd.go create mode 100644 cmd/matrix-gpt/log.go create mode 100644 cmd/matrix-gpt/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/bot/actions.go create mode 100644 internal/bot/bot.go create mode 100644 internal/bot/handlers.go create mode 100644 internal/gpt/completions.go create mode 100644 internal/gpt/gpt.go create mode 100644 internal/gpt/history.go create mode 100644 internal/gpt/users.go create mode 100644 internal/text/text.go create mode 100644 internal/text/text_test.go diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 0000000..3fd2c87 --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,49 @@ +name: Docker Publish + +on: + push: + branches: + - master + tags: + - '*' + +jobs: + docker: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + + - name: Login to GitHub Container Registry + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Get Git Tag + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') + run: | + echo "GIT_TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV + + - name: Build and push Docker image with tag + if: env.GIT_TAG != '' + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}:${{ env.GIT_TAG }} + build-args: | + BUILD_TAG=${{ env.GIT_TAG }} + + - name: Build and push Docker image with latest + if: github.ref == 'refs/heads/master' + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ghcr.io/${{ github.repository_owner }}/${{ github.event.repository.name }}:latest diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..485dee6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..b11aa24 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,18 @@ +FROM golang:1.20 as build + +ENV CGO_ENABLED 1 +RUN apt-get update && apt-get install -y libolm-dev && \ + rm -rf /var/lib/apt/lists/* + +COPY . /app +RUN cd /app && \ + go build -ldflags="-s -w" -trimpath -o /matrix-gpt ./cmd/matrix-gpt + +FROM ubuntu:22.04 +RUN apt-get update && \ + apt-get install -y libolm3 ca-certificates tzdata && \ + rm -rf /var/lib/apt/lists/* + +COPY --from=build /matrix-gpt /matrix-gpt +USER 1337 +CMD ["/matrix-gpt"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..0d92b5b --- /dev/null +++ b/README.md @@ -0,0 +1,50 @@ +# Matrix GPT + +Matrix GPT is a Matrix chatbot that uses OpenAI for real-time chatting. +## Installation + +### Docker + +Run the Docker container: + +```bash +docker run -d --name matrix-gpt \ + -p 8080:8080 \ + -e MATRIX_PASSWORD="matrix password" \ + -e MATRIX_ID="matrix id" \ + -e MATRIX_URL="matrix server url" \ + -e OPENAI_TOKEN="openai token" \ + -e SQLITE_PATH="persistent path for sqlite database" + -e USER_IDS="allowed user ids" + ghcr.io/mazzz1y/matrix-gpt:latest + +``` +## Configuration + +You can configure GPT Matrix using the following environment variables: + +- `SERVER_URL`: The URL to the Matrix homeserver. +- `USER_ID`: Your Matrix user ID for the bot. +- `PASSWORD`: The password for your Matrix bot's account. +- `SQLITE_PATH`: Path to SQLite database for end-to-end encryption. +- `HISTORY_EXPIRE`: Duration after which chat history expires. +- `GPT_MODEL`: The OpenAI GPT model being used. +- `GPT_HISTORY_LIMIT`: Limit for number of chat messages retained in history. +- `GPT_TIMEOUT`: Duration for OpenAI API timeout. +- `GPT_MAX_ATTEMPTS`: Maximum number of attempts for GPT API retries. +- `GPT_USER_IDS`: List of authorized user IDs for the bot. + +Alternatively, you can set these options using command-line flags. Run `./matrix-gpt --help` for more +information. + +## Usage + +1. Begin by adding the bot to your contact list. + Once added, you can start interacting with it. Simply send your questions or commands, and the bot will respond. +2. If at any point you wish to reset the context of the conversation, use the `!reset` command. + Send `!reset` as a message to the bot, and it will clear the existing context, allowing you to start a fresh + conversation. + +## TODO + +* Add image generation using DALL·E model \ No newline at end of file diff --git a/cmd/matrix-gpt/cmd.go b/cmd/matrix-gpt/cmd.go new file mode 100644 index 0000000..fc02c47 --- /dev/null +++ b/cmd/matrix-gpt/cmd.go @@ -0,0 +1,37 @@ +package main + +import ( + "github.com/mazzz1y/matrix-gpt/internal/bot" + "github.com/mazzz1y/matrix-gpt/internal/gpt" + "github.com/urfave/cli/v2" +) + +func run(c *cli.Context) error { + mPassword := c.String("matrix-password") + mUserId := c.String("matrix-id") + mUrl := c.String("matrix-url") + mRoom := c.String("matrix-room") + sqlitePath := c.String("sqlite-path") + + gptModel := c.String("gpt-model") + gptTimeout := c.Int("gpt-timeout") + openaiToken := c.String("openai-token") + maxAttempts := c.Int("max-attempts") + + historyExpire := c.Int("history-expire") + historyLimit := c.Int("history-limit") + userIDs := c.StringSlice("user-ids") + + logLevel := c.String("log-level") + logType := c.String("log-type") + + setLogLevel(logLevel, logType) + + g := gpt.New(openaiToken, gptModel, historyLimit, gptTimeout, maxAttempts, userIDs) + m, err := bot.NewBot(mUrl, mUserId, mPassword, sqlitePath, mRoom, historyExpire, g) + if err != nil { + return err + } + + return m.StartHandler() +} diff --git a/cmd/matrix-gpt/log.go b/cmd/matrix-gpt/log.go new file mode 100644 index 0000000..4317601 --- /dev/null +++ b/cmd/matrix-gpt/log.go @@ -0,0 +1,41 @@ +package main + +import ( + "os" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +func setLogLevel(logLevel string, logType string) { + switch logType { + case "json": + case "pretty": + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + default: + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Info().Msgf("invalid log type: %s. using 'pretty' as default", logType) + } + + switch logLevel { + case "trace": + zerolog.SetGlobalLevel(zerolog.TraceLevel) + case "debug": + zerolog.SetGlobalLevel(zerolog.InfoLevel) + case "info": + zerolog.SetGlobalLevel(zerolog.InfoLevel) + case "warn": + zerolog.SetGlobalLevel(zerolog.WarnLevel) + case "error": + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + case "fatal": + zerolog.SetGlobalLevel(zerolog.FatalLevel) + case "panic": + zerolog.SetGlobalLevel(zerolog.PanicLevel) + case "no": + zerolog.SetGlobalLevel(zerolog.NoLevel) + default: + zerolog.SetGlobalLevel(zerolog.InfoLevel) + log.Info().Msgf("invalid log level: %s. using 'info' as default", logLevel) + } +} diff --git a/cmd/matrix-gpt/main.go b/cmd/matrix-gpt/main.go new file mode 100644 index 0000000..bc485b1 --- /dev/null +++ b/cmd/matrix-gpt/main.go @@ -0,0 +1,103 @@ +package main + +import ( + "fmt" + "github.com/sashabaranov/go-openai" + "github.com/urfave/cli/v2" + "os" +) + +var version = "git" + +func main() { + app := &cli.App{ + Name: "matrix-gpt", + Version: version, + Usage: "GPT Matrix Bot", + Action: run, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "matrix-password", + Usage: "Matrix password", + EnvVars: []string{"MATRIX_PASSWORD"}, + Required: true, + }, + &cli.StringFlag{ + Name: "matrix-id", + Usage: "Matrix user ID", + EnvVars: []string{"MATRIX_ID"}, + Required: true, + }, + &cli.StringFlag{ + Name: "matrix-url", + Usage: "Matrix server URL", + EnvVars: []string{"MATRIX_URL"}, + Required: true, + }, + &cli.StringFlag{ + Name: "openai-token", + Usage: "OpenAI API token", + EnvVars: []string{"OPENAI_TOKEN"}, + Required: true, + }, + &cli.StringFlag{ + Name: "sqlite-path", + Usage: "Path to SQLite database", + EnvVars: []string{"SQLITE_PATH"}, + Required: true, + }, + &cli.IntFlag{ + Name: "history-limit", + Usage: "Maximum number of history entries", + EnvVars: []string{"HISTORY_LIMIT"}, + Value: 5, + }, + &cli.IntFlag{ + Name: "history-expire", + Usage: "Time after which history entries expire (in hours)", + EnvVars: []string{"HISTORY_EXPIRE"}, + Value: 3, + }, + &cli.StringFlag{ + Name: "gpt-model", + Usage: "GPT model name/version", + EnvVars: []string{"GPT_MODEL"}, + Value: openai.GPT3Dot5Turbo, + }, + &cli.IntFlag{ + Name: "gpt-timeout", + Usage: "Time to wait for a GPT response (in seconds)", + EnvVars: []string{"GPT_TIMEOUT"}, + Value: 45, + }, + &cli.IntFlag{ + Name: "max-attempts", + Usage: "Maximum number of retry attempts for GPT", + EnvVars: []string{"MAX_ATTEMPTS"}, + Value: 3, + }, + &cli.StringSliceFlag{ + Name: "user-ids", + Usage: "List of allowed Matrix user IDs", + EnvVars: []string{"USER_IDS"}, + Required: true, + }, + &cli.StringFlag{ + Name: "log-level", + Value: "info", + Usage: "Logging level (e.g. debug, info, warn, error, fatal, panic, no)", + EnvVars: []string{"LOG_LEVEL"}, + }, + &cli.StringFlag{ + Name: "log-type", + Value: "pretty", + Usage: "Logging format/type (e.g. pretty, json)", + EnvVars: []string{"LOG_TYPE"}, + }, + }, + } + + if err := app.Run(os.Args); err != nil { + fmt.Printf("\n" + err.Error()) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..79a36ee --- /dev/null +++ b/go.mod @@ -0,0 +1,29 @@ +module github.com/mazzz1y/matrix-gpt + +go 1.20 + +require ( + github.com/mattn/go-sqlite3 v1.14.17 + github.com/rs/zerolog v1.30.0 + github.com/sashabaranov/go-openai v1.14.1 + github.com/urfave/cli/v2 v2.25.7 + maunium.net/go/mautrix v0.15.4 +) + +require ( + github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect + github.com/mattn/go-colorable v0.1.12 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect + github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect + github.com/yuin/goldmark v1.5.4 // indirect + golang.org/x/crypto v0.11.0 // indirect + golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect + golang.org/x/net v0.12.0 // indirect + golang.org/x/sys v0.10.0 // indirect + maunium.net/go/maulogger/v2 v2.4.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..0437749 --- /dev/null +++ b/go.sum @@ -0,0 +1,52 @@ +github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c= +github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= +github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sashabaranov/go-openai v1.14.1 h1:jqfkdj8XHnBF84oi2aNtT8Ktp3EJ0MfuVjvcMkfI0LA= +github.com/sashabaranov/go-openai v1.14.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/urfave/cli/v2 v2.25.7 h1:VAzn5oq403l5pHjc4OhD54+XGO9cdKVL/7lDjF+iKUs= +github.com/urfave/cli/v2 v2.25.7/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ= +github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= +github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= +github.com/yuin/goldmark v1.5.4 h1:2uY/xC0roWy8IBEGLgB1ywIoEJFGmRrX21YQcvGZzjU= +github.com/yuin/goldmark v1.5.4/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= +maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= +maunium.net/go/mautrix v0.15.4 h1:Ug3n2Mo+9Yb94AjZTWJQSNHmShaksEzZi85EPl3S3P0= +maunium.net/go/mautrix v0.15.4/go.mod h1:dBaDmsnOOBM4a+gKcgefXH73pHGXm+MCJzCs1dXFgrw= diff --git a/internal/bot/actions.go b/internal/bot/actions.go new file mode 100644 index 0000000..7ce01b2 --- /dev/null +++ b/internal/bot/actions.go @@ -0,0 +1,46 @@ +package bot + +import ( + "github.com/mazzz1y/matrix-gpt/internal/gpt" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" + "time" +) + +// sendAnswer responds to a user message with a GPT-based completion. +func (b *Bot) sendAnswer(u *gpt.User, evt *event.Event) error { + if err := b.client.MarkRead(evt.RoomID, evt.ID); err != nil { + return err + } + + b.startTyping(evt.RoomID) + defer b.stopTyping(evt.RoomID) + + msg := evt.Content.AsMessage().Body + answer, err := b.gptClient.GetCompletion(u, msg) + if err != nil { + return err + } + + formattedMsg := format.RenderMarkdown(answer, true, false) + _, err = b.client.SendMessageEvent(evt.RoomID, event.EventMessage, &formattedMsg) + return err +} + +// sendReaction sends a reaction to a message. +func (b *Bot) sendReaction(evt *event.Event, emoji string) error { + _, err := b.client.SendReaction(evt.RoomID, evt.ID, emoji) + return err +} + +// startTyping notifies the room that the bot is typing. +func (b *Bot) startTyping(roomID id.RoomID) { + timeout := time.Duration(b.gptClient.GetTimeout()) * time.Second + _, _ = b.client.UserTyping(roomID, true, timeout) +} + +// stopTyping notifies the room that the bot has stopped typing. +func (b *Bot) stopTyping(roomID id.RoomID) { + _, _ = b.client.UserTyping(roomID, false, 0) +} diff --git a/internal/bot/bot.go b/internal/bot/bot.go new file mode 100644 index 0000000..e3222c5 --- /dev/null +++ b/internal/bot/bot.go @@ -0,0 +1,77 @@ +package bot + +import ( + _ "github.com/mattn/go-sqlite3" + "github.com/mazzz1y/matrix-gpt/internal/gpt" + "github.com/rs/zerolog/log" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/cryptohelper" + "sync" +) + +type Bot struct { + sync.Mutex + client *mautrix.Client + gptClient *gpt.Gpt + room string + selfProfile mautrix.RespUserProfile + replaceFile string + historyExpire int +} + +// NewBot initializes a new Matrix bot instance. +func NewBot(serverUrl, userID, password, sqlitePath, scheduleRoom string, historyExpire int, gpt *gpt.Gpt) (*Bot, error) { + client, err := mautrix.NewClient(serverUrl, "", "") + if err != nil { + return nil, err + } + + crypto, err := cryptohelper.NewCryptoHelper(client, []byte("1337"), sqlitePath) + if err != nil { + return nil, err + } + + crypto.LoginAs = &mautrix.ReqLogin{ + Type: mautrix.AuthTypePassword, + Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: userID}, + Password: password, + } + + if err := crypto.Init(); err != nil { + return nil, err + } + + client.Crypto = crypto + + profile, err := client.GetProfile(client.UserID) + if err != nil { + return nil, err + } + + log.Info(). + Str("matrix-username", profile.DisplayName). + Str("gpt-model", gpt.GetModel()). + Int("gpt-timeout", gpt.GetTimeout()). + Int("history-limit", gpt.GetHistoryLimit()). + Int("history-expire", historyExpire). + Msg("connected to matrix") + + return &Bot{ + client: client, + gptClient: gpt, + selfProfile: *profile, + room: scheduleRoom, + historyExpire: historyExpire, + }, nil +} + +// StartHandler initializes bot event handlers and starts the matrix client sync. +func (b *Bot) StartHandler() error { + logger := log.With().Str("component", "handler").Logger() + + b.joinRoomHandler() + b.messageHandler() + + logger.Info().Msg("started handler") + return b.client.Sync() +} diff --git a/internal/bot/handlers.go b/internal/bot/handlers.go new file mode 100644 index 0000000..07ff7ef --- /dev/null +++ b/internal/bot/handlers.go @@ -0,0 +1,79 @@ +package bot + +import ( + "time" + + "github.com/mazzz1y/matrix-gpt/internal/gpt" + "github.com/mazzz1y/matrix-gpt/internal/text" + "github.com/rs/zerolog/log" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" +) + +// messageHandler sets up the handler for incoming messages. +func (b *Bot) messageHandler() { + syncer := b.client.Syncer.(*mautrix.DefaultSyncer) + syncer.OnEventType(event.EventMessage, b.msgEvtDispatcher) +} + +// joinRoomHandler sets up the handler for joining rooms. +func (b *Bot) joinRoomHandler() { + syncer := b.client.Syncer.(*mautrix.DefaultSyncer) + syncer.OnEventType(event.StateMember, func(source mautrix.EventSource, evt *event.Event) { + if evt.RoomID.String() == b.room && + evt.GetStateKey() == b.client.UserID.String() && + evt.Content.AsMember().Membership == event.MembershipInvite { + _, err := b.client.JoinRoomByID(evt.RoomID) + if err != nil { + return + } + } + }) +} + +// historyResetHandler checks for the reset command and resets history if found. +func (b *Bot) historyResetHandler(user *gpt.User, evt *event.Event) (ok bool) { + if text.HasPrefixIgnoreCase(evt.Content.AsMessage().Body, "!reset") { + user.History.ResetHistory() + _ = b.sendReaction(evt, "✅") + return true + } + return false +} + +// historyExpireHandler checks if the history for a user has expired and resets if necessary. +func (b *Bot) historyExpireHandler(user *gpt.User) { + if user.LastMsg.Add(time.Duration(b.historyExpire) * time.Hour).Before(time.Now()) { + user.History.ResetHistory() + } +} + +// msgEvtDispatcher dispatches incoming messages to their appropriate handlers. +func (b *Bot) msgEvtDispatcher(source mautrix.EventSource, evt *event.Event) { + // Ignore messages sent by the bot itself + if b.client.UserID.String() == evt.Sender.String() { + return + } + + l := log.With(). + Str("component", "handler"). + Str("user_id", evt.Sender.String()). + Logger() + + user, ok := b.gptClient.GetUser(evt.Sender.String()) + if !ok { + l.Info().Msg("forbidden") + return + } + + if b.historyResetHandler(user, evt) { + l.Info().Msg("reset history") + } + + err := b.sendAnswer(user, evt) + if err != nil { + l.Err(err).Msg("failed to send message") + } else { + l.Info().Msg("sending answer") + } +} diff --git a/internal/gpt/completions.go b/internal/gpt/completions.go new file mode 100644 index 0000000..400350e --- /dev/null +++ b/internal/gpt/completions.go @@ -0,0 +1,50 @@ +package gpt + +import ( + "context" + "github.com/sashabaranov/go-openai" + "time" +) + +// GetCompletion retrieves a completion from GPT using the given user's message. +func (g *Gpt) GetCompletion(u *User, userMsg string) (string, error) { + // Append the user's message to the existing history. + messageHistory := append(u.History.GetHistory(), openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: userMsg, + }) + + var ( + response openai.ChatCompletionResponse + err error + ) + + // Try creating a completion up to the maximum number of allowed attempts. + for i := 0; i < g.maxAttempts; i++ { + response, err = g.createCompletionWithTimeout(messageHistory) + if err == nil { + break + } + time.Sleep(5 * time.Second) + } + + // Update the user's history with both the user message and the assistant's response. + u.History.AddMessage(openai.ChatMessageRoleUser, userMsg) + u.History.AddMessage(openai.ChatMessageRoleAssistant, response.Choices[0].Message.Content) + + return response.Choices[0].Message.Content, err +} + +// createCompletionWithTimeout makes a request to get a GPT completion with a specified timeout. +func (g *Gpt) createCompletionWithTimeout(msg []openai.ChatCompletionMessage) (openai.ChatCompletionResponse, error) { + ctx, cancel := context.WithTimeout(g.ctx, time.Duration(g.gptTimeout)*time.Second) + defer cancel() + + return g.client.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: g.model, + Messages: msg, + }, + ) +} diff --git a/internal/gpt/gpt.go b/internal/gpt/gpt.go new file mode 100644 index 0000000..db6a71a --- /dev/null +++ b/internal/gpt/gpt.go @@ -0,0 +1,55 @@ +package gpt + +import ( + "context" + "github.com/sashabaranov/go-openai" +) + +type Gpt struct { + client *openai.Client + ctx context.Context + model string + historyLimit int + gptTimeout int + maxAttempts int + users map[string]*User +} + +// New initializes a Gpt instance with the provided configurations. +func New(token, gptModel string, historyLimit, gptTimeout, maxAttempts int, userIDs []string) *Gpt { + users := make(map[string]*User) + for _, id := range userIDs { + users[id] = NewGptUser(historyLimit) + } + + return &Gpt{ + client: openai.NewClient(token), + ctx: context.Background(), + model: gptModel, + historyLimit: historyLimit, + gptTimeout: gptTimeout, + users: users, + maxAttempts: maxAttempts, + } +} + +// GetUser retrieves the User instance associated with the given ID. +func (g *Gpt) GetUser(id string) (u *User, ok bool) { + u, ok = g.users[id] + return +} + +// GetModel returns the GPT model string. +func (g *Gpt) GetModel() string { + return g.model +} + +// GetTimeout returns the timeout value for the GPT client. +func (g *Gpt) GetTimeout() int { + return g.gptTimeout +} + +// GetHistoryLimit returns the history limit value. +func (g *Gpt) GetHistoryLimit() int { + return g.historyLimit +} diff --git a/internal/gpt/history.go b/internal/gpt/history.go new file mode 100644 index 0000000..8b41d29 --- /dev/null +++ b/internal/gpt/history.go @@ -0,0 +1,63 @@ +package gpt + +import ( + "github.com/sashabaranov/go-openai" + "sync" +) + +// HistoryManager manages chat histories for GPT interactions. +type HistoryManager struct { + sync.RWMutex + Storage []openai.ChatCompletionMessage + Size int +} + +// NewHistoryManager initializes a HistoryManager instance with the provided size. +func NewHistoryManager(size int) *HistoryManager { + return &HistoryManager{ + Storage: make([]openai.ChatCompletionMessage, 0), + Size: size, + } +} + +// ResetHistory clears the current chat history. +// Returns true if history was cleared, false otherwise. +func (m *HistoryManager) ResetHistory() bool { + m.Lock() + defer m.Unlock() + + if len(m.Storage) > 0 { + m.Storage = make([]openai.ChatCompletionMessage, 0) + return true + } + + return false +} + +// AddMessage appends a new message to the chat history. +func (m *HistoryManager) AddMessage(msgType, msgContent string) { + m.Lock() + defer m.Unlock() + + message := openai.ChatCompletionMessage{ + Role: msgType, + Content: msgContent, + } + m.Storage = append(m.Storage, message) + m.trimHistory() +} + +// GetHistory retrieves the current chat history. +func (m *HistoryManager) GetHistory() []openai.ChatCompletionMessage { + m.RLock() + defer m.RUnlock() + + return m.Storage +} + +// trimHistory ensures the chat history doesn't exceed its size limit. +func (m *HistoryManager) trimHistory() { + if len(m.Storage) > m.Size { + m.Storage = m.Storage[len(m.Storage)-m.Size:] + } +} diff --git a/internal/gpt/users.go b/internal/gpt/users.go new file mode 100644 index 0000000..5df8600 --- /dev/null +++ b/internal/gpt/users.go @@ -0,0 +1,18 @@ +package gpt + +import ( + "time" +) + +// User represents a GPT user with a chat history and last message timestamp. +type User struct { + History *HistoryManager + LastMsg time.Time +} + +// NewGptUser creates a new GPT user instance with a given history size. +func NewGptUser(historySize int) *User { + return &User{ + History: NewHistoryManager(historySize), + } +} diff --git a/internal/text/text.go b/internal/text/text.go new file mode 100644 index 0000000..4da3d05 --- /dev/null +++ b/internal/text/text.go @@ -0,0 +1,49 @@ +package text + +import ( + "strings" +) + +// HasPrefixIgnoreCase checks if string `s` has the given `prefix` irrespective of their case. +func HasPrefixIgnoreCase(s, prefix string) bool { + return strings.HasPrefix(strings.ToLower(s), strings.ToLower(prefix)) +} + +// ReplaceIgnoreCase replaces occurrences of `old` in `s` with `new` irrespective of their case for a specified count `n`. +func ReplaceIgnoreCase(s, old, new string, n int) string { + if old == "" { + return s + } + + msgLower := strings.ToLower(s) + oldLower := strings.ToLower(old) + var result strings.Builder + count := 0 + + for { + if count == n && n >= 0 { + result.WriteString(s) + break + } + + idx := strings.Index(msgLower, oldLower) + if idx == -1 { + result.WriteString(s) + break + } + + result.WriteString(s[:idx]) + result.WriteString(new) + + s = s[idx+len(old):] + msgLower = msgLower[idx+len(old):] + count++ + } + + return result.String() +} + +// ReplaceAllIgnoreCase replaces all occurrences of `old` in `s` with `new` irrespective of their case. +func ReplaceAllIgnoreCase(s, old, new string) string { + return ReplaceIgnoreCase(s, old, new, -1) +} diff --git a/internal/text/text_test.go b/internal/text/text_test.go new file mode 100644 index 0000000..2bb3178 --- /dev/null +++ b/internal/text/text_test.go @@ -0,0 +1,90 @@ +package text + +import ( + "testing" +) + +func TestReplaceIgnoreCase(t *testing.T) { + cases := []struct { + name string + s string + old string + new string + n int + expected string + }{ + { + name: "old substring exists", + s: "Hello World", + old: "World", + new: "Everyone", + n: -1, + expected: "Hello Everyone", + }, + { + name: "old substring does not exist", + s: "Hello World", + old: "Universe", + new: "Everyone", + n: -1, + expected: "Hello World", + }, + { + name: "old substring exists but different case", + s: "Hello World", + old: "WORLD", + new: "Everyone", + n: -1, + expected: "Hello Everyone", + }, + { + name: "limited replacements", + s: "Hello World World", + old: "World", + new: "Everyone", + n: 1, + expected: "Hello Everyone World", + }, + { + name: "more replacements than occurrences", + s: "Hello World", + old: "World", + new: "Everyone", + n: 3, + expected: "Hello Everyone", + }, + { + name: "empty string", + s: "", + old: "World", + new: "Everyone", + n: -1, + expected: "", + }, + { + name: "empty old substring", + s: "Hello World", + old: "", + new: "Everyone", + n: -1, + expected: "Hello World", + }, + { + name: "empty new substring", + s: "Hello World", + old: "World", + new: "", + n: -1, + expected: "Hello ", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + result := ReplaceIgnoreCase(tc.s, tc.old, tc.new, tc.n) + if result != tc.expected { + t.Errorf("ReplaceIgnoreCase(%q, %q, %q, %d) = %q; want %q", tc.s, tc.old, tc.new, tc.n, result, tc.expected) + } + }) + } +}