summ.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import curses
  2. import json
  3. from utils import get_url_for_topic, topic_urls, menu, getUrls, get_summary, getArticleText, knn_search
  4. import requests
  5. from sentence_transformers import SentenceTransformer
  6. from mattsollamatools import chunker
  7. if __name__ == "__main__":
  8. chosen_topic = curses.wrapper(menu)
  9. print("Here is your news summary:\n")
  10. urls = getUrls(chosen_topic, n=5)
  11. model = SentenceTransformer('all-MiniLM-L6-v2')
  12. allEmbeddings = []
  13. for url in urls:
  14. article={}
  15. article['embeddings'] = []
  16. article['url'] = url
  17. text = getArticleText(url)
  18. summary = get_summary(text)
  19. chunks = chunker(text) # Use the chunk_text function from web_utils
  20. embeddings = model.encode(chunks)
  21. for (chunk, embedding) in zip(chunks, embeddings):
  22. item = {}
  23. item['source'] = chunk
  24. item['embedding'] = embedding.tolist() # Convert NumPy array to list
  25. item['sourcelength'] = len(chunk)
  26. article['embeddings'].append(item)
  27. allEmbeddings.append(article)
  28. print(f"{summary}\n")
  29. while True:
  30. context = []
  31. # Input a question from the user
  32. question = input("Enter your question about the news, or type quit: ")
  33. if question.lower() == 'quit':
  34. break
  35. # Embed the user's question
  36. question_embedding = model.encode([question])
  37. # Perform KNN search to find the best matches (indices and source text)
  38. best_matches = knn_search(question_embedding, allEmbeddings, k=10)
  39. sourcetext=""
  40. for i, (index, source_text) in enumerate(best_matches, start=1):
  41. sourcetext += f"{i}. Index: {index}, Source Text: {source_text}"
  42. systemPrompt = f"Only use the following information to answer the question. Do not use anything else: {sourcetext}"
  43. url = "http://localhost:11434/api/generate"
  44. payload = {
  45. "model": "mistral-openorca",
  46. "prompt": question,
  47. "system": systemPrompt,
  48. "stream": False,
  49. "context": context
  50. }
  51. # Convert the payload to a JSON string
  52. payload_json = json.dumps(payload)
  53. # Set the headers to specify JSON content
  54. headers = {
  55. "Content-Type": "application/json"
  56. }
  57. # Send the POST request
  58. response = requests.post(url, data=payload_json, headers=headers)
  59. # Check the response
  60. if response.status_code == 200:
  61. output = json.loads(response.text)
  62. context = output['context']
  63. print(output['response']+ "\n")
  64. else:
  65. print(f"Request failed with status code {response.status_code}")