浏览代码

refac: better migration script

Timothy J. Baek 11 月之前
父节点
当前提交
e316abcfc8

+ 6 - 1
backend/apps/web/internal/migrations/010_migrate_modelfiles_to_models.py

@@ -30,6 +30,8 @@ import peewee as pw
 from peewee_migrate import Migrator
 import json
 
+from utils.misc import parse_ollama_modelfile
+
 with suppress(ImportError):
     import playhouse.postgres_ext as pw_pext
 
@@ -64,13 +66,16 @@ def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
             }
         )
 
+        info = parse_ollama_modelfile(modelfile.modelfile.get("content"))
+
         # Insert the processed data into the 'model' table
         Model.create(
             id=modelfile.tag_name,
             user_id=modelfile.user_id,
+            base_model_id=info.get("base_model_id"),
             name=modelfile.modelfile.get("title"),
             meta=meta,
-            params="{}",
+            params=json.dumps(info.get("params", {})),
             created_at=modelfile.timestamp,
             updated_at=modelfile.timestamp,
         )

+ 74 - 0
backend/utils/misc.py

@@ -1,5 +1,6 @@
 from pathlib import Path
 import hashlib
+import json
 import re
 from datetime import timedelta
 from typing import Optional
@@ -110,3 +111,76 @@ def parse_duration(duration: str) -> Optional[timedelta]:
             total_duration += timedelta(weeks=number)
 
     return total_duration
+
+
+def parse_ollama_modelfile(model_text):
+    parameters_meta = {
+        "mirostat": int,
+        "mirostat_eta": float,
+        "mirostat_tau": float,
+        "num_ctx": int,
+        "repeat_last_n": int,
+        "repeat_penalty": float,
+        "temperature": float,
+        "seed": int,
+        "stop": str,
+        "tfs_z": float,
+        "num_predict": int,
+        "top_k": int,
+        "top_p": float,
+    }
+
+    data = {"base_model_id": None, "params": {}}
+
+    # Parse base model
+    base_model_match = re.search(
+        r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE
+    )
+    if base_model_match:
+        data["base_model_id"] = base_model_match.group(1)
+
+    # Parse template
+    template_match = re.search(
+        r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
+    )
+    if template_match:
+        data["params"] = {"template": template_match.group(1).strip()}
+
+    # Parse stops
+    stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE)
+    if stops:
+        data["params"]["stop"] = stops
+
+    # Parse other parameters from the provided list
+    for param, param_type in parameters_meta.items():
+        param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
+        if param_match:
+            value = param_match.group(1)
+            if param_type == int:
+                value = int(value)
+            elif param_type == float:
+                value = float(value)
+            data["params"][param] = value
+
+    # Parse adapter
+    adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE)
+    if adapter_match:
+        data["params"]["adapter"] = adapter_match.group(1)
+
+    # Parse system description
+    system_desc_match = re.search(
+        r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
+    )
+    if system_desc_match:
+        data["params"]["system"] = system_desc_match.group(1).strip()
+
+    # Parse messages
+    messages = []
+    message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE)
+    for role, content in message_matches:
+        messages.append({"role": role, "content": content})
+
+    if messages:
+        data["params"]["messages"] = messages
+
+    return data

+ 1 - 1
src/lib/components/workspace/Models.svelte

@@ -139,7 +139,7 @@
 				</div>
 
 				<div class=" flex-1 self-center">
-					<div class=" font-bold capitalize">{model.name}</div>
+					<div class=" font-bold capitalize line-clamp-1">{model.name}</div>
 					<div class=" text-sm overflow-hidden text-ellipsis line-clamp-1">
 						{model?.info?.meta?.description ?? model.id}
 					</div>