diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ebeed8..206b931 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +- [2024-10-10] [Expand on unique tagging](https://github.com/RubricLab/memory/commit/3c550db0b9cce3ce9f54fc56ecc6716d8475b397) - [2024-10-10] [Add typesafe SQL setup](https://github.com/RubricLab/memory/commit/9acd10a3be5e0335970ca8f551a03847e04fcdda) - [2024-10-10] [Use relative import](https://github.com/RubricLab/memory/commit/2161e8cc919ac63da4b092b61ff811566417bc08) - [2024-10-10] [Fix post-insert return](https://github.com/RubricLab/memory/commit/a79a936cff2baadbd5411caec354579b47d8c846) diff --git a/package.json b/package.json index e391b99..ec77492 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "name": "@rubriclab/memory", "module": "src/index.ts", "main": "src/index.ts", - "version": "0.0.18", + "version": "0.0.19", "private": false, "type": "module", "devDependencies": { diff --git a/prisma/schema/facts.prisma b/prisma/schema/facts.prisma index 68e3e13..0337aa5 100644 --- a/prisma/schema/facts.prisma +++ b/prisma/schema/facts.prisma @@ -1,7 +1,7 @@ model fact { id String @id @default(nanoid()) - body String + body String @unique tags relationship[] diff --git a/src/evals/multi-turn/index.ts b/src/evals/multi-turn/index.ts index b34d768..8612614 100644 --- a/src/evals/multi-turn/index.ts +++ b/src/evals/multi-turn/index.ts @@ -12,6 +12,10 @@ export const runMultiTurnExamples = async ({ db, model }: { model: LLM; db: Data let totalAttempts = 0 for await (const eg of EXAMPLES.slice(0, 1)) { + await db.tag.deleteMany() + await db.fact.deleteMany() + await db.relationship.deleteMany() + for await (const message of eg.messages) { totalFacts += message.facts.length @@ -19,7 +23,6 @@ export const runMultiTurnExamples = async ({ db, model }: { model: LLM; db: Data // Clean up DB in between conversations const omitted: number[] = [] - // await db.fact.deleteMany() const { facts: attempts } = await memory.ingest({ content: message.content diff --git a/src/index.ts b/src/index.ts index d230fea..7f0ab32 100644 --- a/src/index.ts +++ b/src/index.ts @@ -115,7 +115,10 @@ export class Memory { return inserted } - async search(query: string, { threshold, limit } = { threshold: 0.5, limit: 10 }) { + async search( + query: string, + { threshold, limit } = { threshold: 0.5, limit: 10 } + ): Promise { const vector = await this.embed(query) // biome-ignore lint/suspicious/noExplicitAny: @@ -130,23 +133,70 @@ export class Memory { const uniqueTags = [...new Set(tags)] - await Promise.all(uniqueTags.map(tag => this.insert(tag))) - const similarTags = await Promise.all(uniqueTags.map(tag => this.search(tag))) + const similarTagIds = similarTags.flatMap(t => t[0]?.id || []) + console.log('similarTags', similarTags) + + const netNewTags = uniqueTags.filter(t => !similarTags.some(s => s[0]?.body === t)) + console.log({ netNewTags }) + + const tagsInserted = await Promise.all(netNewTags.map(tag => this.insert(tag))) + const netNewTagIds = tagsInserted.flatMap(t => t[0]?.id || []) + + if (!netNewTagIds?.[0]) throw 'Failed to insert tags' + + const allTagIds = [...similarTagIds, ...netNewTagIds] + + const creates = facts.flatMap(fact => { + return allTagIds.map(tagID => + this.db.relationship.create({ + data: { + fact: { + connectOrCreate: { + where: { + body: fact + }, + create: { + body: fact + } + } + }, + tag: { + connect: { + id: tagID + } + } + } + }) + ) + }) - console.log(similarTags) + const created = await this.db.$transaction(creates) + console.log({ created }) const relatedFacts = await this.db.relationship.findMany({ where: { tag: { - body: { - in: similarTags.map(tag => tag[0]?.body || '') + id: { + in: allTagIds + } + } + }, + select: { + fact: { + select: { + body: true + } + }, + tag: { + select: { + body: true } } } }) - console.log({ relatedFacts }) + console.log('relatedFacts', relatedFacts) return { tags, facts } }