diff --git a/CHANGELOG.md b/CHANGELOG.md index b588aaa..88a1c8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +- [2024-10-11] [User-scope everything](https://github.com/RubricLab/memory/commit/621f063252eff5d8999173494b4c17f223f0d903) - [2024-10-11] [Add getAll route](https://github.com/RubricLab/memory/commit/71eb1f5d7afbb26390d2b2617081b46301e9022c) - [2024-10-10] [Expand on unique tagging](https://github.com/RubricLab/memory/commit/3c550db0b9cce3ce9f54fc56ecc6716d8475b397) - [2024-10-10] [Add typesafe SQL setup](https://github.com/RubricLab/memory/commit/9acd10a3be5e0335970ca8f551a03847e04fcdda) diff --git a/package.json b/package.json index 8d65c2c..f4655ab 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "name": "@rubriclab/memory", "module": "src/index.ts", "main": "src/index.ts", - "version": "0.0.20", + "version": "0.0.21", "private": false, "type": "module", "devDependencies": { diff --git a/prisma/schema/facts.prisma b/prisma/schema/facts.prisma index 0337aa5..7b58751 100644 --- a/prisma/schema/facts.prisma +++ b/prisma/schema/facts.prisma @@ -1,25 +1,33 @@ model fact { id String @id @default(nanoid()) - body String @unique + body String tags relationship[] + userId String + createdAt DateTime @default(now()) updatedAt DateTime @default(now()) @updatedAt + + @@unique([userId, body]) } model tag { id String @id @default(nanoid()) - body String @unique + body String vector Unsupported("vector(768)")? + userId String + facts relationship[] createdAt DateTime @default(now()) updatedAt DateTime @default(now()) @updatedAt + + @@unique([userId, body]) } model relationship { @@ -31,6 +39,8 @@ model relationship { fact fact @relation(fields: [factId], references: [id], onDelete: Cascade) tag tag @relation(fields: [tagId], references: [id], onDelete: Cascade) + userId String + createdAt DateTime @default(now()) updatedAt DateTime @default(now()) @updatedAt } diff --git a/prisma/sql/insertVector.sql b/prisma/sql/insertVector.sql index a877440..fd78d40 100644 --- a/prisma/sql/insertVector.sql +++ b/prisma/sql/insertVector.sql @@ -1,7 +1,8 @@ -- @param {String} $1:id -- @param {String} $2:body -- @param $3:vector -insert into tag (id, body, vector) -values ($1, $2, $3) -on conflict (body) do nothing +-- @param {String} $4:userId +insert into tag (id, body, vector, "userId") +values ($1, $2, $3, $4) +on conflict (body, "userId") do nothing returning id \ No newline at end of file diff --git a/prisma/sql/searchVector.sql b/prisma/sql/searchVector.sql index 8516a0d..d5f7508 100644 --- a/prisma/sql/searchVector.sql +++ b/prisma/sql/searchVector.sql @@ -1,10 +1,12 @@ -- @param $1:vector -- @param {Int} $2:threshold -- @param {Float} $3:limit +-- @param {String} $4:userId select id, body, similarity from ( select id, body, 1 - (vector <=> $1::vector) as similarity from tag + where "userId" = $4 ) as subquery where similarity > $2 order by similarity desc diff --git a/src/index.ts b/src/index.ts index 0fb6988..697ee1a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,4 +1,5 @@ import { openai } from '@ai-sdk/openai' +import { PrismaClient } from '@prisma/client' import { createVectorExtension, createVectorIndex, @@ -8,6 +9,7 @@ import { setWorkerMemory } from '@prisma/client/sql' import { generateObject } from 'ai' +import chalk from 'chalk' import { z } from 'zod' import type { Database, LLM } from './types' import { clean } from './utils/string' @@ -16,8 +18,9 @@ import { uid } from './utils/uid' export class Memory { model: LLM db: Database - EMBEDDINGS_MODEL = 'text-embedding-3-small' - EMBEDDINGS_DIMENSIONS = 768 + userId: string + embeddingsModel: string + embeddingsDimension: number async initVectorIndex() { await this.db.$queryRawTyped(createVectorExtension()) @@ -30,13 +33,22 @@ export class Memory { constructor({ model = 'gpt-4o-mini', - db + db = new PrismaClient(), + userId = uid(), + embeddingsModel = 'text-embedding-3-small', + embeddingsDimension = 768 }: { model?: LLM - db: Database - }) { + db?: Database + userId?: string + embeddingsModel?: string + embeddingsDimension?: number + } = {}) { this.model = model this.db = db + this.userId = userId + this.embeddingsModel = embeddingsModel + this.embeddingsDimension = embeddingsDimension // TODO: figure out how to do this once during setup // this.initVectorIndex() @@ -55,10 +67,10 @@ export class Memory { ) }), prompt: clean`Please extract all facts from the following passage. - Portray the first-person as "user". - In case of contradiction, try to capture the most up-to-date state of affairs in present tense. - Passage: - "${content}"` + In case of first-person statements, portray the first-person as "user". + In case of contradiction, try to capture the most up-to-date state of affairs in present tense. + Passage: + "${content}"` }) return { facts: facts.map(({ body }) => body) } @@ -66,31 +78,25 @@ export class Memory { async extractTags({ content }: { content: string }): Promise<{ tags: string[] }> { const { - object: { entities } + object: { entities: tags } } = await generateObject({ model: openai(this.model), schema: z.object({ - entities: z.array( - z.object({ - body: z.string() - }) - ) + entities: z.array(z.string()) }), prompt: clean`Please extract all entities (subjects, objects, and general metaphysical concepts) from the following passage. - Portray the first-person as "user". + In case of first-person statements, portray the first-person as "user". Passage: "${content}"` }) - const tags = entities?.map(({ body }) => body) - return { tags } } async embed(text: string): Promise { const res = await openai - .embedding(this.EMBEDDINGS_MODEL, { - dimensions: this.EMBEDDINGS_DIMENSIONS + .embedding(this.embeddingsModel, { + dimensions: this.embeddingsDimension }) .doEmbed({ values: [text] }) @@ -103,32 +109,30 @@ export class Memory { return embeddings[0] } - async insert(body: string) { - const vector = await this.embed(body) - + async insert(body: string, { userId = this.userId }: { userId?: string }) { const id = uid() + const vector = await this.embed(body) // biome-ignore lint/suspicious/noExplicitAny: - const query = insertVector(id, body, vector as any) - const inserted: insertVector.Result[] = await this.db.$queryRawTyped(query) + const sql = insertVector(id, body, vector as any, userId || this.userId) + const inserted: insertVector.Result[] = await this.db.$queryRawTyped(sql) return inserted } async search( query: string, - { threshold, limit } = { threshold: 0.5, limit: 10 } + { threshold = 0.5, limit = 10, userId = this.userId } ): Promise { const vector = await this.embed(query) // biome-ignore lint/suspicious/noExplicitAny: - const res = await this.db.$queryRawTyped(searchVector(vector as any, threshold, limit)) + const sql = searchVector(vector as any, threshold, limit, userId) + const res = await this.db.$queryRawTyped(sql) return res } - async getAll(args?: { where: { tags: string[] } }) { - const { where } = args || {} - + async getAll({ where, userId = this.userId }: { where?: { tags: string[] }; userId?: string }) { const relations = await this.db.relationship.findMany({ where: { ...(where?.tags @@ -139,7 +143,8 @@ export class Memory { } } } - : {}) + : {}), + userId }, select: { fact: { @@ -158,7 +163,7 @@ export class Memory { return relations } - async ingest({ content }: { content: string }) { + async ingest({ content, userId = this.userId }: { content: string; userId?: string }) { const [{ tags }, { facts }] = await Promise.all([ this.extractTags({ content }), this.extractFacts({ content }) @@ -166,15 +171,14 @@ export class Memory { const uniqueTags = [...new Set(tags)] - // TODO: pass these back to AI to update existing facts - const similarTags = await Promise.all(uniqueTags.map(tag => this.search(tag))) + const similarTags = await Promise.all(uniqueTags.map(tag => this.search(tag, { userId }))) const similarTagIds = similarTags.flatMap(t => t[0]?.id || []) console.log('similarTags', similarTags) const netNewTags = uniqueTags.filter(t => !similarTags.some(s => s[0]?.body === t)) console.log({ netNewTags }) - const tagsInserted = await Promise.all(netNewTags.map(tag => this.insert(tag))) + const tagsInserted = await Promise.all(netNewTags.map(tag => this.insert(tag, { userId }))) const netNewTagIds = tagsInserted.flatMap(t => t[0]?.id || []) if (!netNewTagIds?.[0]) throw 'Failed to insert tags' @@ -188,18 +192,24 @@ export class Memory { fact: { connectOrCreate: { where: { - body: fact + userId_body: { + userId: userId, + body: fact + } }, create: { - body: fact + body: fact, + userId: userId } } }, tag: { connect: { - id: tagID + id: tagID, + userId: userId } - } + }, + userId: userId } }) ) @@ -207,7 +217,7 @@ export class Memory { const created = await this.db.$transaction(creates) - console.log(`Added ${created?.length} facts`) + console.log(chalk.green(`Added ${created?.length} facts`)) // const relatedFacts = await this.db.relationship.findMany({ // where: {