Просмотр исходного кода

feat: include num_thread in advanced params

Timothy J. Baek 11 месяцев назад
Родитель
Сommit
e0ba585204

+ 65 - 32
backend/apps/ollama/main.py

@@ -906,44 +906,77 @@ async def generate_chat_completion(
         if model_info.params:
             payload["options"] = {}
 
-            payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
-            payload["options"]["mirostat_eta"] = model_info.params.get(
-                "mirostat_eta", None
-            )
-            payload["options"]["mirostat_tau"] = model_info.params.get(
-                "mirostat_tau", None
-            )
-            payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
+            if model_info.params.get("mirostat", None):
+                payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
 
-            payload["options"]["repeat_last_n"] = model_info.params.get(
-                "repeat_last_n", None
-            )
-            payload["options"]["repeat_penalty"] = model_info.params.get(
-                "frequency_penalty", None
-            )
+            if model_info.params.get("mirostat_eta", None):
+                payload["options"]["mirostat_eta"] = model_info.params.get(
+                    "mirostat_eta", None
+                )
 
-            payload["options"]["temperature"] = model_info.params.get(
-                "temperature", None
-            )
-            payload["options"]["seed"] = model_info.params.get("seed", None)
+            if model_info.params.get("mirostat_tau", None):
 
-            payload["options"]["stop"] = (
-                [
-                    bytes(stop, "utf-8").decode("unicode_escape")
-                    for stop in model_info.params["stop"]
-                ]
-                if model_info.params.get("stop", None)
-                else None
-            )
+                payload["options"]["mirostat_tau"] = model_info.params.get(
+                    "mirostat_tau", None
+                )
 
-            payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
+            if model_info.params.get("num_ctx", None):
+                payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
 
-            payload["options"]["num_predict"] = model_info.params.get(
-                "max_tokens", None
-            )
-            payload["options"]["top_k"] = model_info.params.get("top_k", None)
+            if model_info.params.get("repeat_last_n", None):
+                payload["options"]["repeat_last_n"] = model_info.params.get(
+                    "repeat_last_n", None
+                )
+
+            if model_info.params.get("frequency_penalty", None):
+                payload["options"]["repeat_penalty"] = model_info.params.get(
+                    "frequency_penalty", None
+                )
+
+            if model_info.params.get("temperature", None):
+                payload["options"]["temperature"] = model_info.params.get(
+                    "temperature", None
+                )
+
+            if model_info.params.get("seed", None):
+                payload["options"]["seed"] = model_info.params.get("seed", None)
+
+            if model_info.params.get("stop", None):
+                payload["options"]["stop"] = (
+                    [
+                        bytes(stop, "utf-8").decode("unicode_escape")
+                        for stop in model_info.params["stop"]
+                    ]
+                    if model_info.params.get("stop", None)
+                    else None
+                )
+
+            if model_info.params.get("tfs_z", None):
+                payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
 
-            payload["options"]["top_p"] = model_info.params.get("top_p", None)
+            if model_info.params.get("max_tokens", None):
+                payload["options"]["num_predict"] = model_info.params.get(
+                    "max_tokens", None
+                )
+
+            if model_info.params.get("top_k", None):
+                payload["options"]["top_k"] = model_info.params.get("top_k", None)
+
+            if model_info.params.get("top_p", None):
+                payload["options"]["top_p"] = model_info.params.get("top_p", None)
+
+            if model_info.params.get("use_mmap", None):
+                payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None)
+
+            if model_info.params.get("use_mlock", None):
+                payload["options"]["use_mlock"] = model_info.params.get(
+                    "use_mlock", None
+                )
+
+            if model_info.params.get("num_thread", None):
+                payload["options"]["num_thread"] = model_info.params.get(
+                    "num_thread", None
+                )
 
         if model_info.params.get("system", None):
             # Check if the payload already has a system message

+ 91 - 0
src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte

@@ -20,6 +20,9 @@
 		tfs_z: '',
 		num_ctx: '',
 		max_tokens: '',
+		use_mmap: null,
+		use_mlock: null,
+		num_thread: null,
 		template: null
 	};
 
@@ -559,6 +562,7 @@
 			</div>
 		{/if}
 	</div>
+
 	<div class=" py-0.5 w-full justify-between">
 		<div class="flex w-full justify-between">
 			<div class=" self-center text-xs font-medium">{$i18n.t('Max Tokens (num_predict)')}</div>
@@ -604,6 +608,93 @@
 			</div>
 		{/if}
 	</div>
+
+	<div class=" py-0.5 w-full justify-between">
+		<div class="flex w-full justify-between">
+			<div class=" self-center text-xs font-medium">{$i18n.t('use_mmap (Ollama)')}</div>
+
+			<button
+				class="p-1 px-3 text-xs flex rounded transition"
+				type="button"
+				on:click={() => {
+					params.use_mmap = (params?.use_mmap ?? null) === null ? true : null;
+				}}
+			>
+				{#if (params?.use_mmap ?? null) === null}
+					<span class="ml-2 self-center">{$i18n.t('Default')}</span>
+				{:else}
+					<span class="ml-2 self-center">{$i18n.t('On')}</span>
+				{/if}
+			</button>
+		</div>
+	</div>
+
+	<div class=" py-0.5 w-full justify-between">
+		<div class="flex w-full justify-between">
+			<div class=" self-center text-xs font-medium">{$i18n.t('use_mlock (Ollama)')}</div>
+
+			<button
+				class="p-1 px-3 text-xs flex rounded transition"
+				type="button"
+				on:click={() => {
+					params.use_mlock = (params?.use_mlock ?? null) === null ? true : null;
+				}}
+			>
+				{#if (params?.use_mlock ?? null) === null}
+					<span class="ml-2 self-center">{$i18n.t('Default')}</span>
+				{:else}
+					<span class="ml-2 self-center">{$i18n.t('On')}</span>
+				{/if}
+			</button>
+		</div>
+	</div>
+
+	<div class=" py-0.5 w-full justify-between">
+		<div class="flex w-full justify-between">
+			<div class=" self-center text-xs font-medium">{$i18n.t('num_thread (Ollama)')}</div>
+
+			<button
+				class="p-1 px-3 text-xs flex rounded transition"
+				type="button"
+				on:click={() => {
+					params.num_thread = (params?.num_thread ?? null) === null ? 2 : null;
+				}}
+			>
+				{#if (params?.num_thread ?? null) === null}
+					<span class="ml-2 self-center">{$i18n.t('Default')}</span>
+				{:else}
+					<span class="ml-2 self-center">{$i18n.t('Custom')}</span>
+				{/if}
+			</button>
+		</div>
+
+		{#if (params?.num_thread ?? null) !== null}
+			<div class="flex mt-0.5 space-x-2">
+				<div class=" flex-1">
+					<input
+						id="steps-range"
+						type="range"
+						min="1"
+						max="256"
+						step="1"
+						bind:value={params.num_thread}
+						class="w-full h-2 rounded-lg appearance-none cursor-pointer dark:bg-gray-700"
+					/>
+				</div>
+				<div class="">
+					<input
+						bind:value={params.num_thread}
+						type="number"
+						class=" bg-transparent text-center w-14"
+						min="1"
+						max="256"
+						step="1"
+					/>
+				</div>
+			</div>
+		{/if}
+	</div>
+
 	<div class=" py-0.5 w-full justify-between">
 		<div class="flex w-full justify-between">
 			<div class=" self-center text-xs font-medium">{$i18n.t('Template')}</div>