main.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import ollama
  2. import warnings
  3. from mattsollamatools import chunker
  4. from newspaper import Article
  5. import numpy as np
  6. from sklearn.neighbors import NearestNeighbors
  7. import nltk
  8. warnings.filterwarnings(
  9. "ignore", category=FutureWarning, module="transformers.tokenization_utils_base"
  10. )
  11. nltk.download("punkt_tab", quiet=True)
  12. def getArticleText(url):
  13. """Gets the text of an article from a URL.
  14. Often there are a bunch of ads and menus on pages for a news article.
  15. This uses newspaper3k to get just the text of just the article.
  16. """
  17. article = Article(url)
  18. article.download()
  19. article.parse()
  20. return article.text
  21. def knn_search(question_embedding, embeddings, k=5):
  22. """Performs K-nearest neighbors (KNN) search"""
  23. X = np.array(
  24. [item["embedding"] for article in embeddings for item in article["embeddings"]]
  25. )
  26. source_texts = [
  27. item["source"] for article in embeddings for item in article["embeddings"]
  28. ]
  29. # Fit a KNN model on the embeddings
  30. knn = NearestNeighbors(n_neighbors=k, metric="cosine")
  31. knn.fit(X)
  32. # Find the indices and distances of the k-nearest neighbors.
  33. _, indices = knn.kneighbors(question_embedding, n_neighbors=k)
  34. # Get the indices and source texts of the best matches
  35. best_matches = [(indices[0][i], source_texts[indices[0][i]]) for i in range(k)]
  36. return best_matches
  37. def check(document, claim):
  38. """Checks if the claim is supported by the document by calling bespoke-minicheck.
  39. Returns Yes/yes if the claim is supported by the document, No/no otherwise.
  40. Support for logits will be added in the future.
  41. bespoke-minicheck's system prompt is defined as:
  42. 'Determine whether the provided claim is consistent with the corresponding
  43. document. Consistency in this context implies that all information presented in the claim
  44. is substantiated by the document. If not, it should be considered inconsistent. Please
  45. assess the claim's consistency with the document by responding with either "Yes" or "No".'
  46. bespoke-minicheck's user prompt is defined as:
  47. "Document: {document}\nClaim: {claim}"
  48. """
  49. prompt = f"Document: {document}\nClaim: {claim}"
  50. response = ollama.generate(
  51. model="bespoke-minicheck", prompt=prompt, options={"num_predict": 2, "temperature": 0.0}
  52. )
  53. return response["response"].strip()
  54. if __name__ == "__main__":
  55. allEmbeddings = []
  56. default_url = "https://www.theverge.com/2024/9/12/24242439/openai-o1-model-reasoning-strawberry-chatgpt"
  57. user_input = input(
  58. "Enter the URL of an article you want to chat with, or press Enter for default example: "
  59. )
  60. article_url = user_input.strip() if user_input.strip() else default_url
  61. article = {}
  62. article["embeddings"] = []
  63. article["url"] = article_url
  64. text = getArticleText(article_url)
  65. chunks = chunker(text)
  66. # Embed (batch) chunks using ollama
  67. embeddings = ollama.embed(model="all-minilm", input=chunks)["embeddings"]
  68. for chunk, embedding in zip(chunks, embeddings):
  69. item = {}
  70. item["source"] = chunk
  71. item["embedding"] = embedding
  72. item["sourcelength"] = len(chunk)
  73. article["embeddings"].append(item)
  74. allEmbeddings.append(article)
  75. print(f"\nLoaded, chunked, and embedded text from {article_url}.\n")
  76. while True:
  77. # Input a question from the user
  78. # For example, "Who is the chief research officer?"
  79. question = input("Enter your question or type quit: ")
  80. if question.lower() == "quit":
  81. break
  82. # Embed the user's question using ollama.embed
  83. question_embedding = ollama.embed(model="all-minilm", input=question)[
  84. "embeddings"
  85. ]
  86. # Perform KNN search to find the best matches (indices and source text)
  87. best_matches = knn_search(question_embedding, allEmbeddings, k=4)
  88. sourcetext = "\n\n".join([source_text for (_, source_text) in best_matches])
  89. print(f"\nRetrieved chunks: \n{sourcetext}\n")
  90. # Give the retreived chunks and question to the chat model
  91. system_prompt = f"Only use the following information to answer the question. Do not use anything else: {sourcetext}"
  92. ollama_response = ollama.generate(
  93. model="llama3.2",
  94. prompt=question,
  95. system=system_prompt,
  96. options={"stream": False},
  97. )
  98. answer = ollama_response["response"]
  99. print(f"LLM Answer:\n{answer}\n")
  100. # Check each sentence in the response for grounded factuality
  101. if answer:
  102. for claim in nltk.sent_tokenize(answer):
  103. print(f"LLM Claim: {claim}")
  104. print(
  105. f"Is this claim supported by the context according to bespoke-minicheck? {check(sourcetext, claim)}\n"
  106. )