engine.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import os
  2. import json
  3. import sys
  4. from contextlib import contextmanager
  5. from llama_cpp import Llama as LLM
  6. import ollama.model
  7. @contextmanager
  8. def suppress_stderr():
  9. stderr = os.dup(sys.stderr.fileno())
  10. with open(os.devnull, 'w') as devnull:
  11. os.dup2(devnull.fileno(), sys.stderr.fileno())
  12. yield
  13. os.dup2(stderr, sys.stderr.fileno())
  14. def generate(model, prompt, models_home='.', llms={}, *args, **kwargs):
  15. llm = load(model, models_home=models_home, llms=llms)
  16. if 'max_tokens' not in kwargs:
  17. kwargs.update({'max_tokens': 16384})
  18. if 'stop' not in kwargs:
  19. kwargs.update({'stop': ['Q:', '\n']})
  20. if 'stream' not in kwargs:
  21. kwargs.update({'stream': True})
  22. for output in llm(prompt, *args, **kwargs):
  23. yield json.dumps(output)
  24. def load(model, models_home='.', llms={}):
  25. llm = llms.get(model, None)
  26. if not llm:
  27. model_path = {
  28. name: path
  29. for name, path in ollama.model.models(models_home)
  30. }.get(model, None)
  31. if model_path is None:
  32. raise ValueError('Model not found')
  33. # suppress LLM's output
  34. with suppress_stderr():
  35. llm = LLM(model_path, verbose=False)
  36. llms.update({model: llm})
  37. return llm
  38. def unload(model, llms={}):
  39. if model in llms:
  40. llms.pop(model)