|
@@ -3,12 +3,15 @@ from langchain.chains import RetrievalQA
|
|
|
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
|
from langchain.vectorstores import Chroma
|
|
|
-from langchain.llms import GPT4All, Ollama
|
|
|
+from langchain.llms import Ollama
|
|
|
import os
|
|
|
import argparse
|
|
|
import time
|
|
|
|
|
|
model = os.environ.get("MODEL", "llama2-uncensored")
|
|
|
+# For embeddings model, the example uses a sentence-transformers model
|
|
|
+# https://www.sbert.net/docs/pretrained_models.html
|
|
|
+# "The all-mpnet-base-v2 model provides the best quality, while all-MiniLM-L6-v2 is 5 times faster and still offers good quality."
|
|
|
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME", "all-MiniLM-L6-v2")
|
|
|
persist_directory = os.environ.get("PERSIST_DIRECTORY", "db")
|
|
|
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
|
|
@@ -44,7 +47,6 @@ def main():
|
|
|
# Print the result
|
|
|
print("\n\n> Question:")
|
|
|
print(query)
|
|
|
- print(f"\n> Answer (took {round(end - start, 2)} s.):")
|
|
|
print(answer)
|
|
|
|
|
|
# Print the relevant sources used for the answer
|