瀏覽代碼

feat: autocompletion

Timothy Jaeryang Baek 5 月之前
父節點
當前提交
a07213b5be

+ 6 - 0
backend/open_webui/config.py

@@ -1037,6 +1037,12 @@ Only output a continuation. If you are unsure how to proceed, output nothing.
 <context>Search</context>
 <text>Best destinations for hiking in</text> 
 **Output**: Europe, such as the Alps or the Scottish Highlands.
+
+### Input:
+<context>{{CONTEXT}}</context>
+<text>
+{{PROMPT}}
+</text>
 """
 
 

+ 6 - 2
backend/open_webui/main.py

@@ -1991,7 +1991,6 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)):
 
 @app.post("/api/task/auto/completions")
 async def generate_autocompletion(form_data: dict, user=Depends(get_verified_user)):
-    context = form_data.get("context")
 
     model_list = await get_all_models()
     models = {model["id"]: model for model in model_list}
@@ -2021,8 +2020,11 @@ async def generate_autocompletion(form_data: dict, user=Depends(get_verified_use
     else:
         template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
 
+    context = form_data.get("context")
+    prompt = form_data.get("prompt")
+
     content = autocomplete_generation_template(
-        template, form_data["messages"], context, {"name": user.name}
+        template, prompt, context, {"name": user.name}
     )
 
     payload = {
@@ -2036,6 +2038,8 @@ async def generate_autocompletion(form_data: dict, user=Depends(get_verified_use
         },
     }
 
+    print(payload)
+
     # Handle pipeline filters
     try:
         payload = filter_pipeline(payload, user, models)

+ 7 - 10
backend/open_webui/utils/task.py

@@ -53,7 +53,9 @@ def prompt_template(
 
 def replace_prompt_variable(template: str, prompt: str) -> str:
     def replacement_function(match):
-        full_match = match.group(0)
+        full_match = match.group(
+            0
+        ).lower()  # Normalize to lowercase for consistent handling
         start_length = match.group(1)
         end_length = match.group(2)
         middle_length = match.group(3)
@@ -73,11 +75,9 @@ def replace_prompt_variable(template: str, prompt: str) -> str:
             return f"{start}...{end}"
         return ""
 
-    template = re.sub(
-        r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
-        replacement_function,
-        template,
-    )
+    # Updated regex pattern to make it case-insensitive with the `(?i)` flag
+    pattern = r"(?i){{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}"
+    template = re.sub(pattern, replacement_function, template)
     return template
 
 
@@ -214,15 +214,12 @@ def emoji_generation_template(
 
 def autocomplete_generation_template(
     template: str,
-    messages: list[dict],
+    prompt: Optional[str] = None,
     context: Optional[str] = None,
     user: Optional[dict] = None,
 ) -> str:
-    prompt = get_last_user_message(messages)
     template = template.replace("{{CONTEXT}}", context if context else "")
-
     template = replace_prompt_variable(template, prompt)
-    template = replace_messages_variable(template, messages)
 
     template = prompt_template(
         template,

+ 47 - 0
src/lib/apis/index.ts

@@ -397,6 +397,53 @@ export const generateQueries = async (
 	}
 };
 
+
+
+export const generateAutoCompletion = async (
+	token: string = '',
+	model: string,
+	prompt: string,
+	context: string = 'search',
+) => {
+	const controller = new AbortController();
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/task/auto/completions`, {
+		signal: controller.signal,
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			model: model,
+			prompt: prompt,
+			context: context,
+			stream: false
+		})
+	})
+	.then(async (res) => {
+		if (!res.ok) throw await res.json();
+		return res.json();
+	})
+	.catch((err) => {
+		console.log(err);
+		if ('detail' in err) {
+			error = err.detail;
+		}
+		return null;
+	});
+
+	if (error) {
+		throw error;
+	}
+
+	const response = res?.choices[0]?.message?.content ?? '';
+	return response;
+};
+
+
 export const generateMoACompletion = async (
 	token: string = '',
 	model: string,

+ 25 - 1
src/lib/components/chat/MessageInput.svelte

@@ -34,6 +34,8 @@
 	import Commands from './MessageInput/Commands.svelte';
 	import XMark from '../icons/XMark.svelte';
 	import RichTextInput from '../common/RichTextInput.svelte';
+	import { generateAutoCompletion } from '$lib/apis';
+	import { error, text } from '@sveltejs/kit';
 
 	const i18n = getContext('i18n');
 
@@ -47,6 +49,9 @@
 	export let atSelectedModel: Model | undefined;
 	export let selectedModels: [''];
 
+	let selectedModelIds = [];
+	$: selectedModelIds = atSelectedModel !== undefined ? [atSelectedModel.id] : selectedModels;
+
 	export let history;
 
 	export let prompt = '';
@@ -581,6 +586,7 @@
 										>
 											<RichTextInput
 												bind:this={chatInputElement}
+												bind:value={prompt}
 												id="chat-input"
 												messageInput={true}
 												shiftEnter={!$mobile ||
@@ -592,7 +598,25 @@
 												placeholder={placeholder ? placeholder : $i18n.t('Send a Message')}
 												largeTextAsFile={$settings?.largeTextAsFile ?? false}
 												autocomplete={true}
-												bind:value={prompt}
+												generateAutoCompletion={async (text) => {
+													if (selectedModelIds.length === 0 || !selectedModelIds.at(0)) {
+														toast.error($i18n.t('Please select a model first.'));
+													}
+
+													const res = await generateAutoCompletion(
+														localStorage.token,
+														selectedModelIds.at(0),
+														text
+													).catch((error) => {
+														console.log(error);
+														toast.error(error);
+														return null;
+													});
+
+													console.log(res);
+
+													return res;
+												}}
 												on:keydown={async (e) => {
 													e = e.detail.event;
 

+ 7 - 1
src/lib/components/common/RichTextInput.svelte

@@ -34,6 +34,7 @@
 	export let value = '';
 	export let id = '';
 
+	export let generateAutoCompletion: Function = async () => null;
 	export let autocomplete = false;
 	export let messageInput = false;
 	export let shiftEnter = false;
@@ -159,7 +160,12 @@
 										return null;
 									}
 
-									return 'AI-generated suggestion';
+									const suggestion = await generateAutoCompletion(text).catch(() => null);
+									if (!suggestion || suggestion.trim().length === 0) {
+										return null;
+									}
+
+									return suggestion;
 								}
 							})
 						]