Skip to content

Commit

Permalink
Parallelize embedding, upsertion, and update for perf
Browse files Browse the repository at this point in the history
  • Loading branch information
tedspare committed Oct 15, 2024
1 parent 8828bca commit 1bccd2c
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 42 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
- [2024-10-15] [Parallelize embedding, upsertion, and update for perf](https://github.com/RubricLab/memory/commit/38f574d2ba7c00f74f9a97a8e21e05315b47f118)
- [2024-10-15] [Clean up](https://github.com/RubricLab/memory/commit/0d5f66305ae9c973f7667576b8b7a966c67acbd7)
- [2024-10-15] [Require userId in getAll](https://github.com/RubricLab/memory/commit/17a1f03ce3817a9aed9391380ff340db15214046)
- [2024-10-11] [Add ability to correct old facts](https://github.com/RubricLab/memory/commit/2fc23586e30ca4e5366cada37c2e04d8647db3e5)
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "@rubriclab/memory",
"module": "src/index.ts",
"main": "src/index.ts",
"version": "0.0.24",
"version": "0.0.25",
"private": false,
"type": "module",
"devDependencies": {
Expand Down
8 changes: 0 additions & 8 deletions prisma/sql/insertVector.sql

This file was deleted.

86 changes: 53 additions & 33 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import { openai } from '@ai-sdk/openai'
import { PrismaClient } from '@prisma/client'
import { Prisma, PrismaClient } from '@prisma/client'
import {
createVectorExtension,
createVectorIndex,
insertVector,
searchVector,
setMaxWorkers,
setWorkerMemory
Expand Down Expand Up @@ -93,27 +92,45 @@ export class Memory {
return { tags }
}

async embed(text: string): Promise<number[]> {
embed(props: { text: string }): Promise<number[]>
embed(props: { texts: string[] }): Promise<number[][]>
async embed(props: { text: string } | { texts: string[] }) {
const texts = 'texts' in props ? props.texts : [props.text]

const res = await openai
.embedding(this.embeddingsModel, {
dimensions: this.embeddingsDimension
})
.doEmbed({ values: [text] })
.doEmbed({ values: texts })

if (!res) throw 'Failed to reach embeddings API'

const { embeddings } = res

if (!embeddings[0]) throw 'No embedding found'
if (!embeddings || embeddings?.length === 0) throw 'No embedding found'

return embeddings[0]
return 'texts' in props ? embeddings : embeddings[0]
}

async insert(body: string, { userId = this.userId }: { userId?: string }) {
const id = uid()
const vector = await this.embed(body)
const sql = insertVector(id, body, vector as unknown as string, userId)
const inserted: { id: string }[] = await this.db.$queryRawTyped(sql)
async insert(
props: { tag: string } | { tags: string[] },
{ userId = this.userId }: { userId?: string }
) {
const tags = 'tags' in props ? props.tags : [props.tag]
const ids = tags.map(() => uid())
const vectors = await this.embed({ texts: tags })
const vectorStrings = vectors.map(v => JSON.stringify(v))
const rows = tags.map((t, i) => `('${ids[i]}', '${t}', '${vectorStrings[i]}', ${userId})`)

const template = clean`insert into tag (id, body, vector, "userId")
values
${rows.join(', ')}
on conflict (body, "userId") do nothing
returning id`

const query = Prisma.sql([template])

const inserted: { id: string }[] = await this.db.$queryRaw(query)

return inserted
}
Expand All @@ -122,7 +139,7 @@ export class Memory {
query: string,
{ threshold = 0.5, limit = 10, userId = this.userId }
): Promise<searchVector.Result[]> {
const vector = await this.embed(query)
const vector = await this.embed({ text: query })
const sql = searchVector(vector as unknown as string, threshold, limit, userId)
const res = await this.db.$queryRawTyped(sql)

Expand Down Expand Up @@ -181,20 +198,18 @@ export class Memory {
const netNewTags = uniqueTags.filter(t => !similarTags.some(s => s?.body === t))
console.log('netNewTags', netNewTags)

// TODO: use the `insert()` SQL command to return newly-created IDs
// TODO: allow passing tag[] to insert
const tagsInserted = await Promise.all(netNewTags.map(tag => this.insert(tag, { userId })))
const netNewTagIds = tagsInserted.flatMap(t => t[0]?.id || [])
const tagsInserted = await this.insert({ tags: netNewTags }, { userId })
const netNewTagIds = tagsInserted.map(t => t.id)

console.log(`insert completed: ${(performance.now() - start).toFixed(2)}ms`)

const allTagIds = [...uniqueSimilarTagIds, ...netNewTagIds]
const combinedTagIds = [...uniqueSimilarTagIds, ...netNewTagIds]

const relatedFacts = await this.db.relationship.findMany({
where: {
tag: {
id: {
in: allTagIds
in: combinedTagIds
}
}
},
Expand Down Expand Up @@ -229,26 +244,32 @@ export class Memory {
)
.optional()
}),
prompt: clean`Given the following statements and some new information, please identify any statements which should be updated.
Statements are generally single-subject, single-verb sentences.
Corrections should be made only when a statement is explicitly wrong.
prompt: clean`Given the following statements and some new information, please identify any statements which have been falsified.
Prior statements:
"${relatedFacts.map((r, i) => `${i + 1}. ${r.fact.body}`).join('\n')}"
"""
${relatedFacts.map((r, i) => `${i}. ${r.fact.body}`).join('\n')}
"""
New information:
"${facts.map(f => `- ${f}`).join('\n')}"
If nothing needs to be corrected, return an empty array.`
"""
${facts.map(f => `- ${f}`).join('\n')}
"""
It is not always necessary to update the returned statements.
`
})

console.log('corrections', corrections)
console.log(`made corrections: ${(performance.now() - start).toFixed(2)}ms`)

if (corrections) {
for (const { index, newStatement } of corrections) {
if (!relatedFacts[index - 1]) continue
const updates =
corrections?.flatMap(({ index, newStatement }) => {
const outdated = relatedFacts[index]

const outdated = relatedFacts[index - 1] as { id: string }
if (!outdated) return []

await this.db.relationship.update({
return this.db.relationship.update({
where: {
id: outdated.id
},
Expand All @@ -260,11 +281,10 @@ export class Memory {
}
}
})
}
}
}) || []

const creates = facts.flatMap(fact => {
return allTagIds.map(tagID =>
return combinedTagIds.map(tagID =>
this.db.relationship.create({
data: {
fact: {
Expand Down Expand Up @@ -293,7 +313,7 @@ export class Memory {
)
})

const created = await this.db.$transaction(creates)
const created = await this.db.$transaction([...updates, ...creates])

console.log(chalk.green(`Added ${created?.length} facts`))

Expand Down

0 comments on commit 1bccd2c

Please sign in to comment.