engine.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import os
  2. import json
  3. import sys
  4. from contextlib import contextmanager
  5. from llama_cpp import Llama as LLM
  6. from template import template
  7. import ollama.model
  8. @contextmanager
  9. def suppress_stderr():
  10. stderr = os.dup(sys.stderr.fileno())
  11. with open(os.devnull, "w") as devnull:
  12. os.dup2(devnull.fileno(), sys.stderr.fileno())
  13. yield
  14. os.dup2(stderr, sys.stderr.fileno())
  15. def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
  16. llm = load(model, models_home=models_home, llms=llms)
  17. prompt = template(model, prompt)
  18. if "max_tokens" not in kwargs:
  19. kwargs.update({"max_tokens": 16384})
  20. if "stop" not in kwargs:
  21. kwargs.update({"stop": ["Q:"]})
  22. if "stream" not in kwargs:
  23. kwargs.update({"stream": True})
  24. for output in llm(prompt, *args, **kwargs):
  25. yield json.dumps(output)
  26. def load(model, models_home=".", llms={}):
  27. llm = llms.get(model, None)
  28. if not llm:
  29. model_path = {
  30. name: path for name, path in ollama.model.models(models_home)
  31. }.get(model, None)
  32. if model_path is None:
  33. # try loading this as a path to a model, rather than a model name
  34. if os.path.isfile(model):
  35. model_path = model
  36. else:
  37. raise ValueError("Model not found")
  38. # suppress LLM's output
  39. with suppress_stderr():
  40. llm = LLM(model_path, verbose=False)
  41. llms.update({model: llm})
  42. return llm
  43. def unload(model, llms={}):
  44. if model in llms:
  45. llms.pop(model)