Browse Source

fix: tiktoken encoding model issue

Timothy J. Baek 6 months ago
parent
commit
50dcad0f73
3 changed files with 11 additions and 11 deletions
  1. 4 4
      Dockerfile
  2. 3 3
      backend/open_webui/apps/retrieval/main.py
  3. 4 4
      backend/open_webui/config.py

+ 4 - 4
Dockerfile

@@ -13,7 +13,7 @@ ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
 ARG USE_RERANKING_MODEL=""
 
 # Tiktoken encoding name; models to use can be found at https://huggingface.co/models?library=tiktoken
-ARG USE_TIKTOKEN_ENCODING_NAME="cl100k_base"
+ARG USE_TIKTOKEN_ENCODING_MODEL_NAME="cl100k_base"
 
 ARG BUILD_HASH=dev-build
 # Override at your own risk - non-root configurations are untested
@@ -77,7 +77,7 @@ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
     SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
 
 ## Tiktoken model settings ##
-ENV TIKTOKEN_ENCODING_NAME="$USE_TIKTOKEN_ENCODING_NAME" \
+ENV TIKTOKEN_ENCODING_MODEL_NAME="$USE_TIKTOKEN_ENCODING_MODEL_NAME" \
     TIKTOKEN_CACHE_DIR="/app/backend/data/cache/tiktoken"
 
 ## Hugging Face download cache ##
@@ -139,13 +139,13 @@ RUN pip3 install uv && \
     uv pip install --system -r requirements.txt --no-cache-dir && \
     python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
     python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
-    python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
+    python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_MODEL_NAME'])"; \
     else \
     pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
     uv pip install --system -r requirements.txt --no-cache-dir && \
     python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
     python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
-    python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \
+    python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_MODEL_NAME'])"; \
     fi; \
     chown -R $UID:$GID /app/backend/data/
 

+ 3 - 3
backend/open_webui/apps/retrieval/main.py

@@ -50,7 +50,7 @@ from open_webui.apps.retrieval.utils import (
 from open_webui.apps.webui.models.files import Files
 from open_webui.config import (
     BRAVE_SEARCH_API_KEY,
-    TIKTOKEN_ENCODING_NAME,
+    TIKTOKEN_ENCODING_MODEL_NAME,
     RAG_TEXT_SPLITTER,
     CHUNK_OVERLAP,
     CHUNK_SIZE,
@@ -135,7 +135,7 @@ app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
 app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
 
 app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
-app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
+app.state.config.TIKTOKEN_ENCODING_MODEL_NAME = TIKTOKEN_ENCODING_MODEL_NAME
 
 app.state.config.CHUNK_SIZE = CHUNK_SIZE
 app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
@@ -667,7 +667,7 @@ def save_docs_to_vector_db(
             )
         elif app.state.config.TEXT_SPLITTER == "token":
             text_splitter = TokenTextSplitter(
-                encoding_name=app.state.config.TIKTOKEN_ENCODING_NAME,
+                model_name=app.state.config.TIKTOKEN_ENCODING_MODEL_NAME,
                 chunk_size=app.state.config.CHUNK_SIZE,
                 chunk_overlap=app.state.config.CHUNK_OVERLAP,
                 add_start_index=True,

+ 4 - 4
backend/open_webui/config.py

@@ -1074,10 +1074,10 @@ RAG_TEXT_SPLITTER = PersistentConfig(
 
 
 TIKTOKEN_CACHE_DIR = os.environ.get("TIKTOKEN_CACHE_DIR", f"{CACHE_DIR}/tiktoken")
-TIKTOKEN_ENCODING_NAME = PersistentConfig(
-    "TIKTOKEN_ENCODING_NAME",
-    "rag.tiktoken_encoding_name",
-    os.environ.get("TIKTOKEN_ENCODING_NAME", "cl100k_base"),
+TIKTOKEN_ENCODING_MODEL_NAME = PersistentConfig(
+    "TIKTOKEN_ENCODING_MODEL_NAME",
+    "rag.tiktoken_encoding_model_name",
+    os.environ.get("TIKTOKEN_ENCODING_MODEL_NAME", "cl100k_base"),
 )