models.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import time
  2. import logging
  3. import sys
  4. from aiocache import cached
  5. from fastapi import Request
  6. from open_webui.routers import openai, ollama
  7. from open_webui.functions import get_function_models
  8. from open_webui.models.functions import Functions
  9. from open_webui.models.models import Models
  10. from open_webui.utils.plugin import load_function_module_by_id
  11. from open_webui.utils.access_control import has_access
  12. from open_webui.config import (
  13. DEFAULT_ARENA_MODEL,
  14. )
  15. from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
  16. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  17. log = logging.getLogger(__name__)
  18. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  19. async def get_all_base_models(request: Request):
  20. function_models = []
  21. openai_models = []
  22. ollama_models = []
  23. if request.app.state.config.ENABLE_OPENAI_API:
  24. openai_models = await openai.get_all_models(request)
  25. openai_models = openai_models["data"]
  26. if request.app.state.config.ENABLE_OLLAMA_API:
  27. ollama_models = await ollama.get_all_models(request)
  28. ollama_models = [
  29. {
  30. "id": model["model"],
  31. "name": model["name"],
  32. "object": "model",
  33. "created": int(time.time()),
  34. "owned_by": "ollama",
  35. "ollama": model,
  36. }
  37. for model in ollama_models["models"]
  38. ]
  39. function_models = await get_function_models(request)
  40. models = function_models + openai_models + ollama_models
  41. return models
  42. async def get_all_models(request):
  43. models = await get_all_base_models(request)
  44. # If there are no models, return an empty list
  45. if len(models) == 0:
  46. return []
  47. # Add arena models
  48. if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
  49. arena_models = []
  50. if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
  51. arena_models = [
  52. {
  53. "id": model["id"],
  54. "name": model["name"],
  55. "info": {
  56. "meta": model["meta"],
  57. },
  58. "object": "model",
  59. "created": int(time.time()),
  60. "owned_by": "arena",
  61. "arena": True,
  62. }
  63. for model in request.app.state.config.EVALUATION_ARENA_MODELS
  64. ]
  65. else:
  66. # Add default arena model
  67. arena_models = [
  68. {
  69. "id": DEFAULT_ARENA_MODEL["id"],
  70. "name": DEFAULT_ARENA_MODEL["name"],
  71. "info": {
  72. "meta": DEFAULT_ARENA_MODEL["meta"],
  73. },
  74. "object": "model",
  75. "created": int(time.time()),
  76. "owned_by": "arena",
  77. "arena": True,
  78. }
  79. ]
  80. models = models + arena_models
  81. global_action_ids = [
  82. function.id for function in Functions.get_global_action_functions()
  83. ]
  84. enabled_action_ids = [
  85. function.id
  86. for function in Functions.get_functions_by_type("action", active_only=True)
  87. ]
  88. custom_models = Models.get_all_models()
  89. for custom_model in custom_models:
  90. if custom_model.base_model_id is None:
  91. for model in models:
  92. if (
  93. custom_model.id == model["id"]
  94. or custom_model.id == model["id"].split(":")[0]
  95. ):
  96. if custom_model.is_active:
  97. model["name"] = custom_model.name
  98. model["info"] = custom_model.model_dump()
  99. action_ids = []
  100. if "info" in model and "meta" in model["info"]:
  101. action_ids.extend(
  102. model["info"]["meta"].get("actionIds", [])
  103. )
  104. model["action_ids"] = action_ids
  105. else:
  106. models.remove(model)
  107. elif custom_model.is_active and (
  108. custom_model.id not in [model["id"] for model in models]
  109. ):
  110. owned_by = "openai"
  111. pipe = None
  112. action_ids = []
  113. for model in models:
  114. if (
  115. custom_model.base_model_id == model["id"]
  116. or custom_model.base_model_id == model["id"].split(":")[0]
  117. ):
  118. owned_by = model["owned_by"]
  119. if "pipe" in model:
  120. pipe = model["pipe"]
  121. break
  122. if custom_model.meta:
  123. meta = custom_model.meta.model_dump()
  124. if "actionIds" in meta:
  125. action_ids.extend(meta["actionIds"])
  126. models.append(
  127. {
  128. "id": f"{custom_model.id}",
  129. "name": custom_model.name,
  130. "object": "model",
  131. "created": custom_model.created_at,
  132. "owned_by": owned_by,
  133. "info": custom_model.model_dump(),
  134. "preset": True,
  135. **({"pipe": pipe} if pipe is not None else {}),
  136. "action_ids": action_ids,
  137. }
  138. )
  139. # Process action_ids to get the actions
  140. def get_action_items_from_module(function, module):
  141. actions = []
  142. if hasattr(module, "actions"):
  143. actions = module.actions
  144. return [
  145. {
  146. "id": f"{function.id}.{action['id']}",
  147. "name": action.get("name", f"{function.name} ({action['id']})"),
  148. "description": function.meta.description,
  149. "icon_url": action.get(
  150. "icon_url", function.meta.manifest.get("icon_url", None)
  151. ),
  152. }
  153. for action in actions
  154. ]
  155. else:
  156. return [
  157. {
  158. "id": function.id,
  159. "name": function.name,
  160. "description": function.meta.description,
  161. "icon_url": function.meta.manifest.get("icon_url", None),
  162. }
  163. ]
  164. def get_function_module_by_id(function_id):
  165. if function_id in request.app.state.FUNCTIONS:
  166. function_module = request.app.state.FUNCTIONS[function_id]
  167. else:
  168. function_module, _, _ = load_function_module_by_id(function_id)
  169. request.app.state.FUNCTIONS[function_id] = function_module
  170. for model in models:
  171. action_ids = [
  172. action_id
  173. for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
  174. if action_id in enabled_action_ids
  175. ]
  176. model["actions"] = []
  177. for action_id in action_ids:
  178. action_function = Functions.get_function_by_id(action_id)
  179. if action_function is None:
  180. raise Exception(f"Action not found: {action_id}")
  181. function_module = get_function_module_by_id(action_id)
  182. model["actions"].extend(
  183. get_action_items_from_module(action_function, function_module)
  184. )
  185. log.debug(f"get_all_models() returned {len(models)} models")
  186. request.app.state.MODELS = {model["id"]: model for model in models}
  187. return models
  188. def check_model_access(user, model):
  189. if model.get("arena"):
  190. if not has_access(
  191. user.id,
  192. type="read",
  193. access_control=model.get("info", {})
  194. .get("meta", {})
  195. .get("access_control", {}),
  196. ):
  197. raise Exception("Model not found")
  198. else:
  199. model_info = Models.get_model_by_id(model.get("id"))
  200. if not model_info:
  201. raise Exception("Model not found")
  202. elif not (
  203. user.id == model_info.user_id
  204. or has_access(
  205. user.id, type="read", access_control=model_info.access_control
  206. )
  207. ):
  208. raise Exception("Model not found")