From 66ac80392800a64994d10ccefba29ba369891803 Mon Sep 17 00:00:00 2001 From: rjmacarthy Date: Tue, 24 Sep 2024 20:31:18 +0100 Subject: [PATCH] add support code qwen 2.5 coder, fix stopwords --- src/common/constants.ts | 2 ++ src/extension/fim-templates.ts | 26 ++++++++++++++++---- src/extension/providers/completion.ts | 35 ++++++++++++++++----------- 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/src/common/constants.ts b/src/common/constants.ts index 0ec1fe58..0861d15b 100644 --- a/src/common/constants.ts +++ b/src/common/constants.ts @@ -226,6 +226,8 @@ export const STOP_CODEGEMMA = ['<|file_separator|>', '<|end_of_turn|>', ''] export const STOP_CODESTRAL = ['[PREFIX]', '[SUFFIX]'] +export const STOP_QWEN = ['<|endoftext|>'] + export const DEFAULT_TEMPLATE_NAMES = defaultTemplates.map(({ name }) => name) export const DEFAULT_ACTION_TEMPLATES = [ diff --git a/src/extension/fim-templates.ts b/src/extension/fim-templates.ts index c704b8af..2e04f0d9 100644 --- a/src/extension/fim-templates.ts +++ b/src/extension/fim-templates.ts @@ -4,7 +4,8 @@ import { STOP_LLAMA, STOP_STARCODER, STOP_CODEGEMMA, - STOP_CODESTRAL + STOP_CODESTRAL, + STOP_QWEN } from '../common/constants' import { supportedLanguages } from '../common/languages' import { FimPromptTemplate } from '../common/types' @@ -79,6 +80,13 @@ export const getFimPromptTemplateCodestral = ({ return `${fileContext}\n\n[SUFFIX]${suffix}[PREFIX]${heading}${prefix}` } +export const getFimPromptTemplateQwen = ({ + prefixSuffix, +}: FimPromptTemplate) => { + const { prefix, suffix } = prefixSuffix + return `<|fim_prefix|>${prefix}<|fim_suffix|>${suffix}<|fim_middle|>` +} + export const getFimPromptTemplateOther = ({ context, header, @@ -112,10 +120,13 @@ function getFimTemplateAuto(fimModel: string, args: FimPromptTemplate) { return getFimPromptTemplateCodestral(args) } + if (fimModel.includes(FIM_TEMPLATE_FORMAT.codeqwen)) { + return getFimPromptTemplateQwen(args) + } + if ( fimModel.includes(FIM_TEMPLATE_FORMAT.stableCode) || fimModel.includes(FIM_TEMPLATE_FORMAT.starcoder) || - fimModel.includes(FIM_TEMPLATE_FORMAT.codeqwen) || fimModel.includes(FIM_TEMPLATE_FORMAT.codegemma) ) { return getFimPromptTemplateOther(args) @@ -137,11 +148,14 @@ function getFimTemplateChosen(format: string, args: FimPromptTemplate) { return getFimPromptTemplateCodestral(args) } + if (format === FIM_TEMPLATE_FORMAT.codeqwen) { + return getFimPromptTemplateQwen(args) + } + if ( format === FIM_TEMPLATE_FORMAT.stableCode || format === FIM_TEMPLATE_FORMAT.starcoder || - format === FIM_TEMPLATE_FORMAT.codegemma || - format === FIM_TEMPLATE_FORMAT.codeqwen + format === FIM_TEMPLATE_FORMAT.codegemma ) { return getFimPromptTemplateOther(args) } @@ -174,7 +188,8 @@ export const getStopWordsAuto = (fimModel: string) => { if ( fimModel.includes(FIM_TEMPLATE_FORMAT.stableCode) || - fimModel.includes(FIM_TEMPLATE_FORMAT.starcoder) + fimModel.includes(FIM_TEMPLATE_FORMAT.starcoder) || + fimModel.includes(FIM_TEMPLATE_FORMAT.codeqwen) ) { return ['<|endoftext|>'] } @@ -193,6 +208,7 @@ export const getStopWordsAuto = (fimModel: string) => { export const getStopWordsChosen = (format: string) => { if (format === FIM_TEMPLATE_FORMAT.codellama) return STOP_LLAMA if (format === FIM_TEMPLATE_FORMAT.deepseek) return STOP_DEEPSEEK + if (format === FIM_TEMPLATE_FORMAT.codeqwen) return STOP_QWEN if ( format === FIM_TEMPLATE_FORMAT.stableCode || format === FIM_TEMPLATE_FORMAT.starcoder diff --git a/src/extension/providers/completion.ts b/src/extension/providers/completion.ts index d0fb17cd..0ee1431c 100644 --- a/src/extension/providers/completion.ts +++ b/src/extension/providers/completion.ts @@ -98,6 +98,7 @@ export class CompletionProvider implements InlineCompletionItemProvider { private _templateProvider: TemplateProvider private _fileContextEnabled = this._config.get('fileContextEnabled') as boolean private _usingFimTemplate = false + private _provider: TwinnyProvider | undefined constructor( statusBar: StatusBarItem, @@ -122,7 +123,7 @@ export class CompletionProvider implements InlineCompletionItemProvider { context: InlineCompletionContext ): Promise { const editor = window.activeTextEditor - + this._provider = this.getProvider() const isLastCompletionAccepted = this._acceptedLastCompletion && !this.enableSubsequentCompletions @@ -248,11 +249,15 @@ export class CompletionProvider implements InlineCompletionItemProvider { } private onData(data: StreamResponse | undefined): string { - const provider = this.getProvider() - if (!provider) return '' + if (!this._provider) return '' + + const stopWords = getStopWords( + this._provider.modelName, + this._provider.fimTemplate || FIM_TEMPLATE_FORMAT.automatic + ) try { - const providerFimData = getProviderFimData(provider.provider, data) + const providerFimData = getProviderFimData(this._provider.provider, data) if (providerFimData === undefined) return '' this._completion = this._completion + providerFimData @@ -268,6 +273,10 @@ export class CompletionProvider implements InlineCompletionItemProvider { ) } + if (stopWords.some(stopWord => this._completion.includes(stopWord))) { + return this._completion + } + if ( !this._multilineCompletionsEnabled && this._chunkCount >= MIN_COMPLETION_CHUNKS && @@ -430,12 +439,11 @@ export class CompletionProvider implements InlineCompletionItemProvider { } private removeStopWords(completion: string) { - const provider = this.getProvider() - if (!provider) return completion + if (!this._provider) return completion let filteredCompletion = completion const stopWords = getStopWords( - provider.modelName, - provider.fimTemplate || FIM_TEMPLATE_FORMAT.automatic + this._provider.modelName, + this._provider.fimTemplate || FIM_TEMPLATE_FORMAT.automatic ) stopWords.forEach((stopWord) => { filteredCompletion = filteredCompletion.split(stopWord).join('') @@ -444,14 +452,13 @@ export class CompletionProvider implements InlineCompletionItemProvider { } private async getPrompt(prefixSuffix: PrefixSuffix) { - const provider = this.getProvider() - if (!provider) return '' - if (!this._document || !this._position || !provider) return '' + if (!this._provider) return '' + if (!this._document || !this._position || !this._provider) return '' const documentLanguage = this._document.languageId const fileInteractionContext = await this.getFileInteractionContext() - if (provider.fimTemplate === FIM_TEMPLATE_FORMAT.custom) { + if (this._provider.fimTemplate === FIM_TEMPLATE_FORMAT.custom) { const systemMessage = await this._templateProvider.readSystemMessageTemplate('fim-system.hbs') @@ -472,8 +479,8 @@ export class CompletionProvider implements InlineCompletionItemProvider { } return getFimPrompt( - provider.modelName, - provider.fimTemplate || FIM_TEMPLATE_FORMAT.automatic, + this._provider.modelName, + this._provider.fimTemplate || FIM_TEMPLATE_FORMAT.automatic, { context: fileInteractionContext || '', prefixSuffix,