main.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. """Simple example to demonstrate how to use the bespoke-minicheck model."""
  2. import ollama
  3. # NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
  4. def check(document, claim):
  5. """Checks if the claim is supported by the document by calling bespoke-minicheck.
  6. Returns Yes/yes if the claim is supported by the document, No/no otherwise.
  7. Support for logits will be added in the future.
  8. bespoke-minicheck's system prompt is defined as:
  9. 'Determine whether the provided claim is consistent with the corresponding
  10. document. Consistency in this context implies that all information presented in the claim
  11. is substantiated by the document. If not, it should be considered inconsistent. Please
  12. assess the claim's consistency with the document by responding with either "Yes" or "No".'
  13. bespoke-minicheck's user prompt is defined as:
  14. "Document: {document}\nClaim: {claim}"
  15. """
  16. prompt = f"Document: {document}\nClaim: {claim}"
  17. response = ollama.generate(
  18. model="bespoke-minicheck", prompt=prompt, options={"num_predict": 2, "temperature": 0.0}
  19. )
  20. return response["response"].strip()
  21. def get_user_input(prompt):
  22. user_input = input(prompt)
  23. if not user_input:
  24. exit()
  25. print()
  26. return user_input
  27. def main():
  28. while True:
  29. # Get a document from the user (e.g. "Ryan likes running and biking.")
  30. document = get_user_input("Enter a document: ")
  31. # Get a claim from the user (e.g. "Ryan likes to run.")
  32. claim = get_user_input("Enter a claim: ")
  33. # Check if the claim is supported by the document
  34. grounded_factuality_check = check(document, claim)
  35. print(
  36. f"Is the claim supported by the document according to bespoke-minicheck? {grounded_factuality_check}"
  37. )
  38. print("\n\n")
  39. if __name__ == "__main__":
  40. main()