misc.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. from pathlib import Path
  2. import hashlib
  3. import json
  4. import re
  5. from datetime import timedelta
  6. from typing import Optional, List, Tuple
  7. import uuid
  8. import time
  9. def get_last_user_message_item(messages: List[dict]) -> str:
  10. for message in reversed(messages):
  11. if message["role"] == "user":
  12. return message
  13. return None
  14. def get_last_user_message(messages: List[dict]) -> str:
  15. message = get_last_user_message_item(messages)
  16. if message is not None:
  17. if isinstance(message["content"], list):
  18. for item in message["content"]:
  19. if item["type"] == "text":
  20. return item["text"]
  21. return message["content"]
  22. return None
  23. def get_last_assistant_message(messages: List[dict]) -> str:
  24. for message in reversed(messages):
  25. if message["role"] == "assistant":
  26. if isinstance(message["content"], list):
  27. for item in message["content"]:
  28. if item["type"] == "text":
  29. return item["text"]
  30. return message["content"]
  31. return None
  32. def get_system_message(messages: List[dict]) -> dict:
  33. for message in messages:
  34. if message["role"] == "system":
  35. return message
  36. return None
  37. def remove_system_message(messages: List[dict]) -> List[dict]:
  38. return [message for message in messages if message["role"] != "system"]
  39. def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
  40. return get_system_message(messages), remove_system_message(messages)
  41. def add_or_update_system_message(content: str, messages: List[dict]):
  42. """
  43. Adds a new system message at the beginning of the messages list
  44. or updates the existing system message at the beginning.
  45. :param msg: The message to be added or appended.
  46. :param messages: The list of message dictionaries.
  47. :return: The updated list of message dictionaries.
  48. """
  49. if messages and messages[0].get("role") == "system":
  50. messages[0]["content"] += f"{content}\n{messages[0]['content']}"
  51. else:
  52. # Insert at the beginning
  53. messages.insert(0, {"role": "system", "content": content})
  54. return messages
  55. def stream_message_template(model: str, message: str):
  56. return {
  57. "id": f"{model}-{str(uuid.uuid4())}",
  58. "object": "chat.completion.chunk",
  59. "created": int(time.time()),
  60. "model": model,
  61. "choices": [
  62. {
  63. "index": 0,
  64. "delta": {"content": message},
  65. "logprobs": None,
  66. "finish_reason": None,
  67. }
  68. ],
  69. }
  70. def get_gravatar_url(email):
  71. # Trim leading and trailing whitespace from
  72. # an email address and force all characters
  73. # to lower case
  74. address = str(email).strip().lower()
  75. # Create a SHA256 hash of the final string
  76. hash_object = hashlib.sha256(address.encode())
  77. hash_hex = hash_object.hexdigest()
  78. # Grab the actual image URL
  79. return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
  80. def calculate_sha256(file):
  81. sha256 = hashlib.sha256()
  82. # Read the file in chunks to efficiently handle large files
  83. for chunk in iter(lambda: file.read(8192), b""):
  84. sha256.update(chunk)
  85. return sha256.hexdigest()
  86. def calculate_sha256_string(string):
  87. # Create a new SHA-256 hash object
  88. sha256_hash = hashlib.sha256()
  89. # Update the hash object with the bytes of the input string
  90. sha256_hash.update(string.encode("utf-8"))
  91. # Get the hexadecimal representation of the hash
  92. hashed_string = sha256_hash.hexdigest()
  93. return hashed_string
  94. def validate_email_format(email: str) -> bool:
  95. if email.endswith("@localhost"):
  96. return True
  97. return bool(re.match(r"[^@]+@[^@]+\.[^@]+", email))
  98. def sanitize_filename(file_name):
  99. # Convert to lowercase
  100. lower_case_file_name = file_name.lower()
  101. # Remove special characters using regular expression
  102. sanitized_file_name = re.sub(r"[^\w\s]", "", lower_case_file_name)
  103. # Replace spaces with dashes
  104. final_file_name = re.sub(r"\s+", "-", sanitized_file_name)
  105. return final_file_name
  106. def extract_folders_after_data_docs(path):
  107. # Convert the path to a Path object if it's not already
  108. path = Path(path)
  109. # Extract parts of the path
  110. parts = path.parts
  111. # Find the index of '/data/docs' in the path
  112. try:
  113. index_data_docs = parts.index("data") + 1
  114. index_docs = parts.index("docs", index_data_docs) + 1
  115. except ValueError:
  116. return []
  117. # Exclude the filename and accumulate folder names
  118. tags = []
  119. folders = parts[index_docs:-1]
  120. for idx, part in enumerate(folders):
  121. tags.append("/".join(folders[: idx + 1]))
  122. return tags
  123. def parse_duration(duration: str) -> Optional[timedelta]:
  124. if duration == "-1" or duration == "0":
  125. return None
  126. # Regular expression to find number and unit pairs
  127. pattern = r"(-?\d+(\.\d+)?)(ms|s|m|h|d|w)"
  128. matches = re.findall(pattern, duration)
  129. if not matches:
  130. raise ValueError("Invalid duration string")
  131. total_duration = timedelta()
  132. for number, _, unit in matches:
  133. number = float(number)
  134. if unit == "ms":
  135. total_duration += timedelta(milliseconds=number)
  136. elif unit == "s":
  137. total_duration += timedelta(seconds=number)
  138. elif unit == "m":
  139. total_duration += timedelta(minutes=number)
  140. elif unit == "h":
  141. total_duration += timedelta(hours=number)
  142. elif unit == "d":
  143. total_duration += timedelta(days=number)
  144. elif unit == "w":
  145. total_duration += timedelta(weeks=number)
  146. return total_duration
  147. def parse_ollama_modelfile(model_text):
  148. parameters_meta = {
  149. "mirostat": int,
  150. "mirostat_eta": float,
  151. "mirostat_tau": float,
  152. "num_ctx": int,
  153. "repeat_last_n": int,
  154. "repeat_penalty": float,
  155. "temperature": float,
  156. "seed": int,
  157. "tfs_z": float,
  158. "num_predict": int,
  159. "top_k": int,
  160. "top_p": float,
  161. "num_keep": int,
  162. "typical_p": float,
  163. "presence_penalty": float,
  164. "frequency_penalty": float,
  165. "penalize_newline": bool,
  166. "numa": bool,
  167. "num_batch": int,
  168. "num_gpu": int,
  169. "main_gpu": int,
  170. "low_vram": bool,
  171. "f16_kv": bool,
  172. "vocab_only": bool,
  173. "use_mmap": bool,
  174. "use_mlock": bool,
  175. "num_thread": int,
  176. }
  177. data = {"base_model_id": None, "params": {}}
  178. # Parse base model
  179. base_model_match = re.search(
  180. r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE
  181. )
  182. if base_model_match:
  183. data["base_model_id"] = base_model_match.group(1)
  184. # Parse template
  185. template_match = re.search(
  186. r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
  187. )
  188. if template_match:
  189. data["params"] = {"template": template_match.group(1).strip()}
  190. # Parse stops
  191. stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE)
  192. if stops:
  193. data["params"]["stop"] = stops
  194. # Parse other parameters from the provided list
  195. for param, param_type in parameters_meta.items():
  196. param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
  197. if param_match:
  198. value = param_match.group(1)
  199. try:
  200. if param_type == int:
  201. value = int(value)
  202. elif param_type == float:
  203. value = float(value)
  204. elif param_type == bool:
  205. value = value.lower() == "true"
  206. except Exception as e:
  207. print(e)
  208. continue
  209. data["params"][param] = value
  210. # Parse adapter
  211. adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE)
  212. if adapter_match:
  213. data["params"]["adapter"] = adapter_match.group(1)
  214. # Parse system description
  215. system_desc_match = re.search(
  216. r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
  217. )
  218. system_desc_match_single = re.search(
  219. r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE
  220. )
  221. if system_desc_match:
  222. data["params"]["system"] = system_desc_match.group(1).strip()
  223. elif system_desc_match_single:
  224. data["params"]["system"] = system_desc_match_single.group(1).strip()
  225. # Parse messages
  226. messages = []
  227. message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE)
  228. for role, content in message_matches:
  229. messages.append({"role": role, "content": content})
  230. if messages:
  231. data["params"]["messages"] = messages
  232. return data