Skip to content

Commit

Permalink
Image Settings rework (#1049)
Browse files Browse the repository at this point in the history
- merge image settings into one place
- merge updates and notifications
- better control of image model config
- change default image dims
- open currently used image settings
- pass through chat image settings correctly
- refactor char form into smaller components
- add model override to generate avatar
- allow ltm to be disabled
- Make ltm off by default
  • Loading branch information
sceuick authored Oct 12, 2024
1 parent 9585df1 commit 6f9dd31
Show file tree
Hide file tree
Showing 44 changed files with 1,412 additions and 1,007 deletions.
4 changes: 2 additions & 2 deletions common/horde-gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ export async function generateImage(
const payload = {
prompt: `${prompt.slice(0, 500)} ### ${negative}`,
params: {
height: base?.height ?? 384,
width: base?.width ?? 384,
height: base?.height ?? 1024,
width: base?.width ?? 1024,
cfg_scale: base?.cfg ?? 9,
seed: Math.trunc(Math.random() * 1_000_000_000).toString(),
karras: false,
Expand Down
2 changes: 2 additions & 0 deletions common/types/admin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ export interface AppConfig {
}

export type ImageModel = {
id: string
name: string
desc: string
override: string
init: { clipSkip?: number; steps: number; cfg: number; height: number; width: number }
limit: { clipSkip?: number; steps: number; cfg: number; height: number; width: number }
}
Expand Down
6 changes: 2 additions & 4 deletions common/types/library.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { PersonaFormat } from '../adapters'
import { JsonField } from '../prompt'
import { BaseImageSettings, ImageSettings } from './image-schema'
import { ImageSettings } from './image-schema'
import { MemoryBook } from './memory'
import { FullSprite } from './sprite'
import { VoiceSettings } from './texttospeech-schema'
Expand Down Expand Up @@ -44,8 +44,6 @@ export interface Character extends BaseCharacter {
voice?: VoiceSettings
voiceDisabled?: boolean

image?: ImageSettings

json?: ResponseSchema

folder?: string
Expand All @@ -58,7 +56,7 @@ export interface Character extends BaseCharacter {
insert?: { depth: number; prompt: string }
creator?: string
characterVersion?: string
imageSettings?: BaseImageSettings
imageSettings?: ImageSettings
}

export interface LibraryCharacter extends Omit<Character, 'kind' | 'tags'> {
Expand Down
10 changes: 9 additions & 1 deletion common/types/presets.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { AIAdapter, OpenRouterModel, ThirdPartyFormat } from '../adapters'
import { ModelFormat } from '../presets/templates'
import { BaseImageSettings } from './image-schema'
import { BaseImageSettings, ImageSettings } from './image-schema'
import { ResponseSchema } from './library'

export interface SubscriptionTier {
Expand Down Expand Up @@ -197,3 +197,11 @@ export interface PromptTemplate {
createdAt: string
updatedAt: string
}

export interface ImagePreset extends ImageSettings {
kind: 'image-preset'
_id: string
userId: string
name: string
description: string
}
6 changes: 4 additions & 2 deletions common/types/schema.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { AIAdapter, ChatAdapter, ThirdPartyFormat } from '../adapters'
import * as Memory from './memory'
import type { GenerationPreset } from '../presets'
import type { BaseImageSettings, ImageSettings } from './image-schema'
import type { ImageSettings } from './image-schema'
import type { TTSSettings } from './texttospeech-schema'
import type { UISettings } from './ui'
import * as Saga from './saga'
Expand Down Expand Up @@ -92,6 +92,8 @@ export namespace AppSchema {
admin: boolean
role?: 'moderator' | 'admin'

disableLTM?: boolean

novelApiKey: string
novelModel: string
novelVerified?: boolean
Expand Down Expand Up @@ -261,7 +263,7 @@ export namespace AppSchema {
treeLeafId?: string

imageSource?: 'last-character' | 'main-character' | 'chat' | 'settings'
imageSettings?: BaseImageSettings
imageSettings?: ImageSettings
}

export interface ChatMember {
Expand Down
10 changes: 7 additions & 3 deletions srv/api/character.ts
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ const editPartCharacter = handle(async ({ body, params, userId }) => {

if (body.imageSettings) {
try {
update.imageSettings = JSON.parse(body.imageSettings)
update.imageSettings =
typeof body.imageSettings === 'string' ? JSON.parse(body.imageSettings) : body.imageSettings
} catch (ex: any) {
throw new StatusError(`Character 'imageSettings' could not be parsed: ${ex.message}`, 400)
}
Expand Down Expand Up @@ -480,12 +481,14 @@ export const createImage = handle(async ({ body, userId, socketId, log }) => {
chatId: 'string?',
requestId: 'string?',
parent: 'string?',
model: 'string?',
},
body
)
const user = userId ? await store.users.getUser(userId) : body.user

const guestId = userId ? undefined : socketId
const requestId = body.requestId || v4()
generateImage(
{
user,
Expand All @@ -495,13 +498,14 @@ export const createImage = handle(async ({ body, userId, socketId, log }) => {
noAffix: body.noAffix,
chatId: body.chatId,
characterId: body.characterId,
requestId: body.requestId,
requestId,
parentId: body.parent,
model: body.model,
},
log,
guestId
)
return { success: true }
return { success: true, requestId }
})

router.post('/image', createImage)
Expand Down
15 changes: 15 additions & 0 deletions srv/api/user/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ const validConfig = {
defaultPreset: 'string?',
chargenPreset: 'string?',
adapterConfig: 'any?',
disableLTM: 'boolean?',
} as const

/**
Expand All @@ -198,6 +199,8 @@ export const updatePartialConfig = handle(async ({ userId, body }) => {
announcement: 'string?',
defaultPreset: 'string?',
chargenPreset: 'string?',
images: 'any?',
disableLTM: 'boolean?',
},
body
)
Expand All @@ -212,6 +215,10 @@ export const updatePartialConfig = handle(async ({ userId, body }) => {
update.defaultPreset = body.defaultPreset
}

if (body.disableLTM !== undefined) {
update.disableLTM = body.disableLTM
}

if (body.chargenPreset) {
const preset = await store.presets.getUserPreset(body.chargenPreset)
if (!preset || preset.userId !== userId) {
Expand Down Expand Up @@ -256,6 +263,10 @@ export const updatePartialConfig = handle(async ({ userId, body }) => {
update.thirdPartyPassword = encryptText(body.thirdPartyPassword)
}

if (body.images) {
update.images = body.images
}

await store.users.updateUser(userId, update)
const next = await getSafeUserConfig(userId)
return next
Expand All @@ -276,6 +287,10 @@ export const updateConfig = handle(async ({ userId, body }) => {
useLocalPipeline: body.useLocalPipeline,
}

if (body.disableLTM !== undefined) {
update.disableLTM = body.disableLTM
}

if (body.hordeKey || body.hordeApiKey) {
const prevKey = prevUser.hordeKey
const incomingKey = body.hordeKey || body.hordeApiKey!
Expand Down
1 change: 1 addition & 0 deletions srv/db/user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ export async function createUser(newUser: NewUser, admin?: boolean) {
const user: AppSchema.User = {
_id: v4(),
kind: 'user',
disableLTM: true,
username,
hash,
admin: !!admin,
Expand Down
2 changes: 1 addition & 1 deletion srv/image/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ export async function generateImage(
case 'sd':
case 'agnai':
image = await handleSDImage(
{ user, prompt, negative, settings: imageSettings },
{ user, prompt, negative, settings: imageSettings, override: opts.model },
log,
guestId
)
Expand Down
47 changes: 30 additions & 17 deletions srv/image/stable-diffusion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ export type SDRequest = {
hr_scale?: number
hr_upscaler?: string
hr_second_pass_steps?: number
model_override?: string
}

export const handleSDImage: ImageAdapter = async (opts, log, guestId) => {
const config = await getConfig(opts)
const payload = getPayload(config.kind, opts, config.model)
const payload = getPayload(config.kind, opts, config.model, config.temp)

logger.debug(payload, 'Image: Stable Diffusion payload')

Expand Down Expand Up @@ -80,14 +81,16 @@ export const handleSDImage: ImageAdapter = async (opts, log, guestId) => {
return { ext: 'png', content: buffer }
}

async function getConfig({ user, settings }: ImageRequestOpts): Promise<{
async function getConfig({ user, settings, override }: ImageRequestOpts): Promise<{
kind: 'user' | 'agnai'
host: string
params?: string
model?: AppSchema.ImageModel
temp?: AppSchema.ImageModel
}> {
const type = settings?.type || user.images?.type

// Stable Diffusion URL always comes from user settings
const userHost = user.images?.sd.url || defaultSettings.url
if (type !== 'agnai') {
return { kind: 'user', host: userHost }
Expand All @@ -102,12 +105,16 @@ async function getConfig({ user, settings }: ImageRequestOpts): Promise<{
if (!sub?.tier?.imagesAccess && !user.admin) return { kind: 'user', host: userHost }

const models = getAgnaiModels(srv.imagesModels)
const model =
models.length === 1
? models[0]
: models.find((m) => m.name === user.images?.agnai?.model) ?? models[0]

if (!model) {
const temp = override ? models.find((m) => m.id === override || m.name === override) : undefined

const match = models.find((m) => {
return m.id === settings?.agnai?.model || m.name === settings?.agnai?.model
})

const model = models.length === 1 ? models[0] : match ?? models[0]

if (!temp && !model) {
return { kind: 'user', host: userHost }
}

Expand All @@ -116,35 +123,41 @@ async function getConfig({ user, settings }: ImageRequestOpts): Promise<{
`key=${config.auth.inferenceKey}`,
`id=${user._id}`,
`level=${user.admin ? 99999 : sub?.level ?? -1}`,
`model=${model.name}`,
].join('&')
`model=${temp?.name || model.name}`,
]

return { kind: 'agnai', host: srv.imagesHost, params: `?${params}`, model }
return { kind: 'agnai', host: srv.imagesHost, params: `?${params.join('&')}`, model, temp }
}

function getPayload(kind: 'agnai' | 'user', opts: ImageRequestOpts, model?: AppSchema.ImageModel) {
function getPayload(
kind: 'agnai' | 'user',
opts: ImageRequestOpts,
model: AppSchema.ImageModel | undefined,
temp: AppSchema.ImageModel | undefined
) {
const sampler =
(kind === 'agnai' ? opts.user.images?.agnai?.sampler : opts.user.images?.sd?.sampler) ||
(kind === 'agnai' ? opts.settings?.agnai?.sampler : opts.settings?.sd?.sampler) ||
defaultSettings.sampler
const payload: SDRequest = {
prompt: opts.prompt,
// enable_hr: true,
// hr_scale: 1.5,
// hr_second_pass_steps: 15,
// hr_upscaler: "",
clip_skip: opts.user.images?.clipSkip ?? model?.init.clipSkip ?? 0,
height: opts.user.images?.height ?? model?.init.height ?? 384,
width: opts.user?.images?.width ?? model?.init.width ?? 384,
clip_skip: opts.settings?.clipSkip ?? model?.init.clipSkip ?? 0,
height: opts.settings?.height ?? model?.init.height ?? 1024,
width: opts.settings?.width ?? model?.init.width ?? 1024,
n_iter: 1,
batch_size: 1,
negative_prompt: opts.negative,
sampler_name: (SD_SAMPLER_REV as any)[sampler],
cfg_scale: opts.user.images?.cfg ?? model?.init.cfg ?? 9,
cfg_scale: opts.settings?.cfg ?? model?.init.cfg ?? 9,
seed: Math.trunc(Math.random() * 1_000_000_000),
steps: opts.user.images?.steps ?? model?.init.steps ?? 28,
steps: opts.settings?.steps ?? model?.init.steps ?? 28,
restore_faces: false,
save_images: false,
send_images: true,
model_override: temp ? temp.override : model?.override,
}

if (model) {
Expand Down
6 changes: 4 additions & 2 deletions srv/image/types.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { AppSchema } from '../../common/types/schema'
import { AppLog } from '../middleware'
import { BaseImageSettings } from '/common/types/image-schema'
import { ImageSettings } from '/common/types/image-schema'

export type ImageGenerateRequest = {
user: AppSchema.User
prompt: string
chatId?: string
model?: string
messageId?: string
ephemeral?: boolean
append?: boolean
Expand All @@ -20,7 +21,8 @@ export type ImageRequestOpts = {
user: AppSchema.User
prompt: string
negative: string
settings: BaseImageSettings | undefined
settings: ImageSettings | undefined
override?: string
}

export type ImageAdapter = (
Expand Down
4 changes: 4 additions & 0 deletions web/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import SoundsPage from './pages/Sounds'
import PatreonOauth from './pages/Settings/PatreonOauth'
import { SagaDetail } from './pages/Saga/Detail'
import { SagaList } from './pages/Saga/List'
import { ImageSettingsModal } from './pages/Settings/Image/ImageSettings'

const App: Component = () => {
const state = userStore()
Expand Down Expand Up @@ -212,6 +213,9 @@ const Layout: Component<{ children?: any }> = (props) => {
<ProfileModal />
<For each={rootModals.modals}>{(modal) => modal.element}</For>
<ImageModal />
<Show when={cfg.showImgSettings}>
<ImageSettingsModal />
</Show>
<SettingsModal />
<div
class="absolute bottom-0 left-0 right-0 top-0 z-10 h-[100vh] w-full bg-black bg-opacity-20 sm:hidden"
Expand Down
25 changes: 23 additions & 2 deletions web/Layout.tsx
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@
import { Component } from 'solid-js'
import { Component, createMemo } from 'solid-js'
import { Portal } from 'solid-js/web'

export const HeaderPortal: Component<{ children: any }> = (props) => {
return <Portal mount={document.getElementById('site-header')!}>{props.children}</Portal>
}

export const Page: Component<{ children: any; class?: string; classList?: any }> = (props) => {
const hasHorzPaddingClass = createMemo(() => {
if (isHorzPadding(props.class)) return true

for (const [cls, value] of Object.entries(props.classList || {})) {
if (!isHorzPadding(cls)) continue
if (value) return true
}

return false
})

return (
<main class={`h-full w-full px-2 sm:px-3 ${props.class || ''}`} classList={props.classList}>
<main
class={`h-full w-full ${props.class || ''}`}
classList={{
'px-2 sm:px-3': !hasHorzPaddingClass(),
...props.classList,
}}
>
{props.children}
</main>
)
}

function isHorzPadding(cls?: string) {
return cls?.startsWith('p-') || cls?.startsWith('px-')
}
Loading

0 comments on commit 6f9dd31

Please sign in to comment.