From 2d44bd83ab30b0656bbbfbeccd637042224d7b5f Mon Sep 17 00:00:00 2001 From: zwwhdls Date: Wed, 29 May 2024 15:52:33 +0800 Subject: [PATCH] fix: openroom with entryid Signed-off-by: zwwhdls --- cmd/apps/apis/fsapi/v1/service.go | 31 +++++------ cmd/apps/apis/fsapi/v1/service_test.go | 6 --- go.mod | 2 +- go.sum | 2 + pkg/controller/controller.go | 1 + pkg/controller/dialogue.go | 26 ++++++++-- pkg/dialogue/dialogue.go | 15 +++++- pkg/metastore/instrumental.go | 8 +++ pkg/metastore/interface.go | 1 + pkg/metastore/sql.go | 16 ++++++ .../basenana/friday/pkg/friday/friday.go | 4 +- .../basenana/friday/pkg/friday/question.go | 4 +- .../basenana/friday/pkg/models/namespace.go | 51 +++++++++++++++++++ .../friday/pkg/vectorstore/db/entity.go | 12 +++++ .../pkg/vectorstore/postgres/postgres.go | 29 +++++------ vendor/modules.txt | 2 +- 16 files changed, 161 insertions(+), 49 deletions(-) create mode 100644 vendor/github.com/basenana/friday/pkg/models/namespace.go diff --git a/cmd/apps/apis/fsapi/v1/service.go b/cmd/apps/apis/fsapi/v1/service.go index 15901c5d..632deb2f 100644 --- a/cmd/apps/apis/fsapi/v1/service.go +++ b/cmd/apps/apis/fsapi/v1/service.go @@ -103,25 +103,26 @@ func (s *services) OpenRoom(ctx context.Context, request *OpenRoomRequest) (*Ope if request.EntryID == 0 { return nil, status.Error(codes.InvalidArgument, "entry id is empty") } - if request.RoomID == 0 { - // need create a new one - prompt := "" - if request.Option != nil { - prompt = request.Option.Prompt - } - room, err := s.ctrl.CreateRoom(ctx, request.EntryID, prompt) - if err != nil { - return nil, status.Error(common.FsApiError(err), "create room failed") - } - return &OpenRoomResponse{ - Room: &RoomInfo{Id: room.ID, EntryID: room.EntryId, Namespace: room.Namespace, Title: room.Title, Prompt: room.Prompt, CreatedAt: timestamppb.New(room.CreatedAt)}, - }, nil - } - room, err := s.ctrl.GetRoom(ctx, request.RoomID) + room, err := s.ctrl.FindRoom(ctx, request.EntryID) if err != nil { + if err == types.ErrNotFound { + // need create a new one + prompt := "" + if request.Option != nil { + prompt = request.Option.Prompt + } + room, err := s.ctrl.CreateRoom(ctx, request.EntryID, prompt) + if err != nil { + return nil, status.Error(common.FsApiError(err), "create room failed") + } + return &OpenRoomResponse{ + Room: &RoomInfo{Id: room.ID, EntryID: room.EntryId, Namespace: room.Namespace, Title: room.Title, Prompt: room.Prompt, CreatedAt: timestamppb.New(room.CreatedAt)}, + }, nil + } return nil, status.Error(common.FsApiError(err), "get room failed") } + msg := make([]*RoomMessage, 0, len(room.Messages)) for _, m := range room.Messages { msg = append(msg, &RoomMessage{ diff --git a/cmd/apps/apis/fsapi/v1/service_test.go b/cmd/apps/apis/fsapi/v1/service_test.go index 79b4ac2d..ef914596 100644 --- a/cmd/apps/apis/fsapi/v1/service_test.go +++ b/cmd/apps/apis/fsapi/v1/service_test.go @@ -128,12 +128,6 @@ var _ = Describe("testRoomService", func() { It("delete should be succeed", func() { _, err := serviceClient.DeleteRoom(ctx, &DeleteRoomRequest{RoomID: roomId}, grpc.UseCompressor(gzip.Name)) Expect(err).Should(BeNil()) - res, err := serviceClient.OpenRoom(ctx, &OpenRoomRequest{ - EntryID: 1, - RoomID: roomId, - }, grpc.UseCompressor(gzip.Name)) - Expect(err).ShouldNot(BeNil()) - Expect(res).Should(BeNil()) }) }) }) diff --git a/go.mod b/go.mod index e0460290..c930d192 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.13.26 github.com/aws/aws-sdk-go-v2/service/s3 v1.36.0 github.com/aws/smithy-go v1.13.5 - github.com/basenana/friday v0.0.0-20240514065549-962b40b3faf1 + github.com/basenana/friday v0.0.0-20240529034549-d8f6efb8f215 github.com/blevesearch/bleve/v2 v2.4.0 github.com/bluele/gcache v0.0.2 github.com/getsentry/sentry-go v0.22.0 diff --git a/go.sum b/go.sum index f3c76441..77d11b01 100644 --- a/go.sum +++ b/go.sum @@ -50,6 +50,8 @@ github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuP github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/basenana/friday v0.0.0-20240514065549-962b40b3faf1 h1:8vEnxa3CEbTtEqJIk44ZzhvZG+6DuJlLO+Mfa9psuAs= github.com/basenana/friday v0.0.0-20240514065549-962b40b3faf1/go.mod h1:izPqWGUN5Kxz6mb7Xwhz3Zoe5GdOa/O+Ht6VhIHmjaA= +github.com/basenana/friday v0.0.0-20240529034549-d8f6efb8f215 h1:kCuKG91mxPXnkbsmFiRLBlni6qtmwwFnhe0Bm1O2cD4= +github.com/basenana/friday v0.0.0-20240529034549-d8f6efb8f215/go.mod h1:izPqWGUN5Kxz6mb7Xwhz3Zoe5GdOa/O+Ht6VhIHmjaA= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index b0e0772e..0bbdde8a 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -87,6 +87,7 @@ type Controller interface { ListRooms(ctx context.Context, entryId int64) ([]*types.Room, error) CreateRoom(ctx context.Context, entryId int64, prompt string) (*types.Room, error) GetRoom(ctx context.Context, id int64) (*types.Room, error) + FindRoom(ctx context.Context, entryId int64) (*types.Room, error) UpdateRoom(ctx context.Context, roomId int64, prompt string) error DeleteRoom(ctx context.Context, id int64) error ClearRoom(ctx context.Context, id int64) error diff --git a/pkg/controller/dialogue.go b/pkg/controller/dialogue.go index 939430a1..983f2fb7 100644 --- a/pkg/controller/dialogue.go +++ b/pkg/controller/dialogue.go @@ -63,6 +63,15 @@ func (c *controller) GetRoom(ctx context.Context, id int64) (*types.Room, error) return result, nil } +func (c *controller) FindRoom(ctx context.Context, entryId int64) (*types.Room, error) { + result, err := c.dialogue.FindRoom(ctx, entryId) + if err != nil { + c.logger.Errorw("find room failed", "err", err) + return nil, err + } + return result, nil +} + func (c *controller) DeleteRoom(ctx context.Context, id int64) error { err := c.dialogue.DeleteRoom(ctx, id) if err != nil { @@ -88,7 +97,13 @@ func (c *controller) ClearRoom(ctx context.Context, id int64) error { } func (c *controller) CreateRoomMessage(ctx context.Context, roomID int64, sender, msg string, sendAt time.Time) (*types.RoomMessage, error) { + room, err := c.dialogue.GetRoom(ctx, roomID) + if err != nil { + c.logger.Errorw("get room failed", "err", err) + return nil, err + } result, err := c.dialogue.SaveMessage(ctx, &types.RoomMessage{ + Namespace: room.Namespace, RoomID: roomID, Sender: sender, Message: msg, @@ -183,11 +198,12 @@ func (c *controller) ChatInRoom(ctx context.Context, roomId int64, newMsg string // save model message response, err := c.dialogue.SaveMessage(ctx, &types.RoomMessage{ - ID: responseMsgId, - RoomID: roomId, - Sender: model, - Message: respMsg, - SendAt: time.Now(), + ID: responseMsgId, + Namespace: room.Namespace, + RoomID: roomId, + Sender: model, + Message: respMsg, + SendAt: time.Now(), }) if err != nil { c.logger.Errorw("save message failed", "err", err) diff --git a/pkg/dialogue/dialogue.go b/pkg/dialogue/dialogue.go index cedb335a..5fb4e59c 100644 --- a/pkg/dialogue/dialogue.go +++ b/pkg/dialogue/dialogue.go @@ -35,6 +35,7 @@ type Manager interface { CreateRoom(ctx context.Context, entryId int64, prompt string) (*types.Room, error) UpdateRoom(ctx context.Context, room *types.Room) error GetRoom(ctx context.Context, id int64) (*types.Room, error) + FindRoom(ctx context.Context, entryId int64) (*types.Room, error) DeleteRoom(ctx context.Context, id int64) error DeleteRoomMessages(ctx context.Context, roomId int64) error SaveMessage(ctx context.Context, roomMessage *types.RoomMessage) (*types.RoomMessage, error) @@ -109,6 +110,19 @@ func (m *manager) GetRoom(ctx context.Context, id int64) (*types.Room, error) { return room, nil } +func (m *manager) FindRoom(ctx context.Context, entryId int64) (*types.Room, error) { + room, err := m.recorder.FindRoom(ctx, entryId) + if err != nil { + return nil, err + } + msgs, err := m.recorder.ListRoomMessage(ctx, room.ID) + if err != nil { + return nil, err + } + room.Messages = msgs + return room, nil +} + func (m *manager) DeleteRoom(ctx context.Context, id int64) error { err := m.recorder.DeleteRoomMessages(ctx, id) if err != nil { @@ -132,7 +146,6 @@ func (m *manager) SaveMessage(ctx context.Context, roomMessage *types.RoomMessag if crtMsg == nil { roomMessage.CreatedAt = time.Now() - roomMessage.Namespace = types.GetNamespace(ctx).String() return roomMessage, m.recorder.SaveRoomMessage(ctx, roomMessage) } diff --git a/pkg/metastore/instrumental.go b/pkg/metastore/instrumental.go index 839dbb82..2ec9874c 100644 --- a/pkg/metastore/instrumental.go +++ b/pkg/metastore/instrumental.go @@ -530,6 +530,14 @@ func (i instrumentalStore) GetRoom(ctx context.Context, id int64) (*types.Room, return room, err } +func (i instrumentalStore) FindRoom(ctx context.Context, entryId int64) (*types.Room, error) { + const operation = "find_room" + defer logOperationLatency(operation, time.Now()) + room, err := i.store.FindRoom(ctx, entryId) + logOperationError(operation, err) + return room, err +} + func (i instrumentalStore) DeleteRoom(ctx context.Context, id int64) error { const operation = "delete_room" defer logOperationLatency(operation, time.Now()) diff --git a/pkg/metastore/interface.go b/pkg/metastore/interface.go index 4498bf12..8bfa9887 100644 --- a/pkg/metastore/interface.go +++ b/pkg/metastore/interface.go @@ -83,6 +83,7 @@ type DEntry interface { SaveRoom(ctx context.Context, room *types.Room) error GetRoom(ctx context.Context, id int64) (*types.Room, error) + FindRoom(ctx context.Context, entryId int64) (*types.Room, error) DeleteRoom(ctx context.Context, id int64) error ListRooms(ctx context.Context, entryId int64) ([]*types.Room, error) ListRoomMessage(ctx context.Context, roomId int64) ([]*types.RoomMessage, error) diff --git a/pkg/metastore/sql.go b/pkg/metastore/sql.go index af7c6443..6e6f2957 100644 --- a/pkg/metastore/sql.go +++ b/pkg/metastore/sql.go @@ -410,6 +410,12 @@ func (s *sqliteMetaStore) GetRoom(ctx context.Context, id int64) (*types.Room, e return s.dbStore.GetRoom(ctx, id) } +func (s *sqliteMetaStore) FindRoom(ctx context.Context, entryId int64) (*types.Room, error) { + s.mux.Lock() + defer s.mux.Unlock() + return s.dbStore.FindRoom(ctx, entryId) +} + func (s *sqliteMetaStore) DeleteRoom(ctx context.Context, id int64) error { s.mux.Lock() defer s.mux.Unlock() @@ -2066,6 +2072,16 @@ func (s *sqlMetaStore) GetRoom(ctx context.Context, id int64) (*types.Room, erro return room.To() } +func (s *sqlMetaStore) FindRoom(ctx context.Context, entryId int64) (*types.Room, error) { + defer trace.StartRegion(ctx, "metastore.sql.FindRoom").End() + room := &db.Room{} + res := s.WithNamespace(ctx).Where("entry_id = ?", entryId).First(room) + if res.Error != nil { + return nil, db.SqlError2Error(res.Error) + } + return room.To() +} + func (s *sqlMetaStore) DeleteRoom(ctx context.Context, id int64) error { defer trace.StartRegion(ctx, "metastore.sql.DeleteRoom").End() err := s.WithContext(ctx).Transaction(func(tx *gorm.DB) error { diff --git a/vendor/github.com/basenana/friday/pkg/friday/friday.go b/vendor/github.com/basenana/friday/pkg/friday/friday.go index c10db34d..2cafa9f1 100644 --- a/vendor/github.com/basenana/friday/pkg/friday/friday.go +++ b/vendor/github.com/basenana/friday/pkg/friday/friday.go @@ -117,7 +117,7 @@ func (f *Friday) WithContext(ctx context.Context) *Friday { return t } -func (f *Friday) Namespace(namespace string) *Friday { - f.statement.context = context.WithValue(f.statement.context, "namespace", namespace) +func (f *Friday) Namespace(namespace *models.Namespace) *Friday { + f.statement.context = models.WithNamespace(f.statement.context, namespace) return f } diff --git a/vendor/github.com/basenana/friday/pkg/friday/question.go b/vendor/github.com/basenana/friday/pkg/friday/question.go index 643c5be2..8f7e1c19 100644 --- a/vendor/github.com/basenana/friday/pkg/friday/question.go +++ b/vendor/github.com/basenana/friday/pkg/friday/question.go @@ -123,12 +123,12 @@ func (f *Friday) Chat(res *ChatState) *Friday { } func (f *Friday) generateSystemInfo() string { - systemTemplate := "基于以下内容,简洁和专业的来回答用户的问题。答案请使用中文。\n" + systemTemplate := "你是一位知识渊博的文字工作者,负责帮用户阅读文章,基于以下内容,简洁和专业的来回答用户的问题。答案请使用中文。\n" if f.statement.Summary != "" { systemTemplate += "\n这是文章简介: {{ .Summary }}\n" } if f.statement.Info != "" { - systemTemplate += "\n这是已知内容: {{ .Info }}\n" + systemTemplate += "\n这是相关的已知内容: {{ .Info }}\n" } if f.statement.HistorySummary != "" { systemTemplate += "\n这是历史聊天总结作为前情提要: {{ .HistorySummary }}\n" diff --git a/vendor/github.com/basenana/friday/pkg/models/namespace.go b/vendor/github.com/basenana/friday/pkg/models/namespace.go new file mode 100644 index 00000000..61115bfd --- /dev/null +++ b/vendor/github.com/basenana/friday/pkg/models/namespace.go @@ -0,0 +1,51 @@ +/* + Copyright 2023 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package models + +import "context" + +const ( + NamespaceKey = "namespace" + DefaultNamespaceValue = "global" // TODO: using 'public' + GlobalNamespaceValue = "global" +) + +type Namespace struct { + name string +} + +func NewNamespace(name string) *Namespace { + return &Namespace{name: name} +} + +func (n *Namespace) String() string { + return n.name +} + +func GetNamespace(ctx context.Context) (ns *Namespace) { + ns = &Namespace{ + name: DefaultNamespaceValue, + } + if ctx.Value(NamespaceKey) != nil { + ns.name = ctx.Value(NamespaceKey).(string) + } + return +} + +func WithNamespace(ctx context.Context, ns *Namespace) context.Context { + return context.WithValue(ctx, NamespaceKey, ns.String()) +} diff --git a/vendor/github.com/basenana/friday/pkg/vectorstore/db/entity.go b/vendor/github.com/basenana/friday/pkg/vectorstore/db/entity.go index cecddf5d..2401560f 100644 --- a/vendor/github.com/basenana/friday/pkg/vectorstore/db/entity.go +++ b/vendor/github.com/basenana/friday/pkg/vectorstore/db/entity.go @@ -17,7 +17,11 @@ package db import ( + "context" + "gorm.io/gorm" + + "github.com/basenana/friday/pkg/models" ) type Entity struct { @@ -31,3 +35,11 @@ func NewDbEntity(db *gorm.DB, migrate func(db *gorm.DB) error) (*Entity, error) } return ent, nil } + +func (e *Entity) WithNamespace(ctx context.Context) *gorm.DB { + ns := models.GetNamespace(ctx) + if ns.String() == models.DefaultNamespaceValue { + return e.WithContext(ctx) + } + return e.WithContext(ctx).Where("namespace = ?", ns.String()) +} diff --git a/vendor/github.com/basenana/friday/pkg/vectorstore/postgres/postgres.go b/vendor/github.com/basenana/friday/pkg/vectorstore/postgres/postgres.go index 2e9dc6fb..86ec2c0f 100644 --- a/vendor/github.com/basenana/friday/pkg/vectorstore/postgres/postgres.go +++ b/vendor/github.com/basenana/friday/pkg/vectorstore/postgres/postgres.go @@ -128,20 +128,14 @@ func (p *PostgresClient) Store(ctx context.Context, element *models.Element, ext } func (p *PostgresClient) Search(ctx context.Context, query models.DocQuery, vectors []float32, k int) ([]*models.Doc, error) { - namespace := ctx.Value("namespace") - if namespace == nil { - namespace = defaultNamespace - } vectors64 := make([]float64, 0) for _, v := range vectors { vectors64 = append(vectors64, float64(v)) } // query from db existIndexes := make([]Index, 0) - var res *gorm.DB - res = p.dEntity.WithContext(ctx) - res = res.Where("namespace = ?", namespace) + res := p.dEntity.WithNamespace(ctx) if query.ParentId != 0 { res = res.Where("parent_entry_id = ?", query.ParentId) } @@ -183,17 +177,12 @@ func (p *PostgresClient) Search(ctx context.Context, query models.DocQuery, vect } func (p *PostgresClient) Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) { - namespace := ctx.Value("namespace") - if namespace == nil { - namespace = defaultNamespace - } vModel := &Index{} - var res *gorm.DB - if oid == 0 { - res = p.dEntity.WithContext(ctx).Where("namespace = ? AND name = ? AND idx_group = ?", namespace, name, group).First(vModel) - } else { - res = p.dEntity.WithContext(ctx).Where("namespace = ? AND name = ? AND oid = ? AND idx_group = ?", namespace, name, oid, group).First(vModel) + tx := p.dEntity.WithNamespace(ctx).Where("name = ? AND idx_group = ?", name, group) + if oid != 0 { + tx = tx.Where("oid = ?", oid) } + res := tx.First(vModel) if res.Error != nil { if res.Error == gorm.ErrRecordNotFound { return nil, nil @@ -233,3 +222,11 @@ func (d distances) Less(i, j int) bool { func (d distances) Swap(i, j int) { d[i], d[j] = d[j], d[i] } + +func namespaceQuery(ctx context.Context, tx *gorm.DB) *gorm.DB { + ns := models.GetNamespace(ctx) + if ns.String() == models.DefaultNamespaceValue { + return tx + } + return tx.Where("namespace = ?", ns.String()) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 139ea1a0..311525a5 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -136,7 +136,7 @@ github.com/aws/smithy-go/waiter ## explicit github.com/aymerick/douceur/css github.com/aymerick/douceur/parser -# github.com/basenana/friday v0.0.0-20240514065549-962b40b3faf1 +# github.com/basenana/friday v0.0.0-20240529034549-d8f6efb8f215 ## explicit; go 1.20 github.com/basenana/friday/config github.com/basenana/friday/pkg/build/withvector