server.go 31 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075
  1. package llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log"
  11. "log/slog"
  12. "math/rand"
  13. "net"
  14. "net/http"
  15. "os"
  16. "os/exec"
  17. "path/filepath"
  18. "runtime"
  19. "strconv"
  20. "strings"
  21. "sync"
  22. "time"
  23. "golang.org/x/sync/semaphore"
  24. "github.com/ollama/ollama/api"
  25. "github.com/ollama/ollama/discover"
  26. "github.com/ollama/ollama/envconfig"
  27. "github.com/ollama/ollama/format"
  28. "github.com/ollama/ollama/fs/ggml"
  29. "github.com/ollama/ollama/llama"
  30. "github.com/ollama/ollama/model"
  31. )
  32. type LlamaServer interface {
  33. Ping(ctx context.Context) error
  34. WaitUntilRunning(ctx context.Context) error
  35. Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
  36. Embedding(ctx context.Context, input string) ([]float32, error)
  37. Tokenize(ctx context.Context, content string) ([]int, error)
  38. Detokenize(ctx context.Context, tokens []int) (string, error)
  39. Close() error
  40. EstimatedVRAM() uint64 // Total VRAM across all GPUs
  41. EstimatedTotal() uint64
  42. EstimatedVRAMByGPU(gpuID string) uint64
  43. }
  44. // llmServer is an instance of the llama.cpp server
  45. type llmServer struct {
  46. port int
  47. cmd *exec.Cmd
  48. done chan error // Channel to signal when the process exits
  49. status *StatusWriter
  50. options api.Options
  51. numParallel int
  52. modelPath string
  53. // llamaModel is an instance of the cgo llama.cpp model definition
  54. // nil if this server is running the new engine
  55. llamaModel *llama.Model
  56. llamaModelLock sync.Mutex
  57. // textProcessor handles text encoding/decoding for the model in the Ollama engine
  58. // nil if this server is running the llama.cpp based engine
  59. textProcessor model.TextProcessor
  60. estimate MemoryEstimate
  61. totalLayers uint64
  62. // gpuCount int
  63. gpus discover.GpuInfoList // Recorded just before the model loaded, free space will be incorrect
  64. loadDuration time.Duration // Record how long it took the model to load
  65. loadProgress float32
  66. sem *semaphore.Weighted
  67. }
  68. // LoadModel will load a model from disk. The model must be in the GGML format.
  69. //
  70. // It collects array values for arrays with a size less than or equal to
  71. // maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
  72. // the maxArraySize is negative, all arrays are collected.
  73. func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
  74. if _, err := os.Stat(model); err != nil {
  75. return nil, err
  76. }
  77. f, err := os.Open(model)
  78. if err != nil {
  79. return nil, err
  80. }
  81. defer f.Close()
  82. ggml, _, err := ggml.Decode(f, maxArraySize)
  83. return ggml, err
  84. }
  85. // NewLlamaServer will run a server for the given GPUs
  86. // The gpu list must be a single family.
  87. func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
  88. systemInfo := discover.GetSystemInfo()
  89. systemTotalMemory := systemInfo.System.TotalMemory
  90. systemFreeMemory := systemInfo.System.FreeMemory
  91. systemSwapFreeMemory := systemInfo.System.FreeSwap
  92. slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
  93. // If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
  94. if opts.NumGPU == 0 {
  95. gpus = discover.GetCPUInfo()
  96. }
  97. estimate := EstimateGPULayers(gpus, f, projectors, opts)
  98. if len(gpus) > 1 || gpus[0].Library != "cpu" {
  99. switch {
  100. case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
  101. // disable partial offloading when model is greater than total system memory as this
  102. // can lead to locking up the system
  103. opts.NumGPU = 0
  104. case gpus[0].Library != "metal" && estimate.Layers == 0:
  105. // Don't bother loading into the GPU if no layers can fit
  106. gpus = discover.GetCPUInfo()
  107. case opts.NumGPU < 0 && estimate.Layers > 0 && gpus[0].Library != "cpu":
  108. opts.NumGPU = estimate.Layers
  109. }
  110. }
  111. // On linux and windows, over-allocating CPU memory will almost always result in an error
  112. // Darwin has fully dynamic swap so has no direct concept of free swap space
  113. if runtime.GOOS != "darwin" {
  114. systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
  115. available := systemFreeMemory + systemSwapFreeMemory
  116. if systemMemoryRequired > available {
  117. slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
  118. return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
  119. }
  120. }
  121. slog.Info("offload", "", estimate)
  122. params := []string{
  123. "--model", modelPath,
  124. "--ctx-size", strconv.Itoa(opts.NumCtx),
  125. "--batch-size", strconv.Itoa(opts.NumBatch),
  126. }
  127. if opts.NumGPU >= 0 {
  128. params = append(params, "--n-gpu-layers", strconv.Itoa(opts.NumGPU))
  129. }
  130. if envconfig.Debug() {
  131. params = append(params, "--verbose")
  132. }
  133. if opts.MainGPU > 0 {
  134. params = append(params, "--main-gpu", strconv.Itoa(opts.MainGPU))
  135. }
  136. if len(adapters) > 0 {
  137. for _, adapter := range adapters {
  138. params = append(params, "--lora", adapter)
  139. }
  140. }
  141. defaultThreads := systemInfo.GetOptimalThreadCount()
  142. if opts.NumThread > 0 {
  143. params = append(params, "--threads", strconv.Itoa(opts.NumThread))
  144. } else if defaultThreads > 0 {
  145. params = append(params, "--threads", strconv.Itoa(defaultThreads))
  146. }
  147. fa := envconfig.FlashAttention()
  148. if fa && !gpus.FlashAttentionSupported() {
  149. slog.Warn("flash attention enabled but not supported by gpu")
  150. fa = false
  151. }
  152. if fa && !f.SupportsFlashAttention() {
  153. slog.Warn("flash attention enabled but not supported by model")
  154. fa = false
  155. }
  156. kvct := strings.ToLower(envconfig.KvCacheType())
  157. if fa {
  158. slog.Info("enabling flash attention")
  159. params = append(params, "--flash-attn")
  160. // Flash Attention also supports kv cache quantization
  161. // Enable if the requested and kv cache type is supported by the model
  162. if kvct != "" && f.SupportsKVCacheType(kvct) {
  163. params = append(params, "--kv-cache-type", kvct)
  164. } else {
  165. slog.Warn("kv cache type not supported by model", "type", kvct)
  166. }
  167. } else if kvct != "" && kvct != "f16" {
  168. slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
  169. }
  170. // mmap has issues with partial offloading on metal
  171. for _, g := range gpus {
  172. if g.Library == "metal" &&
  173. uint64(opts.NumGPU) > 0 &&
  174. uint64(opts.NumGPU) < f.KV().BlockCount()+1 {
  175. opts.UseMMap = new(bool)
  176. *opts.UseMMap = false
  177. }
  178. }
  179. // Windows CUDA should not use mmap for best performance
  180. // Linux with a model larger than free space, mmap leads to thrashing
  181. // For CPU loads we want the memory to be allocated, not FS cache
  182. if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == nil) ||
  183. (runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == nil) ||
  184. (gpus[0].Library == "cpu" && opts.UseMMap == nil) ||
  185. (opts.UseMMap != nil && !*opts.UseMMap) {
  186. params = append(params, "--no-mmap")
  187. }
  188. if opts.UseMLock {
  189. params = append(params, "--mlock")
  190. }
  191. // TODO - NUMA support currently doesn't work properly
  192. params = append(params, "--parallel", strconv.Itoa(numParallel))
  193. if estimate.TensorSplit != "" {
  194. params = append(params, "--tensor-split", estimate.TensorSplit)
  195. }
  196. if envconfig.MultiUserCache() {
  197. params = append(params, "--multiuser-cache")
  198. }
  199. libs := make(map[string]string)
  200. if entries, err := os.ReadDir(discover.LibOllamaPath); err == nil {
  201. for _, entry := range entries {
  202. libs[entry.Name()] = filepath.Join(discover.LibOllamaPath, entry.Name())
  203. }
  204. }
  205. lib := gpus[0].RunnerName()
  206. requested := envconfig.LLMLibrary()
  207. if libs[requested] != "" {
  208. slog.Info("using requested gpu library", "requested", requested)
  209. lib = requested
  210. }
  211. var compatible []string
  212. for k := range libs {
  213. // exact match first
  214. if k == lib {
  215. compatible = append([]string{k}, compatible...)
  216. continue
  217. }
  218. // then match the family (e.g. 'cuda')
  219. if strings.Split(k, "_")[0] == strings.Split(lib, "_")[0] {
  220. compatible = append(compatible, k)
  221. }
  222. }
  223. slog.Debug("compatible gpu libraries", "compatible", compatible)
  224. exe, err := os.Executable()
  225. if err != nil {
  226. return nil, fmt.Errorf("unable to lookup executable path: %w", err)
  227. }
  228. if eval, err := filepath.EvalSymlinks(exe); err == nil {
  229. exe = eval
  230. }
  231. var llamaModel *llama.Model
  232. var textProcessor model.TextProcessor
  233. if envconfig.NewEngine() {
  234. textProcessor, err = model.NewTextProcessor(modelPath)
  235. if err != nil {
  236. // To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
  237. slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
  238. }
  239. }
  240. if textProcessor == nil {
  241. llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
  242. if err != nil {
  243. return nil, err
  244. }
  245. }
  246. if len(projectors) > 0 && llamaModel != nil {
  247. params = append(params, "--mmproj", projectors[0])
  248. }
  249. // iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
  250. // adding each library's respective path to the LD_LIBRARY_PATH, until finally running
  251. // without any LD_LIBRARY_PATH flags
  252. for {
  253. port := 0
  254. if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
  255. var l *net.TCPListener
  256. if l, err = net.ListenTCP("tcp", a); err == nil {
  257. port = l.Addr().(*net.TCPAddr).Port
  258. l.Close()
  259. }
  260. }
  261. if port == 0 {
  262. slog.Debug("ResolveTCPAddr failed, using random port")
  263. port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  264. }
  265. finalParams := []string{"runner"}
  266. if textProcessor != nil {
  267. // New engine
  268. // TODO - if we have failure to load scenarios, add logic to retry with the old runner
  269. finalParams = append(finalParams, "--ollama-engine")
  270. }
  271. finalParams = append(finalParams, params...)
  272. finalParams = append(finalParams, "--port", strconv.Itoa(port))
  273. var pathEnv string
  274. switch runtime.GOOS {
  275. case "windows":
  276. pathEnv = "PATH"
  277. case "darwin":
  278. pathEnv = "DYLD_LIBRARY_PATH"
  279. default:
  280. pathEnv = "LD_LIBRARY_PATH"
  281. }
  282. var libraryPaths []string
  283. if libraryPath, ok := os.LookupEnv(pathEnv); ok {
  284. libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
  285. }
  286. if len(compatible) > 0 {
  287. c := compatible[0]
  288. if libpath, ok := libs[c]; ok {
  289. slog.Debug("adding gpu library", "path", libpath)
  290. libraryPaths = append(libraryPaths, libpath)
  291. }
  292. }
  293. // Note: we always put the dependency path first
  294. // since this was the exact version we compiled/linked against
  295. if gpus[0].DependencyPath != nil {
  296. slog.Debug("adding gpu dependency paths", "paths", gpus[0].DependencyPath)
  297. // assume gpus from the same library have the same dependency path
  298. libraryPaths = append(gpus[0].DependencyPath, libraryPaths...)
  299. }
  300. // finally, add the root library path
  301. libraryPaths = append(libraryPaths, discover.LibOllamaPath)
  302. s := &llmServer{
  303. port: port,
  304. cmd: exec.Command(exe, finalParams...),
  305. status: NewStatusWriter(os.Stderr),
  306. options: opts,
  307. modelPath: modelPath,
  308. llamaModel: llamaModel,
  309. textProcessor: textProcessor,
  310. estimate: estimate,
  311. numParallel: numParallel,
  312. sem: semaphore.NewWeighted(int64(numParallel)),
  313. totalLayers: f.KV().BlockCount() + 1,
  314. gpus: gpus,
  315. done: make(chan error, 1),
  316. }
  317. s.cmd.Env = os.Environ()
  318. s.cmd.Stdout = os.Stdout
  319. s.cmd.Stderr = s.status
  320. s.cmd.SysProcAttr = LlamaServerSysProcAttr
  321. envWorkarounds := [][2]string{}
  322. for _, gpu := range gpus {
  323. envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...)
  324. }
  325. visibleDevicesEnv, visibleDevicesEnvVal := gpus.GetVisibleDevicesEnv()
  326. pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
  327. // Update or add the path and visible devices variable with our adjusted version
  328. pathNeeded := true
  329. devicesNeeded := visibleDevicesEnv != ""
  330. for i := range s.cmd.Env {
  331. cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
  332. if strings.EqualFold(cmp[0], pathEnv) {
  333. s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
  334. pathNeeded = false
  335. } else if devicesNeeded && strings.EqualFold(cmp[0], visibleDevicesEnv) {
  336. s.cmd.Env[i] = visibleDevicesEnv + "=" + visibleDevicesEnvVal
  337. devicesNeeded = false
  338. } else if len(envWorkarounds) != 0 {
  339. for _, kv := range envWorkarounds {
  340. if strings.EqualFold(cmp[0], kv[0]) {
  341. s.cmd.Env[i] = kv[0] + "=" + kv[1]
  342. }
  343. }
  344. }
  345. }
  346. if pathNeeded {
  347. s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
  348. }
  349. if devicesNeeded {
  350. s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
  351. }
  352. slog.Info("starting llama server", "cmd", s.cmd.String())
  353. if envconfig.Debug() {
  354. filteredEnv := []string{}
  355. for _, ev := range s.cmd.Env {
  356. if strings.HasPrefix(ev, "CUDA_") ||
  357. strings.HasPrefix(ev, "ROCR_") ||
  358. strings.HasPrefix(ev, "ROCM_") ||
  359. strings.HasPrefix(ev, "HIP_") ||
  360. strings.HasPrefix(ev, "GPU_") ||
  361. strings.HasPrefix(ev, "HSA_") ||
  362. strings.HasPrefix(ev, "GGML_") ||
  363. strings.HasPrefix(ev, "PATH=") ||
  364. strings.HasPrefix(ev, "LD_LIBRARY_PATH=") ||
  365. strings.HasPrefix(ev, "DYLD_LIBRARY_PATH=") {
  366. filteredEnv = append(filteredEnv, ev)
  367. }
  368. }
  369. // Log at debug as the environment is inherited and might contain sensitive information
  370. slog.Debug("subprocess", "environment", filteredEnv)
  371. }
  372. if err = s.cmd.Start(); err != nil {
  373. var msg string
  374. if s.status != nil && s.status.LastErrMsg != "" {
  375. msg = s.status.LastErrMsg
  376. }
  377. err := fmt.Errorf("error starting runner: %v %s", err, msg)
  378. if len(compatible) == 0 {
  379. if llamaModel != nil {
  380. llama.FreeModel(llamaModel)
  381. }
  382. return nil, err
  383. }
  384. slog.Warn("unable to start runner with compatible gpu", "error", err, "compatible", compatible)
  385. compatible = compatible[1:]
  386. continue
  387. }
  388. // reap subprocess when it exits
  389. go func() {
  390. err := s.cmd.Wait()
  391. // Favor a more detailed message over the process exit status
  392. if err != nil && s.status != nil && s.status.LastErrMsg != "" {
  393. slog.Error("llama runner terminated", "error", err)
  394. if strings.Contains(s.status.LastErrMsg, "unknown model") {
  395. s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
  396. }
  397. s.done <- errors.New(s.status.LastErrMsg)
  398. } else {
  399. s.done <- err
  400. }
  401. }()
  402. return s, nil
  403. }
  404. }
  405. type ServerStatus int
  406. const ( // iota is reset to 0
  407. ServerStatusReady ServerStatus = iota
  408. ServerStatusNoSlotsAvailable
  409. ServerStatusLoadingModel
  410. ServerStatusNotResponding
  411. ServerStatusError
  412. )
  413. func (s ServerStatus) ToString() string {
  414. switch s {
  415. case ServerStatusReady:
  416. return "llm server ready"
  417. case ServerStatusNoSlotsAvailable:
  418. return "llm busy - no slots available"
  419. case ServerStatusLoadingModel:
  420. return "llm server loading model"
  421. case ServerStatusNotResponding:
  422. return "llm server not responding"
  423. default:
  424. return "llm server error"
  425. }
  426. }
  427. type ServerStatusResp struct {
  428. Status string `json:"status"`
  429. SlotsIdle int `json:"slots_idle"`
  430. SlotsProcessing int `json:"slots_processing"`
  431. Error string `json:"error"`
  432. Progress float32 `json:"progress"`
  433. }
  434. func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
  435. // Fail fast if its exited
  436. if s.cmd.ProcessState != nil {
  437. msg := ""
  438. if s.status != nil && s.status.LastErrMsg != "" {
  439. msg = s.status.LastErrMsg
  440. }
  441. if s.cmd.ProcessState.ExitCode() == -1 {
  442. // Most likely a signal killed it, log some more details to try to help troubleshoot
  443. slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
  444. }
  445. return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
  446. }
  447. req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/health", s.port), nil)
  448. if err != nil {
  449. return ServerStatusError, fmt.Errorf("error creating GET request: %v", err)
  450. }
  451. req.Header.Set("Content-Type", "application/json")
  452. resp, err := http.DefaultClient.Do(req)
  453. if err != nil {
  454. if errors.Is(err, context.DeadlineExceeded) {
  455. return ServerStatusNotResponding, errors.New("server not responding")
  456. }
  457. return ServerStatusError, fmt.Errorf("health resp: %w", err)
  458. }
  459. defer resp.Body.Close()
  460. body, err := io.ReadAll(resp.Body)
  461. if err != nil {
  462. return ServerStatusError, fmt.Errorf("read health request: %w", err)
  463. }
  464. var status ServerStatusResp
  465. if err := json.Unmarshal(body, &status); err != nil {
  466. return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
  467. }
  468. switch status.Status {
  469. case "ok":
  470. return ServerStatusReady, nil
  471. case "no slot available":
  472. return ServerStatusNoSlotsAvailable, nil
  473. case "loading model":
  474. s.loadProgress = status.Progress
  475. return ServerStatusLoadingModel, nil
  476. default:
  477. return ServerStatusError, fmt.Errorf("server error: %+v", status)
  478. }
  479. }
  480. // getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received
  481. func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) {
  482. var retries int
  483. for {
  484. status, err := s.getServerStatus(ctx)
  485. if err != nil {
  486. return status, err
  487. }
  488. if status == ServerStatusNoSlotsAvailable {
  489. if retries >= 10 {
  490. return status, fmt.Errorf("no slots available after %d retries", retries)
  491. }
  492. time.Sleep(5 * time.Millisecond)
  493. retries++
  494. continue
  495. }
  496. return status, nil
  497. }
  498. }
  499. func (s *llmServer) Ping(ctx context.Context) error {
  500. _, err := s.getServerStatus(ctx)
  501. if err != nil {
  502. slog.Debug("server unhealthy", "error", err)
  503. return err
  504. }
  505. return nil
  506. }
  507. func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
  508. start := time.Now()
  509. stallDuration := envconfig.LoadTimeout() // If no progress happens
  510. stallTimer := time.Now().Add(stallDuration) // give up if we stall
  511. slog.Info("waiting for llama runner to start responding")
  512. var lastStatus ServerStatus = -1
  513. fullyLoaded := false
  514. for {
  515. select {
  516. case <-ctx.Done():
  517. slog.Warn("client connection closed before server finished loading, aborting load")
  518. return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
  519. case err := <-s.done:
  520. return fmt.Errorf("llama runner process has terminated: %w", err)
  521. default:
  522. }
  523. if time.Now().After(stallTimer) {
  524. // timeout
  525. msg := ""
  526. if s.status != nil && s.status.LastErrMsg != "" {
  527. msg = s.status.LastErrMsg
  528. }
  529. return fmt.Errorf("timed out waiting for llama runner to start - progress %0.2f - %s", s.loadProgress, msg)
  530. }
  531. if s.cmd.ProcessState != nil {
  532. msg := ""
  533. if s.status != nil && s.status.LastErrMsg != "" {
  534. msg = s.status.LastErrMsg
  535. }
  536. return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
  537. }
  538. ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
  539. defer cancel()
  540. priorProgress := s.loadProgress
  541. status, _ := s.getServerStatus(ctx)
  542. if lastStatus != status && status != ServerStatusReady {
  543. // Only log on status changes
  544. slog.Info("waiting for server to become available", "status", status.ToString())
  545. }
  546. switch status {
  547. case ServerStatusReady:
  548. s.loadDuration = time.Since(start)
  549. slog.Info(fmt.Sprintf("llama runner started in %0.2f seconds", s.loadDuration.Seconds()))
  550. return nil
  551. default:
  552. lastStatus = status
  553. // Reset the timer as long as we're making forward progress on the load
  554. if priorProgress != s.loadProgress {
  555. slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
  556. stallTimer = time.Now().Add(stallDuration)
  557. } else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
  558. slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
  559. stallTimer = time.Now().Add(stallDuration)
  560. fullyLoaded = true
  561. }
  562. time.Sleep(time.Millisecond * 250)
  563. continue
  564. }
  565. }
  566. }
  567. var grammarJSON = `
  568. root ::= object
  569. value ::= object | array | string | number | ("true" | "false" | "null") ws
  570. object ::=
  571. "{" ws (
  572. string ":" ws value
  573. ("," ws string ":" ws value)*
  574. )? "}" ws
  575. array ::=
  576. "[" ws (
  577. value
  578. ("," ws value)*
  579. )? "]" ws
  580. string ::=
  581. "\"" (
  582. [^"\\\x7F\x00-\x1F] |
  583. "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
  584. )* "\"" ws
  585. number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
  586. # Optional space: by convention, applied in this grammar after literal chars when allowed
  587. ws ::= ([ \t\n] ws)?
  588. `
  589. const maxBufferSize = 512 * format.KiloByte
  590. type ImageData struct {
  591. Data []byte `json:"data"`
  592. ID int `json:"id"`
  593. AspectRatioID int `json:"aspect_ratio_id"`
  594. }
  595. type completion struct {
  596. Content string `json:"content"`
  597. Model string `json:"model"`
  598. Prompt string `json:"prompt"`
  599. Stop bool `json:"stop"`
  600. StoppedLimit bool `json:"stopped_limit"`
  601. Timings struct {
  602. PredictedN int `json:"predicted_n"`
  603. PredictedMS float64 `json:"predicted_ms"`
  604. PromptN int `json:"prompt_n"`
  605. PromptMS float64 `json:"prompt_ms"`
  606. }
  607. }
  608. type CompletionRequest struct {
  609. Prompt string
  610. Format json.RawMessage
  611. Images []ImageData
  612. Options *api.Options
  613. }
  614. type CompletionResponse struct {
  615. Content string
  616. DoneReason string
  617. Done bool
  618. PromptEvalCount int
  619. PromptEvalDuration time.Duration
  620. EvalCount int
  621. EvalDuration time.Duration
  622. }
  623. func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
  624. request := map[string]any{
  625. "prompt": req.Prompt,
  626. "stream": true,
  627. "n_predict": req.Options.NumPredict,
  628. "n_keep": req.Options.NumKeep,
  629. "main_gpu": req.Options.MainGPU,
  630. "temperature": req.Options.Temperature,
  631. "top_k": req.Options.TopK,
  632. "top_p": req.Options.TopP,
  633. "min_p": req.Options.MinP,
  634. "typical_p": req.Options.TypicalP,
  635. "repeat_last_n": req.Options.RepeatLastN,
  636. "repeat_penalty": req.Options.RepeatPenalty,
  637. "presence_penalty": req.Options.PresencePenalty,
  638. "frequency_penalty": req.Options.FrequencyPenalty,
  639. "mirostat": req.Options.Mirostat,
  640. "mirostat_tau": req.Options.MirostatTau,
  641. "mirostat_eta": req.Options.MirostatEta,
  642. "seed": req.Options.Seed,
  643. "stop": req.Options.Stop,
  644. "image_data": req.Images,
  645. "cache_prompt": true,
  646. }
  647. if len(req.Format) > 0 {
  648. format := string(req.Format)
  649. if format != `null` && format != `""` {
  650. if s.textProcessor != nil {
  651. // New engine handles this on the backend
  652. request["format"] = req.Format
  653. } else {
  654. // old engine
  655. switch format {
  656. case `"json"`:
  657. request["grammar"] = grammarJSON
  658. default:
  659. if req.Format[0] != '{' {
  660. return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
  661. }
  662. // User provided a JSON schema
  663. g := llama.SchemaToGrammar(req.Format)
  664. if g == nil {
  665. return fmt.Errorf("invalid JSON schema in format")
  666. }
  667. request["grammar"] = string(g)
  668. }
  669. }
  670. }
  671. }
  672. if err := s.sem.Acquire(ctx, 1); err != nil {
  673. if errors.Is(err, context.Canceled) {
  674. slog.Info("aborting completion request due to client closing the connection")
  675. } else {
  676. slog.Error("Failed to acquire semaphore", "error", err)
  677. }
  678. return err
  679. }
  680. defer s.sem.Release(1)
  681. // put an upper limit on num_predict to avoid the model running on forever
  682. if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
  683. req.Options.NumPredict = 10 * s.options.NumCtx
  684. }
  685. // Make sure the server is ready
  686. status, err := s.getServerStatusRetry(ctx)
  687. if err != nil {
  688. return err
  689. } else if status != ServerStatusReady {
  690. return fmt.Errorf("unexpected server status: %s", status.ToString())
  691. }
  692. // Handling JSON marshaling with special characters unescaped.
  693. buffer := &bytes.Buffer{}
  694. enc := json.NewEncoder(buffer)
  695. enc.SetEscapeHTML(false)
  696. if err := enc.Encode(request); err != nil {
  697. return fmt.Errorf("failed to marshal data: %v", err)
  698. }
  699. endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
  700. serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
  701. if err != nil {
  702. return fmt.Errorf("error creating POST request: %v", err)
  703. }
  704. serverReq.Header.Set("Content-Type", "application/json")
  705. res, err := http.DefaultClient.Do(serverReq)
  706. if err != nil {
  707. return fmt.Errorf("POST predict: %v", err)
  708. }
  709. defer res.Body.Close()
  710. if res.StatusCode >= 400 {
  711. bodyBytes, err := io.ReadAll(res.Body)
  712. if err != nil {
  713. return fmt.Errorf("failed reading llm error response: %w", err)
  714. }
  715. log.Printf("llm predict error: %s", bodyBytes)
  716. return fmt.Errorf("%s", bodyBytes)
  717. }
  718. scanner := bufio.NewScanner(res.Body)
  719. buf := make([]byte, 0, maxBufferSize)
  720. scanner.Buffer(buf, maxBufferSize)
  721. // keep track of the last token generated, this is used to abort if the model starts looping
  722. var lastToken string
  723. var tokenRepeat int
  724. for scanner.Scan() {
  725. select {
  726. case <-ctx.Done():
  727. // This handles the request cancellation
  728. return ctx.Err()
  729. default:
  730. line := scanner.Bytes()
  731. if len(line) == 0 {
  732. continue
  733. }
  734. // slog.Debug("got line", "line", string(line))
  735. evt, ok := bytes.CutPrefix(line, []byte("data: "))
  736. if !ok {
  737. evt = line
  738. }
  739. var c completion
  740. if err := json.Unmarshal(evt, &c); err != nil {
  741. return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
  742. }
  743. switch {
  744. case strings.TrimSpace(c.Content) == lastToken:
  745. tokenRepeat++
  746. default:
  747. lastToken = strings.TrimSpace(c.Content)
  748. tokenRepeat = 0
  749. }
  750. // 30 picked as an arbitrary max token repeat limit, modify as needed
  751. if tokenRepeat > 30 {
  752. slog.Debug("prediction aborted, token repeat limit reached")
  753. return ctx.Err()
  754. }
  755. if c.Content != "" {
  756. fn(CompletionResponse{
  757. Content: c.Content,
  758. })
  759. }
  760. if c.Stop {
  761. doneReason := "stop"
  762. if c.StoppedLimit {
  763. doneReason = "length"
  764. }
  765. fn(CompletionResponse{
  766. Done: true,
  767. DoneReason: doneReason,
  768. PromptEvalCount: c.Timings.PromptN,
  769. PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
  770. EvalCount: c.Timings.PredictedN,
  771. EvalDuration: parseDurationMs(c.Timings.PredictedMS),
  772. })
  773. return nil
  774. }
  775. }
  776. }
  777. if err := scanner.Err(); err != nil {
  778. if strings.Contains(err.Error(), "unexpected EOF") || strings.Contains(err.Error(), "forcibly closed") {
  779. s.Close()
  780. var msg string
  781. if s.status != nil && s.status.LastErrMsg != "" {
  782. msg = s.status.LastErrMsg
  783. } else {
  784. msg = err.Error()
  785. }
  786. return fmt.Errorf("an error was encountered while running the model: %s", msg)
  787. }
  788. return fmt.Errorf("error reading llm response: %v", err)
  789. }
  790. return nil
  791. }
  792. type EmbeddingRequest struct {
  793. Content string `json:"content"`
  794. }
  795. type EmbeddingResponse struct {
  796. Embedding []float32 `json:"embedding"`
  797. }
  798. func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
  799. if err := s.sem.Acquire(ctx, 1); err != nil {
  800. if errors.Is(err, context.Canceled) {
  801. slog.Info("aborting embedding request due to client closing the connection")
  802. } else {
  803. slog.Error("Failed to acquire semaphore", "error", err)
  804. }
  805. return nil, err
  806. }
  807. defer s.sem.Release(1)
  808. // Make sure the server is ready
  809. status, err := s.getServerStatusRetry(ctx)
  810. if err != nil {
  811. return nil, err
  812. } else if status != ServerStatusReady {
  813. return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
  814. }
  815. data, err := json.Marshal(EmbeddingRequest{Content: input})
  816. if err != nil {
  817. return nil, fmt.Errorf("error marshaling embed data: %w", err)
  818. }
  819. r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
  820. if err != nil {
  821. return nil, fmt.Errorf("error creating embed request: %w", err)
  822. }
  823. r.Header.Set("Content-Type", "application/json")
  824. resp, err := http.DefaultClient.Do(r)
  825. if err != nil {
  826. return nil, fmt.Errorf("do embedding request: %w", err)
  827. }
  828. defer resp.Body.Close()
  829. body, err := io.ReadAll(resp.Body)
  830. if err != nil {
  831. return nil, fmt.Errorf("error reading embed response: %w", err)
  832. }
  833. if resp.StatusCode >= 400 {
  834. log.Printf("llm embedding error: %s", body)
  835. return nil, fmt.Errorf("%s", body)
  836. }
  837. var e EmbeddingResponse
  838. if err := json.Unmarshal(body, &e); err != nil {
  839. return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
  840. }
  841. return e.Embedding, nil
  842. }
  843. type TokenizeRequest struct {
  844. Content string `json:"content"`
  845. }
  846. type TokenizeResponse struct {
  847. Tokens []int `json:"tokens"`
  848. }
  849. func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
  850. s.llamaModelLock.Lock()
  851. defer s.llamaModelLock.Unlock()
  852. if s.llamaModel != nil {
  853. return s.llamaModel.Tokenize(content, false, true)
  854. }
  855. if s.textProcessor != nil {
  856. tokens, err := s.textProcessor.Encode(content, false)
  857. if err != nil {
  858. return nil, err
  859. }
  860. toks := make([]int, len(tokens))
  861. for i, t := range tokens {
  862. toks[i] = int(t)
  863. }
  864. return toks, nil
  865. }
  866. // not reached
  867. return nil, fmt.Errorf("no tokenizer configured")
  868. }
  869. type DetokenizeRequest struct {
  870. Tokens []int `json:"tokens"`
  871. }
  872. type DetokenizeResponse struct {
  873. Content string `json:"content"`
  874. }
  875. func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
  876. s.llamaModelLock.Lock()
  877. defer s.llamaModelLock.Unlock()
  878. if s.llamaModel != nil {
  879. var resp string
  880. for _, token := range tokens {
  881. resp += s.llamaModel.TokenToPiece(token)
  882. }
  883. return resp, nil
  884. }
  885. if s.textProcessor != nil {
  886. toks := make([]int32, len(tokens))
  887. for i, t := range tokens {
  888. toks[i] = int32(t)
  889. }
  890. content, err := s.textProcessor.Decode(toks)
  891. if err != nil {
  892. return "", err
  893. }
  894. return content, nil
  895. }
  896. // not reached
  897. return "", fmt.Errorf("no tokenizer configured")
  898. }
  899. func (s *llmServer) Close() error {
  900. s.llamaModelLock.Lock()
  901. if s.llamaModel != nil {
  902. llama.FreeModel(s.llamaModel)
  903. s.llamaModel = nil
  904. }
  905. s.llamaModelLock.Unlock()
  906. if s.cmd != nil {
  907. slog.Debug("stopping llama server")
  908. if err := s.cmd.Process.Kill(); err != nil {
  909. return err
  910. }
  911. // if ProcessState is already populated, Wait already completed, no need to wait again
  912. if s.cmd.ProcessState == nil {
  913. slog.Debug("waiting for llama server to exit")
  914. <-s.done
  915. }
  916. slog.Debug("llama server stopped")
  917. }
  918. return nil
  919. }
  920. func (s *llmServer) EstimatedVRAM() uint64 {
  921. return s.estimate.VRAMSize
  922. }
  923. func (s *llmServer) EstimatedTotal() uint64 {
  924. return s.estimate.TotalSize
  925. }
  926. func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
  927. for i, gpu := range s.gpus {
  928. if gpu.ID == gpuID {
  929. if i < len(s.estimate.GPUSizes) {
  930. return s.estimate.GPUSizes[i]
  931. }
  932. }
  933. }
  934. return 0
  935. }
  936. func parseDurationMs(ms float64) time.Duration {
  937. dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
  938. if err != nil {
  939. panic(err)
  940. }
  941. return dur
  942. }