From 1bccd2c0fcb3f4205e2861e3a628bafd68763aa5 Mon Sep 17 00:00:00 2001 From: tedspare Date: Tue, 15 Oct 2024 12:25:16 -0400 Subject: [PATCH] Parallelize embedding, upsertion, and update for perf --- CHANGELOG.md | 1 + package.json | 2 +- prisma/sql/insertVector.sql | 8 ---- src/index.ts | 86 +++++++++++++++++++++++-------------- 4 files changed, 55 insertions(+), 42 deletions(-) delete mode 100644 prisma/sql/insertVector.sql diff --git a/CHANGELOG.md b/CHANGELOG.md index 6acd6fe..7b5a335 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +- [2024-10-15] [Parallelize embedding, upsertion, and update for perf](https://github.com/RubricLab/memory/commit/38f574d2ba7c00f74f9a97a8e21e05315b47f118) - [2024-10-15] [Clean up](https://github.com/RubricLab/memory/commit/0d5f66305ae9c973f7667576b8b7a966c67acbd7) - [2024-10-15] [Require userId in getAll](https://github.com/RubricLab/memory/commit/17a1f03ce3817a9aed9391380ff340db15214046) - [2024-10-11] [Add ability to correct old facts](https://github.com/RubricLab/memory/commit/2fc23586e30ca4e5366cada37c2e04d8647db3e5) diff --git a/package.json b/package.json index 5b35c6c..13e4204 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "name": "@rubriclab/memory", "module": "src/index.ts", "main": "src/index.ts", - "version": "0.0.24", + "version": "0.0.25", "private": false, "type": "module", "devDependencies": { diff --git a/prisma/sql/insertVector.sql b/prisma/sql/insertVector.sql deleted file mode 100644 index fd78d40..0000000 --- a/prisma/sql/insertVector.sql +++ /dev/null @@ -1,8 +0,0 @@ --- @param {String} $1:id --- @param {String} $2:body --- @param $3:vector --- @param {String} $4:userId -insert into tag (id, body, vector, "userId") -values ($1, $2, $3, $4) -on conflict (body, "userId") do nothing -returning id \ No newline at end of file diff --git a/src/index.ts b/src/index.ts index f12158a..93f5bc3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,9 +1,8 @@ import { openai } from '@ai-sdk/openai' -import { PrismaClient } from '@prisma/client' +import { Prisma, PrismaClient } from '@prisma/client' import { createVectorExtension, createVectorIndex, - insertVector, searchVector, setMaxWorkers, setWorkerMemory @@ -93,27 +92,45 @@ export class Memory { return { tags } } - async embed(text: string): Promise { + embed(props: { text: string }): Promise + embed(props: { texts: string[] }): Promise + async embed(props: { text: string } | { texts: string[] }) { + const texts = 'texts' in props ? props.texts : [props.text] + const res = await openai .embedding(this.embeddingsModel, { dimensions: this.embeddingsDimension }) - .doEmbed({ values: [text] }) + .doEmbed({ values: texts }) if (!res) throw 'Failed to reach embeddings API' const { embeddings } = res - if (!embeddings[0]) throw 'No embedding found' + if (!embeddings || embeddings?.length === 0) throw 'No embedding found' - return embeddings[0] + return 'texts' in props ? embeddings : embeddings[0] } - async insert(body: string, { userId = this.userId }: { userId?: string }) { - const id = uid() - const vector = await this.embed(body) - const sql = insertVector(id, body, vector as unknown as string, userId) - const inserted: { id: string }[] = await this.db.$queryRawTyped(sql) + async insert( + props: { tag: string } | { tags: string[] }, + { userId = this.userId }: { userId?: string } + ) { + const tags = 'tags' in props ? props.tags : [props.tag] + const ids = tags.map(() => uid()) + const vectors = await this.embed({ texts: tags }) + const vectorStrings = vectors.map(v => JSON.stringify(v)) + const rows = tags.map((t, i) => `('${ids[i]}', '${t}', '${vectorStrings[i]}', ${userId})`) + + const template = clean`insert into tag (id, body, vector, "userId") + values + ${rows.join(', ')} + on conflict (body, "userId") do nothing + returning id` + + const query = Prisma.sql([template]) + + const inserted: { id: string }[] = await this.db.$queryRaw(query) return inserted } @@ -122,7 +139,7 @@ export class Memory { query: string, { threshold = 0.5, limit = 10, userId = this.userId } ): Promise { - const vector = await this.embed(query) + const vector = await this.embed({ text: query }) const sql = searchVector(vector as unknown as string, threshold, limit, userId) const res = await this.db.$queryRawTyped(sql) @@ -181,20 +198,18 @@ export class Memory { const netNewTags = uniqueTags.filter(t => !similarTags.some(s => s?.body === t)) console.log('netNewTags', netNewTags) - // TODO: use the `insert()` SQL command to return newly-created IDs - // TODO: allow passing tag[] to insert - const tagsInserted = await Promise.all(netNewTags.map(tag => this.insert(tag, { userId }))) - const netNewTagIds = tagsInserted.flatMap(t => t[0]?.id || []) + const tagsInserted = await this.insert({ tags: netNewTags }, { userId }) + const netNewTagIds = tagsInserted.map(t => t.id) console.log(`insert completed: ${(performance.now() - start).toFixed(2)}ms`) - const allTagIds = [...uniqueSimilarTagIds, ...netNewTagIds] + const combinedTagIds = [...uniqueSimilarTagIds, ...netNewTagIds] const relatedFacts = await this.db.relationship.findMany({ where: { tag: { id: { - in: allTagIds + in: combinedTagIds } } }, @@ -229,26 +244,32 @@ export class Memory { ) .optional() }), - prompt: clean`Given the following statements and some new information, please identify any statements which should be updated. - Statements are generally single-subject, single-verb sentences. - Corrections should be made only when a statement is explicitly wrong. + prompt: clean`Given the following statements and some new information, please identify any statements which have been falsified. + Prior statements: - "${relatedFacts.map((r, i) => `${i + 1}. ${r.fact.body}`).join('\n')}" + """ + ${relatedFacts.map((r, i) => `${i}. ${r.fact.body}`).join('\n')} + """ + New information: - "${facts.map(f => `- ${f}`).join('\n')}" - If nothing needs to be corrected, return an empty array.` + """ + ${facts.map(f => `- ${f}`).join('\n')} + """ + + It is not always necessary to update the returned statements. + ` }) console.log('corrections', corrections) console.log(`made corrections: ${(performance.now() - start).toFixed(2)}ms`) - if (corrections) { - for (const { index, newStatement } of corrections) { - if (!relatedFacts[index - 1]) continue + const updates = + corrections?.flatMap(({ index, newStatement }) => { + const outdated = relatedFacts[index] - const outdated = relatedFacts[index - 1] as { id: string } + if (!outdated) return [] - await this.db.relationship.update({ + return this.db.relationship.update({ where: { id: outdated.id }, @@ -260,11 +281,10 @@ export class Memory { } } }) - } - } + }) || [] const creates = facts.flatMap(fact => { - return allTagIds.map(tagID => + return combinedTagIds.map(tagID => this.db.relationship.create({ data: { fact: { @@ -293,7 +313,7 @@ export class Memory { ) }) - const created = await this.db.$transaction(creates) + const created = await this.db.$transaction([...updates, ...creates]) console.log(chalk.green(`Added ${created?.length} facts`))