server.go 29 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012
  1. package llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log"
  11. "log/slog"
  12. "math/rand"
  13. "net"
  14. "net/http"
  15. "os"
  16. "os/exec"
  17. "path/filepath"
  18. "runtime"
  19. "strconv"
  20. "strings"
  21. "sync"
  22. "time"
  23. "golang.org/x/sync/semaphore"
  24. "github.com/ollama/ollama/api"
  25. "github.com/ollama/ollama/discover"
  26. "github.com/ollama/ollama/envconfig"
  27. "github.com/ollama/ollama/format"
  28. "github.com/ollama/ollama/fs/ggml"
  29. "github.com/ollama/ollama/llama"
  30. "github.com/ollama/ollama/model"
  31. )
  32. type LlamaServer interface {
  33. Ping(ctx context.Context) error
  34. WaitUntilRunning(ctx context.Context) error
  35. Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
  36. Embedding(ctx context.Context, input string) ([]float32, error)
  37. Tokenize(ctx context.Context, content string) ([]int, error)
  38. Detokenize(ctx context.Context, tokens []int) (string, error)
  39. Close() error
  40. EstimatedVRAM() uint64 // Total VRAM across all GPUs
  41. EstimatedTotal() uint64
  42. EstimatedVRAMByGPU(gpuID string) uint64
  43. }
  44. // llmServer is an instance of the llama.cpp server
  45. type llmServer struct {
  46. port int
  47. cmd *exec.Cmd
  48. done chan error // Channel to signal when the process exits
  49. status *StatusWriter
  50. options api.Options
  51. numParallel int
  52. modelPath string
  53. // llamaModel is an instance of the cgo llama.cpp model definition
  54. // nil if this server is running the new engine
  55. llamaModel *llama.Model
  56. llamaModelLock sync.Mutex
  57. // textProcessor handles text encoding/decoding for the model in the Ollama engine
  58. // nil if this server is running the llama.cpp based engine
  59. textProcessor model.TextProcessor
  60. estimate MemoryEstimate
  61. totalLayers uint64
  62. // gpuCount int
  63. gpus discover.GpuInfoList // Recorded just before the model loaded, free space will be incorrect
  64. loadDuration time.Duration // Record how long it took the model to load
  65. loadProgress float32
  66. sem *semaphore.Weighted
  67. }
  68. // LoadModel will load a model from disk. The model must be in the GGML format.
  69. //
  70. // It collects array values for arrays with a size less than or equal to
  71. // maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
  72. // the maxArraySize is negative, all arrays are collected.
  73. func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
  74. if _, err := os.Stat(model); err != nil {
  75. return nil, err
  76. }
  77. f, err := os.Open(model)
  78. if err != nil {
  79. return nil, err
  80. }
  81. defer f.Close()
  82. ggml, _, err := ggml.Decode(f, maxArraySize)
  83. return ggml, err
  84. }
  85. // NewLlamaServer will run a server for the given GPUs
  86. // The gpu list must be a single family.
  87. func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
  88. systemInfo := discover.GetSystemInfo()
  89. systemTotalMemory := systemInfo.System.TotalMemory
  90. systemFreeMemory := systemInfo.System.FreeMemory
  91. systemSwapFreeMemory := systemInfo.System.FreeSwap
  92. slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
  93. // If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
  94. if opts.NumGPU == 0 {
  95. gpus = discover.GetCPUInfo()
  96. }
  97. estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
  98. if len(gpus) > 1 || gpus[0].Library != "cpu" {
  99. switch {
  100. case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
  101. // disable partial offloading when model is greater than total system memory as this
  102. // can lead to locking up the system
  103. opts.NumGPU = 0
  104. case gpus[0].Library != "metal" && estimate.Layers == 0:
  105. // Don't bother loading into the GPU if no layers can fit
  106. gpus = discover.GetCPUInfo()
  107. case opts.NumGPU < 0 && estimate.Layers > 0 && gpus[0].Library != "cpu":
  108. opts.NumGPU = estimate.Layers
  109. }
  110. }
  111. // On linux and windows, over-allocating CPU memory will almost always result in an error
  112. // Darwin has fully dynamic swap so has no direct concept of free swap space
  113. if runtime.GOOS != "darwin" {
  114. systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
  115. available := systemFreeMemory + systemSwapFreeMemory
  116. if systemMemoryRequired > available {
  117. slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
  118. return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
  119. }
  120. }
  121. slog.Info("offload", "", estimate)
  122. params := []string{
  123. "--model", modelPath,
  124. "--ctx-size", strconv.Itoa(opts.NumCtx),
  125. "--batch-size", strconv.Itoa(opts.NumBatch),
  126. }
  127. if opts.NumGPU >= 0 {
  128. params = append(params, "--n-gpu-layers", strconv.Itoa(opts.NumGPU))
  129. }
  130. if envconfig.Debug() {
  131. params = append(params, "--verbose")
  132. }
  133. if opts.MainGPU > 0 {
  134. params = append(params, "--main-gpu", strconv.Itoa(opts.MainGPU))
  135. }
  136. if len(adapters) > 0 {
  137. for _, adapter := range adapters {
  138. params = append(params, "--lora", adapter)
  139. }
  140. }
  141. defaultThreads := systemInfo.GetOptimalThreadCount()
  142. if opts.NumThread > 0 {
  143. params = append(params, "--threads", strconv.Itoa(opts.NumThread))
  144. } else if defaultThreads > 0 {
  145. params = append(params, "--threads", strconv.Itoa(defaultThreads))
  146. }
  147. fa := envconfig.FlashAttention()
  148. if fa && !gpus.FlashAttentionSupported() {
  149. slog.Warn("flash attention enabled but not supported by gpu")
  150. fa = false
  151. }
  152. if fa && !f.SupportsFlashAttention() {
  153. slog.Warn("flash attention enabled but not supported by model")
  154. fa = false
  155. }
  156. kvct := strings.ToLower(envconfig.KvCacheType())
  157. if fa {
  158. slog.Info("enabling flash attention")
  159. params = append(params, "--flash-attn")
  160. // Flash Attention also supports kv cache quantization
  161. // Enable if the requested and kv cache type is supported by the model
  162. if kvct != "" && f.SupportsKVCacheType(kvct) {
  163. params = append(params, "--kv-cache-type", kvct)
  164. } else {
  165. slog.Warn("kv cache type not supported by model", "type", kvct)
  166. }
  167. } else if kvct != "" && kvct != "f16" {
  168. slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
  169. }
  170. // mmap has issues with partial offloading on metal
  171. for _, g := range gpus {
  172. if g.Library == "metal" &&
  173. uint64(opts.NumGPU) > 0 &&
  174. uint64(opts.NumGPU) < f.KV().BlockCount()+1 {
  175. opts.UseMMap = new(bool)
  176. *opts.UseMMap = false
  177. }
  178. }
  179. // Windows CUDA should not use mmap for best performance
  180. // Linux with a model larger than free space, mmap leads to thrashing
  181. // For CPU loads we want the memory to be allocated, not FS cache
  182. if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == nil) ||
  183. (runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == nil) ||
  184. (gpus[0].Library == "cpu" && opts.UseMMap == nil) ||
  185. (opts.UseMMap != nil && !*opts.UseMMap) {
  186. params = append(params, "--no-mmap")
  187. }
  188. if opts.UseMLock {
  189. params = append(params, "--mlock")
  190. }
  191. // TODO - NUMA support currently doesn't work properly
  192. params = append(params, "--parallel", strconv.Itoa(numParallel))
  193. if estimate.TensorSplit != "" {
  194. params = append(params, "--tensor-split", estimate.TensorSplit)
  195. }
  196. if envconfig.MultiUserCache() {
  197. params = append(params, "--multiuser-cache")
  198. }
  199. libs := make(map[string]string)
  200. if entries, err := os.ReadDir(discover.LibOllamaPath); err == nil {
  201. for _, entry := range entries {
  202. libs[entry.Name()] = filepath.Join(discover.LibOllamaPath, entry.Name())
  203. }
  204. }
  205. lib := gpus[0].RunnerName()
  206. requested := envconfig.LLMLibrary()
  207. if libs[requested] != "" {
  208. slog.Info("using requested gpu library", "requested", requested)
  209. lib = requested
  210. }
  211. var compatible []string
  212. for k := range libs {
  213. // exact match first
  214. if k == lib {
  215. compatible = append([]string{k}, compatible...)
  216. continue
  217. }
  218. // then match the family (e.g. 'cuda')
  219. if strings.Split(k, "_")[0] == strings.Split(lib, "_")[0] {
  220. compatible = append(compatible, k)
  221. }
  222. }
  223. slog.Debug("compatible gpu libraries", "compatible", compatible)
  224. exe, err := os.Executable()
  225. if err != nil {
  226. return nil, fmt.Errorf("unable to lookup executable path: %w", err)
  227. }
  228. if eval, err := filepath.EvalSymlinks(exe); err == nil {
  229. exe = eval
  230. }
  231. var llamaModel *llama.Model
  232. var textProcessor model.TextProcessor
  233. if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
  234. textProcessor, err = model.NewTextProcessor(modelPath)
  235. if err != nil {
  236. // To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
  237. slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
  238. }
  239. }
  240. if textProcessor == nil {
  241. llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
  242. if err != nil {
  243. return nil, err
  244. }
  245. }
  246. if len(projectors) > 0 && llamaModel != nil {
  247. params = append(params, "--mmproj", projectors[0])
  248. }
  249. // iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
  250. // adding each library's respective path to the LD_LIBRARY_PATH, until finally running
  251. // without any LD_LIBRARY_PATH flags
  252. for {
  253. port := 0
  254. if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
  255. var l *net.TCPListener
  256. if l, err = net.ListenTCP("tcp", a); err == nil {
  257. port = l.Addr().(*net.TCPAddr).Port
  258. l.Close()
  259. }
  260. }
  261. if port == 0 {
  262. slog.Debug("ResolveTCPAddr failed, using random port")
  263. port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  264. }
  265. finalParams := []string{"runner"}
  266. if textProcessor != nil {
  267. // New engine
  268. // TODO - if we have failure to load scenarios, add logic to retry with the old runner
  269. finalParams = append(finalParams, "--ollama-engine")
  270. }
  271. finalParams = append(finalParams, params...)
  272. finalParams = append(finalParams, "--port", strconv.Itoa(port))
  273. var pathEnv string
  274. switch runtime.GOOS {
  275. case "windows":
  276. pathEnv = "PATH"
  277. case "darwin":
  278. pathEnv = "DYLD_LIBRARY_PATH"
  279. default:
  280. pathEnv = "LD_LIBRARY_PATH"
  281. }
  282. var libraryPaths []string
  283. if libraryPath, ok := os.LookupEnv(pathEnv); ok {
  284. libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
  285. }
  286. if len(compatible) > 0 {
  287. c := compatible[0]
  288. if libpath, ok := libs[c]; ok {
  289. slog.Debug("adding gpu library", "path", libpath)
  290. libraryPaths = append(libraryPaths, libpath)
  291. }
  292. }
  293. // Note: we always put the dependency path first
  294. // since this was the exact version we compiled/linked against
  295. if gpus[0].DependencyPath != nil {
  296. slog.Debug("adding gpu dependency paths", "paths", gpus[0].DependencyPath)
  297. // assume gpus from the same library have the same dependency path
  298. libraryPaths = append(gpus[0].DependencyPath, libraryPaths...)
  299. }
  300. // finally, add the root library path
  301. libraryPaths = append(libraryPaths, discover.LibOllamaPath)
  302. s := &llmServer{
  303. port: port,
  304. cmd: exec.Command(exe, finalParams...),
  305. status: NewStatusWriter(os.Stderr),
  306. options: opts,
  307. modelPath: modelPath,
  308. llamaModel: llamaModel,
  309. textProcessor: textProcessor,
  310. estimate: estimate,
  311. numParallel: numParallel,
  312. sem: semaphore.NewWeighted(int64(numParallel)),
  313. totalLayers: f.KV().BlockCount() + 1,
  314. gpus: gpus,
  315. done: make(chan error, 1),
  316. }
  317. s.cmd.Env = os.Environ()
  318. s.cmd.Stdout = os.Stdout
  319. s.cmd.Stderr = s.status
  320. s.cmd.SysProcAttr = LlamaServerSysProcAttr
  321. envWorkarounds := [][2]string{}
  322. for _, gpu := range gpus {
  323. envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...)
  324. }
  325. visibleDevicesEnv, visibleDevicesEnvVal := gpus.GetVisibleDevicesEnv()
  326. pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
  327. // Update or add the path and visible devices variable with our adjusted version
  328. pathNeeded := true
  329. devicesNeeded := visibleDevicesEnv != ""
  330. for i := range s.cmd.Env {
  331. cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
  332. if strings.EqualFold(cmp[0], pathEnv) {
  333. s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
  334. pathNeeded = false
  335. } else if devicesNeeded && strings.EqualFold(cmp[0], visibleDevicesEnv) {
  336. s.cmd.Env[i] = visibleDevicesEnv + "=" + visibleDevicesEnvVal
  337. devicesNeeded = false
  338. } else if len(envWorkarounds) != 0 {
  339. for _, kv := range envWorkarounds {
  340. if strings.EqualFold(cmp[0], kv[0]) {
  341. s.cmd.Env[i] = kv[0] + "=" + kv[1]
  342. }
  343. }
  344. }
  345. }
  346. if pathNeeded {
  347. s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
  348. }
  349. if devicesNeeded {
  350. s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
  351. }
  352. slog.Info("starting llama server", "cmd", s.cmd)
  353. if envconfig.Debug() {
  354. filteredEnv := []string{}
  355. for _, ev := range s.cmd.Env {
  356. if strings.HasPrefix(ev, "CUDA_") ||
  357. strings.HasPrefix(ev, "ROCR_") ||
  358. strings.HasPrefix(ev, "ROCM_") ||
  359. strings.HasPrefix(ev, "HIP_") ||
  360. strings.HasPrefix(ev, "GPU_") ||
  361. strings.HasPrefix(ev, "HSA_") ||
  362. strings.HasPrefix(ev, "GGML_") ||
  363. strings.HasPrefix(ev, "PATH=") ||
  364. strings.HasPrefix(ev, "LD_LIBRARY_PATH=") ||
  365. strings.HasPrefix(ev, "DYLD_LIBRARY_PATH=") {
  366. filteredEnv = append(filteredEnv, ev)
  367. }
  368. }
  369. // Log at debug as the environment is inherited and might contain sensitive information
  370. slog.Debug("subprocess", "environment", filteredEnv)
  371. }
  372. if err = s.cmd.Start(); err != nil {
  373. var msg string
  374. if s.status != nil && s.status.LastErrMsg != "" {
  375. msg = s.status.LastErrMsg
  376. }
  377. err := fmt.Errorf("error starting runner: %v %s", err, msg)
  378. if len(compatible) == 0 {
  379. if llamaModel != nil {
  380. llama.FreeModel(llamaModel)
  381. }
  382. return nil, err
  383. }
  384. slog.Warn("unable to start runner with compatible gpu", "error", err, "compatible", compatible)
  385. compatible = compatible[1:]
  386. continue
  387. }
  388. // reap subprocess when it exits
  389. go func() {
  390. err := s.cmd.Wait()
  391. // Favor a more detailed message over the process exit status
  392. if err != nil && s.status != nil && s.status.LastErrMsg != "" {
  393. slog.Error("llama runner terminated", "error", err)
  394. if strings.Contains(s.status.LastErrMsg, "unknown model") {
  395. s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
  396. }
  397. s.done <- errors.New(s.status.LastErrMsg)
  398. } else {
  399. s.done <- err
  400. }
  401. }()
  402. return s, nil
  403. }
  404. }
  405. type ServerStatus int
  406. const ( // iota is reset to 0
  407. ServerStatusReady ServerStatus = iota
  408. ServerStatusNoSlotsAvailable
  409. ServerStatusLoadingModel
  410. ServerStatusNotResponding
  411. ServerStatusError
  412. )
  413. func (s ServerStatus) String() string {
  414. switch s {
  415. case ServerStatusReady:
  416. return "llm server ready"
  417. case ServerStatusNoSlotsAvailable:
  418. return "llm busy - no slots available"
  419. case ServerStatusLoadingModel:
  420. return "llm server loading model"
  421. case ServerStatusNotResponding:
  422. return "llm server not responding"
  423. default:
  424. return "llm server error"
  425. }
  426. }
  427. type ServerStatusResponse struct {
  428. Status ServerStatus `json:"status"`
  429. Progress float32 `json:"progress"`
  430. }
  431. func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
  432. // Fail fast if its exited
  433. if s.cmd.ProcessState != nil {
  434. msg := ""
  435. if s.status != nil && s.status.LastErrMsg != "" {
  436. msg = s.status.LastErrMsg
  437. }
  438. if s.cmd.ProcessState.ExitCode() == -1 {
  439. // Most likely a signal killed it, log some more details to try to help troubleshoot
  440. slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
  441. }
  442. return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
  443. }
  444. req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/health", s.port), nil)
  445. if err != nil {
  446. return ServerStatusError, fmt.Errorf("error creating GET request: %v", err)
  447. }
  448. req.Header.Set("Content-Type", "application/json")
  449. resp, err := http.DefaultClient.Do(req)
  450. if err != nil {
  451. if errors.Is(err, context.DeadlineExceeded) {
  452. return ServerStatusNotResponding, errors.New("server not responding")
  453. }
  454. return ServerStatusError, fmt.Errorf("health resp: %w", err)
  455. }
  456. defer resp.Body.Close()
  457. body, err := io.ReadAll(resp.Body)
  458. if err != nil {
  459. return ServerStatusError, fmt.Errorf("read health request: %w", err)
  460. }
  461. var ssr ServerStatusResponse
  462. if err := json.Unmarshal(body, &ssr); err != nil {
  463. return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
  464. }
  465. switch ssr.Status {
  466. case ServerStatusLoadingModel:
  467. s.loadProgress = ssr.Progress
  468. return ssr.Status, nil
  469. case ServerStatusReady, ServerStatusNoSlotsAvailable:
  470. return ssr.Status, nil
  471. default:
  472. return ssr.Status, fmt.Errorf("server error: %+v", ssr)
  473. }
  474. }
  475. // getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received
  476. func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) {
  477. var retries int
  478. for {
  479. status, err := s.getServerStatus(ctx)
  480. if err != nil {
  481. return status, err
  482. }
  483. if status == ServerStatusNoSlotsAvailable {
  484. if retries >= 10 {
  485. return status, fmt.Errorf("no slots available after %d retries", retries)
  486. }
  487. time.Sleep(5 * time.Millisecond)
  488. retries++
  489. continue
  490. }
  491. return status, nil
  492. }
  493. }
  494. func (s *llmServer) Ping(ctx context.Context) error {
  495. _, err := s.getServerStatus(ctx)
  496. if err != nil {
  497. slog.Debug("server unhealthy", "error", err)
  498. return err
  499. }
  500. return nil
  501. }
  502. func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
  503. start := time.Now()
  504. stallDuration := envconfig.LoadTimeout() // If no progress happens
  505. stallTimer := time.Now().Add(stallDuration) // give up if we stall
  506. slog.Info("waiting for llama runner to start responding")
  507. var lastStatus ServerStatus = -1
  508. fullyLoaded := false
  509. for {
  510. select {
  511. case <-ctx.Done():
  512. slog.Warn("client connection closed before server finished loading, aborting load")
  513. return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
  514. case err := <-s.done:
  515. return fmt.Errorf("llama runner process has terminated: %w", err)
  516. default:
  517. }
  518. if time.Now().After(stallTimer) {
  519. // timeout
  520. msg := ""
  521. if s.status != nil && s.status.LastErrMsg != "" {
  522. msg = s.status.LastErrMsg
  523. }
  524. return fmt.Errorf("timed out waiting for llama runner to start - progress %0.2f - %s", s.loadProgress, msg)
  525. }
  526. if s.cmd.ProcessState != nil {
  527. msg := ""
  528. if s.status != nil && s.status.LastErrMsg != "" {
  529. msg = s.status.LastErrMsg
  530. }
  531. return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
  532. }
  533. ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
  534. defer cancel()
  535. priorProgress := s.loadProgress
  536. status, _ := s.getServerStatus(ctx)
  537. if lastStatus != status && status != ServerStatusReady {
  538. // Only log on status changes
  539. slog.Info("waiting for server to become available", "status", status)
  540. }
  541. switch status {
  542. case ServerStatusReady:
  543. s.loadDuration = time.Since(start)
  544. slog.Info(fmt.Sprintf("llama runner started in %0.2f seconds", s.loadDuration.Seconds()))
  545. return nil
  546. default:
  547. lastStatus = status
  548. // Reset the timer as long as we're making forward progress on the load
  549. if priorProgress != s.loadProgress {
  550. slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
  551. stallTimer = time.Now().Add(stallDuration)
  552. } else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
  553. slog.Debug("model load completed, waiting for server to become available", "status", status)
  554. stallTimer = time.Now().Add(stallDuration)
  555. fullyLoaded = true
  556. }
  557. time.Sleep(time.Millisecond * 250)
  558. continue
  559. }
  560. }
  561. }
  562. var grammarJSON = `
  563. root ::= object
  564. value ::= object | array | string | number | ("true" | "false" | "null") ws
  565. object ::=
  566. "{" ws (
  567. string ":" ws value
  568. ("," ws string ":" ws value)*
  569. )? "}" ws
  570. array ::=
  571. "[" ws (
  572. value
  573. ("," ws value)*
  574. )? "]" ws
  575. string ::=
  576. "\"" (
  577. [^"\\\x7F\x00-\x1F] |
  578. "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
  579. )* "\"" ws
  580. number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
  581. # Optional space: by convention, applied in this grammar after literal chars when allowed
  582. ws ::= ([ \t\n] ws)?
  583. `
  584. const maxBufferSize = 512 * format.KiloByte
  585. type ImageData struct {
  586. Data []byte `json:"data"`
  587. ID int `json:"id"`
  588. AspectRatioID int `json:"aspect_ratio_id"`
  589. }
  590. type CompletionRequest struct {
  591. Prompt string
  592. Format json.RawMessage
  593. Images []ImageData
  594. Options *api.Options
  595. Grammar string // set before sending the request to the subprocess
  596. }
  597. type CompletionResponse struct {
  598. Content string `json:"content"`
  599. DoneReason string `json:"done_reason"`
  600. Done bool `json:"done"`
  601. PromptEvalCount int `json:"prompt_eval_count"`
  602. PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
  603. EvalCount int `json:"eval_count"`
  604. EvalDuration time.Duration `json:"eval_duration"`
  605. }
  606. func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
  607. if len(req.Format) > 0 {
  608. switch string(req.Format) {
  609. case `null`, `""`:
  610. // Field was set, but "missing" a value. We accept
  611. // these as "not set".
  612. break
  613. case `"json"`:
  614. req.Grammar = grammarJSON
  615. default:
  616. if req.Format[0] != '{' {
  617. return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
  618. }
  619. // User provided a JSON schema
  620. g := llama.SchemaToGrammar(req.Format)
  621. if g == nil {
  622. return fmt.Errorf("invalid JSON schema in format")
  623. }
  624. req.Grammar = string(g)
  625. }
  626. }
  627. if req.Options == nil {
  628. opts := api.DefaultOptions()
  629. req.Options = &opts
  630. }
  631. if err := s.sem.Acquire(ctx, 1); err != nil {
  632. if errors.Is(err, context.Canceled) {
  633. slog.Info("aborting completion request due to client closing the connection")
  634. } else {
  635. slog.Error("Failed to acquire semaphore", "error", err)
  636. }
  637. return err
  638. }
  639. defer s.sem.Release(1)
  640. // put an upper limit on num_predict to avoid the model running on forever
  641. if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
  642. req.Options.NumPredict = 10 * s.options.NumCtx
  643. }
  644. // Make sure the server is ready
  645. status, err := s.getServerStatusRetry(ctx)
  646. if err != nil {
  647. return err
  648. } else if status != ServerStatusReady {
  649. return fmt.Errorf("unexpected server status: %s", status)
  650. }
  651. // Handling JSON marshaling with special characters unescaped.
  652. buffer := &bytes.Buffer{}
  653. enc := json.NewEncoder(buffer)
  654. enc.SetEscapeHTML(false)
  655. if err := enc.Encode(req); err != nil {
  656. return fmt.Errorf("failed to marshal data: %v", err)
  657. }
  658. endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
  659. serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
  660. if err != nil {
  661. return fmt.Errorf("error creating POST request: %v", err)
  662. }
  663. serverReq.Header.Set("Content-Type", "application/json")
  664. res, err := http.DefaultClient.Do(serverReq)
  665. if err != nil {
  666. return fmt.Errorf("POST predict: %v", err)
  667. }
  668. defer res.Body.Close()
  669. if res.StatusCode >= 400 {
  670. bodyBytes, err := io.ReadAll(res.Body)
  671. if err != nil {
  672. return fmt.Errorf("failed reading llm error response: %w", err)
  673. }
  674. log.Printf("llm predict error: %s", bodyBytes)
  675. return fmt.Errorf("%s", bodyBytes)
  676. }
  677. scanner := bufio.NewScanner(res.Body)
  678. buf := make([]byte, 0, maxBufferSize)
  679. scanner.Buffer(buf, maxBufferSize)
  680. // keep track of the last token generated, this is used to abort if the model starts looping
  681. var lastToken string
  682. var tokenRepeat int
  683. for scanner.Scan() {
  684. select {
  685. case <-ctx.Done():
  686. // This handles the request cancellation
  687. return ctx.Err()
  688. default:
  689. line := scanner.Bytes()
  690. if len(line) == 0 {
  691. continue
  692. }
  693. // slog.Debug("got line", "line", string(line))
  694. evt, ok := bytes.CutPrefix(line, []byte("data: "))
  695. if !ok {
  696. evt = line
  697. }
  698. var c CompletionResponse
  699. if err := json.Unmarshal(evt, &c); err != nil {
  700. return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
  701. }
  702. switch {
  703. case strings.TrimSpace(c.Content) == lastToken:
  704. tokenRepeat++
  705. default:
  706. lastToken = strings.TrimSpace(c.Content)
  707. tokenRepeat = 0
  708. }
  709. // 30 picked as an arbitrary max token repeat limit, modify as needed
  710. if tokenRepeat > 30 {
  711. slog.Debug("prediction aborted, token repeat limit reached")
  712. return ctx.Err()
  713. }
  714. if c.Content != "" {
  715. fn(CompletionResponse{
  716. Content: c.Content,
  717. })
  718. }
  719. if c.Done {
  720. fn(c)
  721. return nil
  722. }
  723. }
  724. }
  725. if err := scanner.Err(); err != nil {
  726. if strings.Contains(err.Error(), "unexpected EOF") || strings.Contains(err.Error(), "forcibly closed") {
  727. s.Close()
  728. var msg string
  729. if s.status != nil && s.status.LastErrMsg != "" {
  730. msg = s.status.LastErrMsg
  731. } else {
  732. msg = err.Error()
  733. }
  734. return fmt.Errorf("an error was encountered while running the model: %s", msg)
  735. }
  736. return fmt.Errorf("error reading llm response: %v", err)
  737. }
  738. return nil
  739. }
  740. type EmbeddingRequest struct {
  741. Content string `json:"content"`
  742. }
  743. type EmbeddingResponse struct {
  744. Embedding []float32 `json:"embedding"`
  745. }
  746. func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
  747. if err := s.sem.Acquire(ctx, 1); err != nil {
  748. if errors.Is(err, context.Canceled) {
  749. slog.Info("aborting embedding request due to client closing the connection")
  750. } else {
  751. slog.Error("Failed to acquire semaphore", "error", err)
  752. }
  753. return nil, err
  754. }
  755. defer s.sem.Release(1)
  756. // Make sure the server is ready
  757. status, err := s.getServerStatusRetry(ctx)
  758. if err != nil {
  759. return nil, err
  760. } else if status != ServerStatusReady {
  761. return nil, fmt.Errorf("unexpected server status: %s", status)
  762. }
  763. data, err := json.Marshal(EmbeddingRequest{Content: input})
  764. if err != nil {
  765. return nil, fmt.Errorf("error marshaling embed data: %w", err)
  766. }
  767. r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
  768. if err != nil {
  769. return nil, fmt.Errorf("error creating embed request: %w", err)
  770. }
  771. r.Header.Set("Content-Type", "application/json")
  772. resp, err := http.DefaultClient.Do(r)
  773. if err != nil {
  774. return nil, fmt.Errorf("do embedding request: %w", err)
  775. }
  776. defer resp.Body.Close()
  777. body, err := io.ReadAll(resp.Body)
  778. if err != nil {
  779. return nil, fmt.Errorf("error reading embed response: %w", err)
  780. }
  781. if resp.StatusCode >= 400 {
  782. log.Printf("llm embedding error: %s", body)
  783. return nil, fmt.Errorf("%s", body)
  784. }
  785. var e EmbeddingResponse
  786. if err := json.Unmarshal(body, &e); err != nil {
  787. return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
  788. }
  789. return e.Embedding, nil
  790. }
  791. type TokenizeRequest struct {
  792. Content string `json:"content"`
  793. }
  794. type TokenizeResponse struct {
  795. Tokens []int `json:"tokens"`
  796. }
  797. func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
  798. s.llamaModelLock.Lock()
  799. defer s.llamaModelLock.Unlock()
  800. if s.llamaModel != nil {
  801. return s.llamaModel.Tokenize(content, false, true)
  802. }
  803. if s.textProcessor != nil {
  804. tokens, err := s.textProcessor.Encode(content, false)
  805. if err != nil {
  806. return nil, err
  807. }
  808. toks := make([]int, len(tokens))
  809. for i, t := range tokens {
  810. toks[i] = int(t)
  811. }
  812. return toks, nil
  813. }
  814. // not reached
  815. return nil, fmt.Errorf("no tokenizer configured")
  816. }
  817. type DetokenizeRequest struct {
  818. Tokens []int `json:"tokens"`
  819. }
  820. type DetokenizeResponse struct {
  821. Content string `json:"content"`
  822. }
  823. func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
  824. s.llamaModelLock.Lock()
  825. defer s.llamaModelLock.Unlock()
  826. if s.llamaModel != nil {
  827. var resp string
  828. for _, token := range tokens {
  829. resp += s.llamaModel.TokenToPiece(token)
  830. }
  831. return resp, nil
  832. }
  833. if s.textProcessor != nil {
  834. toks := make([]int32, len(tokens))
  835. for i, t := range tokens {
  836. toks[i] = int32(t)
  837. }
  838. content, err := s.textProcessor.Decode(toks)
  839. if err != nil {
  840. return "", err
  841. }
  842. return content, nil
  843. }
  844. // not reached
  845. return "", fmt.Errorf("no tokenizer configured")
  846. }
  847. func (s *llmServer) Close() error {
  848. s.llamaModelLock.Lock()
  849. if s.llamaModel != nil {
  850. llama.FreeModel(s.llamaModel)
  851. s.llamaModel = nil
  852. }
  853. s.llamaModelLock.Unlock()
  854. if s.cmd != nil {
  855. slog.Debug("stopping llama server")
  856. if err := s.cmd.Process.Kill(); err != nil {
  857. return err
  858. }
  859. // if ProcessState is already populated, Wait already completed, no need to wait again
  860. if s.cmd.ProcessState == nil {
  861. slog.Debug("waiting for llama server to exit")
  862. <-s.done
  863. }
  864. slog.Debug("llama server stopped")
  865. }
  866. return nil
  867. }
  868. func (s *llmServer) EstimatedVRAM() uint64 {
  869. return s.estimate.VRAMSize
  870. }
  871. func (s *llmServer) EstimatedTotal() uint64 {
  872. return s.estimate.TotalSize
  873. }
  874. func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
  875. for i, gpu := range s.gpus {
  876. if gpu.ID == gpuID {
  877. if i < len(s.estimate.GPUSizes) {
  878. return s.estimate.GPUSizes[i]
  879. }
  880. }
  881. }
  882. return 0
  883. }