misc.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. from pathlib import Path
  2. import hashlib
  3. import re
  4. from datetime import timedelta
  5. from typing import Optional, Callable
  6. import uuid
  7. import time
  8. from utils.task import prompt_template
  9. def get_last_user_message_item(messages: list[dict]) -> Optional[dict]:
  10. for message in reversed(messages):
  11. if message["role"] == "user":
  12. return message
  13. return None
  14. def get_content_from_message(message: dict) -> Optional[str]:
  15. if isinstance(message["content"], list):
  16. for item in message["content"]:
  17. if item["type"] == "text":
  18. return item["text"]
  19. else:
  20. return message["content"]
  21. return None
  22. def get_last_user_message(messages: list[dict]) -> Optional[str]:
  23. message = get_last_user_message_item(messages)
  24. if message is None:
  25. return None
  26. return get_content_from_message(message)
  27. def get_last_assistant_message(messages: list[dict]) -> Optional[str]:
  28. for message in reversed(messages):
  29. if message["role"] == "assistant":
  30. return get_content_from_message(message)
  31. return None
  32. def get_system_message(messages: list[dict]) -> Optional[dict]:
  33. for message in messages:
  34. if message["role"] == "system":
  35. return message
  36. return None
  37. def remove_system_message(messages: list[dict]) -> list[dict]:
  38. return [message for message in messages if message["role"] != "system"]
  39. def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]]:
  40. return get_system_message(messages), remove_system_message(messages)
  41. def prepend_to_first_user_message_content(
  42. content: str, messages: list[dict]
  43. ) -> list[dict]:
  44. for message in messages:
  45. if message["role"] == "user":
  46. if isinstance(message["content"], list):
  47. for item in message["content"]:
  48. if item["type"] == "text":
  49. item["text"] = f"{content}\n{item['text']}"
  50. else:
  51. message["content"] = f"{content}\n{message['content']}"
  52. break
  53. return messages
  54. def add_or_update_system_message(content: str, messages: list[dict]):
  55. """
  56. Adds a new system message at the beginning of the messages list
  57. or updates the existing system message at the beginning.
  58. :param msg: The message to be added or appended.
  59. :param messages: The list of message dictionaries.
  60. :return: The updated list of message dictionaries.
  61. """
  62. if messages and messages[0].get("role") == "system":
  63. messages[0]["content"] += f"{content}\n{messages[0]['content']}"
  64. else:
  65. # Insert at the beginning
  66. messages.insert(0, {"role": "system", "content": content})
  67. return messages
  68. def openai_chat_message_template(model: str):
  69. return {
  70. "id": f"{model}-{str(uuid.uuid4())}",
  71. "created": int(time.time()),
  72. "model": model,
  73. "choices": [{"index": 0, "logprobs": None, "finish_reason": None}],
  74. }
  75. def openai_chat_chunk_message_template(model: str, message: str) -> dict:
  76. template = openai_chat_message_template(model)
  77. template["object"] = "chat.completion.chunk"
  78. template["choices"][0]["delta"] = {"content": message}
  79. return template
  80. def openai_chat_completion_message_template(model: str, message: str) -> dict:
  81. template = openai_chat_message_template(model)
  82. template["object"] = "chat.completion"
  83. template["choices"][0]["message"] = {"content": message, "role": "assistant"}
  84. template["choices"][0]["finish_reason"] = "stop"
  85. return template
  86. # inplace function: form_data is modified
  87. def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
  88. system = params.get("system", None)
  89. if not system:
  90. return form_data
  91. if user:
  92. template_params = {
  93. "user_name": user.name,
  94. "user_location": user.info.get("location") if user.info else None,
  95. }
  96. else:
  97. template_params = {}
  98. system = prompt_template(system, **template_params)
  99. form_data["messages"] = add_or_update_system_message(
  100. system, form_data.get("messages", [])
  101. )
  102. return form_data
  103. # inplace function: form_data is modified
  104. def apply_model_params_to_body(
  105. params: dict, form_data: dict, mappings: dict[str, Callable]
  106. ) -> dict:
  107. if not params:
  108. return form_data
  109. for key, cast_func in mappings.items():
  110. if (value := params.get(key)) is not None:
  111. form_data[key] = cast_func(value)
  112. return form_data
  113. # inplace function: form_data is modified
  114. def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
  115. mappings = {
  116. "temperature": float,
  117. "top_p": int,
  118. "max_tokens": int,
  119. "frequency_penalty": int,
  120. "seed": lambda x: x,
  121. "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
  122. }
  123. return apply_model_params_to_body(params, form_data, mappings)
  124. def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
  125. opts = [
  126. "temperature",
  127. "top_p",
  128. "seed",
  129. "mirostat",
  130. "mirostat_eta",
  131. "mirostat_tau",
  132. "num_ctx",
  133. "num_batch",
  134. "num_keep",
  135. "repeat_last_n",
  136. "tfs_z",
  137. "top_k",
  138. "min_p",
  139. "use_mmap",
  140. "use_mlock",
  141. "num_thread",
  142. "num_gpu",
  143. ]
  144. mappings = {i: lambda x: x for i in opts}
  145. form_data = apply_model_params_to_body(params, form_data, mappings)
  146. name_differences = {
  147. "max_tokens": "num_predict",
  148. "frequency_penalty": "repeat_penalty",
  149. }
  150. for key, value in name_differences.items():
  151. if (param := params.get(key, None)) is not None:
  152. form_data[value] = param
  153. return form_data
  154. def get_gravatar_url(email):
  155. # Trim leading and trailing whitespace from
  156. # an email address and force all characters
  157. # to lower case
  158. address = str(email).strip().lower()
  159. # Create a SHA256 hash of the final string
  160. hash_object = hashlib.sha256(address.encode())
  161. hash_hex = hash_object.hexdigest()
  162. # Grab the actual image URL
  163. return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
  164. def calculate_sha256(file):
  165. sha256 = hashlib.sha256()
  166. # Read the file in chunks to efficiently handle large files
  167. for chunk in iter(lambda: file.read(8192), b""):
  168. sha256.update(chunk)
  169. return sha256.hexdigest()
  170. def calculate_sha256_string(string):
  171. # Create a new SHA-256 hash object
  172. sha256_hash = hashlib.sha256()
  173. # Update the hash object with the bytes of the input string
  174. sha256_hash.update(string.encode("utf-8"))
  175. # Get the hexadecimal representation of the hash
  176. hashed_string = sha256_hash.hexdigest()
  177. return hashed_string
  178. def validate_email_format(email: str) -> bool:
  179. if email.endswith("@localhost"):
  180. return True
  181. return bool(re.match(r"[^@]+@[^@]+\.[^@]+", email))
  182. def sanitize_filename(file_name):
  183. # Convert to lowercase
  184. lower_case_file_name = file_name.lower()
  185. # Remove special characters using regular expression
  186. sanitized_file_name = re.sub(r"[^\w\s]", "", lower_case_file_name)
  187. # Replace spaces with dashes
  188. final_file_name = re.sub(r"\s+", "-", sanitized_file_name)
  189. return final_file_name
  190. def extract_folders_after_data_docs(path):
  191. # Convert the path to a Path object if it's not already
  192. path = Path(path)
  193. # Extract parts of the path
  194. parts = path.parts
  195. # Find the index of '/data/docs' in the path
  196. try:
  197. index_data_docs = parts.index("data") + 1
  198. index_docs = parts.index("docs", index_data_docs) + 1
  199. except ValueError:
  200. return []
  201. # Exclude the filename and accumulate folder names
  202. tags = []
  203. folders = parts[index_docs:-1]
  204. for idx, _ in enumerate(folders):
  205. tags.append("/".join(folders[: idx + 1]))
  206. return tags
  207. def parse_duration(duration: str) -> Optional[timedelta]:
  208. if duration == "-1" or duration == "0":
  209. return None
  210. # Regular expression to find number and unit pairs
  211. pattern = r"(-?\d+(\.\d+)?)(ms|s|m|h|d|w)"
  212. matches = re.findall(pattern, duration)
  213. if not matches:
  214. raise ValueError("Invalid duration string")
  215. total_duration = timedelta()
  216. for number, _, unit in matches:
  217. number = float(number)
  218. if unit == "ms":
  219. total_duration += timedelta(milliseconds=number)
  220. elif unit == "s":
  221. total_duration += timedelta(seconds=number)
  222. elif unit == "m":
  223. total_duration += timedelta(minutes=number)
  224. elif unit == "h":
  225. total_duration += timedelta(hours=number)
  226. elif unit == "d":
  227. total_duration += timedelta(days=number)
  228. elif unit == "w":
  229. total_duration += timedelta(weeks=number)
  230. return total_duration
  231. def parse_ollama_modelfile(model_text):
  232. parameters_meta = {
  233. "mirostat": int,
  234. "mirostat_eta": float,
  235. "mirostat_tau": float,
  236. "num_ctx": int,
  237. "repeat_last_n": int,
  238. "repeat_penalty": float,
  239. "temperature": float,
  240. "seed": int,
  241. "tfs_z": float,
  242. "num_predict": int,
  243. "top_k": int,
  244. "top_p": float,
  245. "num_keep": int,
  246. "typical_p": float,
  247. "presence_penalty": float,
  248. "frequency_penalty": float,
  249. "penalize_newline": bool,
  250. "numa": bool,
  251. "num_batch": int,
  252. "num_gpu": int,
  253. "main_gpu": int,
  254. "low_vram": bool,
  255. "f16_kv": bool,
  256. "vocab_only": bool,
  257. "use_mmap": bool,
  258. "use_mlock": bool,
  259. "num_thread": int,
  260. }
  261. data = {"base_model_id": None, "params": {}}
  262. # Parse base model
  263. base_model_match = re.search(
  264. r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE
  265. )
  266. if base_model_match:
  267. data["base_model_id"] = base_model_match.group(1)
  268. # Parse template
  269. template_match = re.search(
  270. r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
  271. )
  272. if template_match:
  273. data["params"] = {"template": template_match.group(1).strip()}
  274. # Parse stops
  275. stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE)
  276. if stops:
  277. data["params"]["stop"] = stops
  278. # Parse other parameters from the provided list
  279. for param, param_type in parameters_meta.items():
  280. param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
  281. if param_match:
  282. value = param_match.group(1)
  283. try:
  284. if param_type is int:
  285. value = int(value)
  286. elif param_type is float:
  287. value = float(value)
  288. elif param_type is bool:
  289. value = value.lower() == "true"
  290. except Exception as e:
  291. print(e)
  292. continue
  293. data["params"][param] = value
  294. # Parse adapter
  295. adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE)
  296. if adapter_match:
  297. data["params"]["adapter"] = adapter_match.group(1)
  298. # Parse system description
  299. system_desc_match = re.search(
  300. r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
  301. )
  302. system_desc_match_single = re.search(
  303. r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE
  304. )
  305. if system_desc_match:
  306. data["params"]["system"] = system_desc_match.group(1).strip()
  307. elif system_desc_match_single:
  308. data["params"]["system"] = system_desc_match_single.group(1).strip()
  309. # Parse messages
  310. messages = []
  311. message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE)
  312. for role, content in message_matches:
  313. messages.append({"role": role, "content": content})
  314. if messages:
  315. data["params"]["messages"] = messages
  316. return data