Skip to content

Commit

Permalink
Scaffold multi-turn evals
Browse files Browse the repository at this point in the history
  • Loading branch information
tedspare committed Oct 9, 2024
1 parent f0c568c commit d71a86c
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
- [2024-10-09] [Scaffold multi-turn evals](https://github.com/RubricLab/memory/commit/ecb5531acef6a924b81684660150eeb71d93e704)
- [2024-10-09] [Add TSConfig](https://github.com/RubricLab/memory/commit/ed521824cc492e46adff6d38a994e18cc08166b2)
- [2024-10-09] [Extract memory to class](https://github.com/RubricLab/memory/commit/5e165608ffad822c5b77ee03f1dfc308dcb1787a)
- [2024-10-09] [Fix precision calc](https://github.com/RubricLab/memory/commit/52fc41e151c47e276c37a24b3489ba414d032a0b)
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@rubriclab/memory",
"module": "src/index.ts",
"version": "0.0.9",
"version": "0.0.10",
"private": false,
"type": "module",
"devDependencies": {
Expand Down
17 changes: 13 additions & 4 deletions src/evals/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ const args = parseArgs({
type: 'boolean',
default: false
},
dataset: {
type: 'string',
default: '1',
choices: ['1', '2']
},
help: {
type: 'boolean',
default: false
Expand All @@ -18,7 +23,8 @@ const args = parseArgs({
})

if (import.meta.path === Bun.main) {
if (args.values.help) {
const { help, fast, dataset } = args.values
if (help) {
console.log(`
Usage: bun evals/index.ts [options]
Expand All @@ -29,8 +35,11 @@ if (import.meta.path === Bun.main) {
process.exit(0)
}

const model = args.values.fast ? 'gpt-4o-mini' : 'gpt-4o-2024-08-06'
const model = fast ? 'gpt-4o-mini' : 'gpt-4o-2024-08-06'

await runOneShotExamples({ model })
await runMultiTurnExamples({ model })
if (dataset === '1') {
await runOneShotExamples({ model })
} else if (dataset === '2') {
await runMultiTurnExamples({ model })
}
}
33 changes: 32 additions & 1 deletion src/evals/multi-turn/examples.ts
Original file line number Diff line number Diff line change
@@ -1 +1,32 @@
export const EXAMPLES = []
import type { Fact } from '@/types'

type Example = {
messages: { facts: Fact[]; content: string }[]
}

export const EXAMPLES: Example[] = [
{
messages: [
{
content: 'I am vegan',
facts: [
{
subject: 'user',
relation: 'is',
object: 'vegan'
}
]
},
{
content: 'I am not vegan',
facts: [
{
subject: 'user',
relation: 'is not',
object: 'vegan'
}
]
}
]
}
]
86 changes: 84 additions & 2 deletions src/evals/multi-turn/index.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,93 @@
import { Database } from 'bun:sqlite'
import { Memory } from '@/index'
import type { Fact } from '@/types'
import { format } from '@/utils/string'
import type { openai } from '@ai-sdk/openai'
import { Memory } from '../..'
import chalk from 'chalk'
import { EXAMPLES } from './examples'

const db = new Database(':memory:', { create: true, strict: true })

await db
.prepare(
'create table if not exists facts (subject text, relation text, object text, primary key (subject, object))'
)
.get()

export const runMultiTurnExamples = async ({ model }: { model: Parameters<typeof openai>[0] }) => {
const memory = new Memory({ model })

let totalFacts = 0
let totalRecall = 0
let totalAttempts = 0

for await (const eg of EXAMPLES) {
console.log(eg)
for await (const message of eg.messages) {
totalFacts += message.facts.length

console.log(chalk.yellow(`\n\n"${message.content}"`))

const { facts: attempts } = await memory.extract({
content: message.content
})

const omitted: number[] = []

for (const [i, fact] of message.facts.entries()) {
let correctFacts = 0

console.log(
`\n🎯 ${i + 1} of ${message.facts.length}: ${chalk.magenta(fact.subject)} ${chalk.yellow(fact.relation)} ${chalk.blue(fact.object)}`
)

for (const attempt of attempts) {
const { subject, relation, object } = attempt

db
.prepare(`
insert into facts (subject, relation, object)
values ($1, $2, $3)
on conflict (subject, object) do update set relation = $2
`)
.run(subject, relation, object)
}

const newFacts = db.query('select * from facts').all()
console.log({ newFacts })

for (const [k, newFact] of newFacts.entries()) {
const { subject, relation, object } = newFact as Fact

const correctSubject = fact.subject === subject
const correctRelation = fact.relation === relation
const correctObject = fact.object === object

console.log(
`🤖 ${k + 1} of ${newFacts.length}: ${chalk.magenta(format(subject, correctSubject))} ${chalk.yellow(
format(relation, correctRelation)
)} ${chalk.blue(format(object, correctObject))}`
)

if (omitted.includes(k)) continue

correctFacts += Number(correctSubject && correctRelation && correctObject)

if (correctFacts) {
omitted.push(k)
break
}
}
totalRecall += correctFacts
}

totalAttempts += attempts.length
}
}

console.log(
`\n\nPrecision (% of attempts true): ${totalRecall} of ${totalAttempts} ${chalk.green(`(${~~((totalRecall / totalAttempts) * 100)}%)`)}`
)
console.log(
`Recall (% of total facts correctly returned): ${totalRecall} of ${totalFacts} ${chalk.green(`(${~~((totalRecall / totalFacts) * 100)}%)`)}`
)
}
6 changes: 3 additions & 3 deletions src/evals/one-shot/index.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import { Memory } from '@/'
import { Memory } from '@/index'
import { format } from '@/utils/string'
import type { openai } from '@ai-sdk/openai'
import chalk from 'chalk'
import { EXAMPLES } from './examples'

export const runOneShotExamples = async ({ model }: { model: Parameters<typeof openai>[0] }) => {
const memory = new Memory({ model })

let totalFacts = 0
let totalRecall = 0
let totalAttempts = 0

const memory = new Memory({ model })

for await (const eg of EXAMPLES) {
totalFacts += eg.facts.length

Expand Down
1 change: 1 addition & 0 deletions tsconfig.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"compilerOptions": {
"target": "ESNext",
"baseUrl": ".",
"paths": {
"@/*": ["./src/*"]
Expand Down

0 comments on commit d71a86c

Please sign in to comment.