Skip to content

Commit

Permalink
User-scope everything
Browse files Browse the repository at this point in the history
  • Loading branch information
tedspare committed Oct 11, 2024
1 parent 686a932 commit 2e049b5
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 46 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
- [2024-10-11] [User-scope everything](https://github.com/RubricLab/memory/commit/621f063252eff5d8999173494b4c17f223f0d903)
- [2024-10-11] [Add getAll route](https://github.com/RubricLab/memory/commit/71eb1f5d7afbb26390d2b2617081b46301e9022c)
- [2024-10-10] [Expand on unique tagging](https://github.com/RubricLab/memory/commit/3c550db0b9cce3ce9f54fc56ecc6716d8475b397)
- [2024-10-10] [Add typesafe SQL setup](https://github.com/RubricLab/memory/commit/9acd10a3be5e0335970ca8f551a03847e04fcdda)
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "@rubriclab/memory",
"module": "src/index.ts",
"main": "src/index.ts",
"version": "0.0.20",
"version": "0.0.21",
"private": false,
"type": "module",
"devDependencies": {
Expand Down
14 changes: 12 additions & 2 deletions prisma/schema/facts.prisma
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
model fact {
id String @id @default(nanoid())
body String @unique
body String
tags relationship[]
userId String
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
@@unique([userId, body])
}

model tag {
id String @id @default(nanoid())
body String @unique
body String
vector Unsupported("vector(768)")?
userId String
facts relationship[]
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
@@unique([userId, body])
}

model relationship {
Expand All @@ -31,6 +39,8 @@ model relationship {
fact fact @relation(fields: [factId], references: [id], onDelete: Cascade)
tag tag @relation(fields: [tagId], references: [id], onDelete: Cascade)
userId String
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
}
7 changes: 4 additions & 3 deletions prisma/sql/insertVector.sql
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
-- @param {String} $1:id
-- @param {String} $2:body
-- @param $3:vector
insert into tag (id, body, vector)
values ($1, $2, $3)
on conflict (body) do nothing
-- @param {String} $4:userId
insert into tag (id, body, vector, "userId")
values ($1, $2, $3, $4)
on conflict (body, "userId") do nothing
returning id
2 changes: 2 additions & 0 deletions prisma/sql/searchVector.sql
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
-- @param $1:vector
-- @param {Int} $2:threshold
-- @param {Float} $3:limit
-- @param {String} $4:userId
select id, body, similarity
from (
select id, body, 1 - (vector <=> $1::vector) as similarity
from tag
where "userId" = $4
) as subquery
where similarity > $2
order by similarity desc
Expand Down
90 changes: 50 additions & 40 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { openai } from '@ai-sdk/openai'
import { PrismaClient } from '@prisma/client'
import {
createVectorExtension,
createVectorIndex,
Expand All @@ -8,6 +9,7 @@ import {
setWorkerMemory
} from '@prisma/client/sql'
import { generateObject } from 'ai'
import chalk from 'chalk'
import { z } from 'zod'
import type { Database, LLM } from './types'
import { clean } from './utils/string'
Expand All @@ -16,8 +18,9 @@ import { uid } from './utils/uid'
export class Memory {
model: LLM
db: Database
EMBEDDINGS_MODEL = 'text-embedding-3-small'
EMBEDDINGS_DIMENSIONS = 768
userId: string
embeddingsModel: string
embeddingsDimension: number

async initVectorIndex() {
await this.db.$queryRawTyped(createVectorExtension())
Expand All @@ -30,13 +33,22 @@ export class Memory {

constructor({
model = 'gpt-4o-mini',
db
db = new PrismaClient(),
userId = uid(),
embeddingsModel = 'text-embedding-3-small',
embeddingsDimension = 768
}: {
model?: LLM
db: Database
}) {
db?: Database
userId?: string
embeddingsModel?: string
embeddingsDimension?: number
} = {}) {
this.model = model
this.db = db
this.userId = userId
this.embeddingsModel = embeddingsModel
this.embeddingsDimension = embeddingsDimension

// TODO: figure out how to do this once during setup
// this.initVectorIndex()
Expand All @@ -55,42 +67,36 @@ export class Memory {
)
}),
prompt: clean`Please extract all facts from the following passage.
Portray the first-person as "user".
In case of contradiction, try to capture the most up-to-date state of affairs in present tense.
Passage:
"${content}"`
In case of first-person statements, portray the first-person as "user".
In case of contradiction, try to capture the most up-to-date state of affairs in present tense.
Passage:
"${content}"`
})

return { facts: facts.map(({ body }) => body) }
}

async extractTags({ content }: { content: string }): Promise<{ tags: string[] }> {
const {
object: { entities }
object: { entities: tags }
} = await generateObject({
model: openai(this.model),
schema: z.object({
entities: z.array(
z.object({
body: z.string()
})
)
entities: z.array(z.string())
}),
prompt: clean`Please extract all entities (subjects, objects, and general metaphysical concepts) from the following passage.
Portray the first-person as "user".
In case of first-person statements, portray the first-person as "user".
Passage:
"${content}"`
})

const tags = entities?.map(({ body }) => body)

return { tags }
}

async embed(text: string): Promise<number[]> {
const res = await openai
.embedding(this.EMBEDDINGS_MODEL, {
dimensions: this.EMBEDDINGS_DIMENSIONS
.embedding(this.embeddingsModel, {
dimensions: this.embeddingsDimension
})
.doEmbed({ values: [text] })

Expand All @@ -103,32 +109,30 @@ export class Memory {
return embeddings[0]
}

async insert(body: string) {
const vector = await this.embed(body)

async insert(body: string, { userId = this.userId }: { userId?: string }) {
const id = uid()
const vector = await this.embed(body)

// biome-ignore lint/suspicious/noExplicitAny: <explanation>
const query = insertVector(id, body, vector as any)
const inserted: insertVector.Result[] = await this.db.$queryRawTyped(query)
const sql = insertVector(id, body, vector as any, userId || this.userId)
const inserted: insertVector.Result[] = await this.db.$queryRawTyped(sql)

return inserted
}

async search(
query: string,
{ threshold, limit } = { threshold: 0.5, limit: 10 }
{ threshold = 0.5, limit = 10, userId = this.userId }
): Promise<searchVector.Result[]> {
const vector = await this.embed(query)

// biome-ignore lint/suspicious/noExplicitAny: <explanation>
const res = await this.db.$queryRawTyped(searchVector(vector as any, threshold, limit))
const sql = searchVector(vector as any, threshold, limit, userId)
const res = await this.db.$queryRawTyped(sql)
return res
}

async getAll(args?: { where: { tags: string[] } }) {
const { where } = args || {}

async getAll({ where, userId = this.userId }: { where?: { tags: string[] }; userId?: string }) {
const relations = await this.db.relationship.findMany({
where: {
...(where?.tags
Expand All @@ -139,7 +143,8 @@ export class Memory {
}
}
}
: {})
: {}),
userId
},
select: {
fact: {
Expand All @@ -158,23 +163,22 @@ export class Memory {
return relations
}

async ingest({ content }: { content: string }) {
async ingest({ content, userId = this.userId }: { content: string; userId?: string }) {
const [{ tags }, { facts }] = await Promise.all([
this.extractTags({ content }),
this.extractFacts({ content })
])

const uniqueTags = [...new Set(tags)]

// TODO: pass these back to AI to update existing facts
const similarTags = await Promise.all(uniqueTags.map(tag => this.search(tag)))
const similarTags = await Promise.all(uniqueTags.map(tag => this.search(tag, { userId })))
const similarTagIds = similarTags.flatMap(t => t[0]?.id || [])
console.log('similarTags', similarTags)

const netNewTags = uniqueTags.filter(t => !similarTags.some(s => s[0]?.body === t))
console.log({ netNewTags })

const tagsInserted = await Promise.all(netNewTags.map(tag => this.insert(tag)))
const tagsInserted = await Promise.all(netNewTags.map(tag => this.insert(tag, { userId })))
const netNewTagIds = tagsInserted.flatMap(t => t[0]?.id || [])

if (!netNewTagIds?.[0]) throw 'Failed to insert tags'
Expand All @@ -188,26 +192,32 @@ export class Memory {
fact: {
connectOrCreate: {
where: {
body: fact
userId_body: {
userId: userId,
body: fact
}
},
create: {
body: fact
body: fact,
userId: userId
}
}
},
tag: {
connect: {
id: tagID
id: tagID,
userId: userId
}
}
},
userId: userId
}
})
)
})

const created = await this.db.$transaction(creates)

console.log(`Added ${created?.length} facts`)
console.log(chalk.green(`Added ${created?.length} facts`))

// const relatedFacts = await this.db.relationship.findMany({
// where: {
Expand Down

0 comments on commit 2e049b5

Please sign in to comment.