server.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017
  1. package llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log"
  11. "log/slog"
  12. "math/rand"
  13. "net"
  14. "net/http"
  15. "os"
  16. "os/exec"
  17. "path/filepath"
  18. "runtime"
  19. "strconv"
  20. "strings"
  21. "sync"
  22. "time"
  23. "golang.org/x/sync/semaphore"
  24. "github.com/ollama/ollama/api"
  25. "github.com/ollama/ollama/discover"
  26. "github.com/ollama/ollama/envconfig"
  27. "github.com/ollama/ollama/format"
  28. "github.com/ollama/ollama/fs/ggml"
  29. "github.com/ollama/ollama/grammar"
  30. "github.com/ollama/ollama/llama"
  31. "github.com/ollama/ollama/model"
  32. )
  33. type LlamaServer interface {
  34. Ping(ctx context.Context) error
  35. WaitUntilRunning(ctx context.Context) error
  36. Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
  37. Embedding(ctx context.Context, input string) ([]float32, error)
  38. Tokenize(ctx context.Context, content string) ([]int, error)
  39. Detokenize(ctx context.Context, tokens []int) (string, error)
  40. Close() error
  41. EstimatedVRAM() uint64 // Total VRAM across all GPUs
  42. EstimatedTotal() uint64
  43. EstimatedVRAMByGPU(gpuID string) uint64
  44. }
  45. // llmServer is an instance of the llama.cpp server
  46. type llmServer struct {
  47. port int
  48. cmd *exec.Cmd
  49. done chan error // Channel to signal when the process exits
  50. status *StatusWriter
  51. options api.Options
  52. numParallel int
  53. modelPath string
  54. // llamaModel is an instance of the cgo llama.cpp model definition
  55. // nil if this server is running the new engine
  56. llamaModel *llama.Model
  57. llamaModelLock sync.Mutex
  58. // textProcessor handles text encoding/decoding for the model in the Ollama engine
  59. // nil if this server is running the llama.cpp based engine
  60. textProcessor model.TextProcessor
  61. estimate MemoryEstimate
  62. totalLayers uint64
  63. // gpuCount int
  64. gpus discover.GpuInfoList // Recorded just before the model loaded, free space will be incorrect
  65. loadDuration time.Duration // Record how long it took the model to load
  66. loadProgress float32
  67. sem *semaphore.Weighted
  68. }
  69. // LoadModel will load a model from disk. The model must be in the GGML format.
  70. //
  71. // It collects array values for arrays with a size less than or equal to
  72. // maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
  73. // the maxArraySize is negative, all arrays are collected.
  74. func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
  75. if _, err := os.Stat(model); err != nil {
  76. return nil, err
  77. }
  78. f, err := os.Open(model)
  79. if err != nil {
  80. return nil, err
  81. }
  82. defer f.Close()
  83. ggml, _, err := ggml.Decode(f, maxArraySize)
  84. return ggml, err
  85. }
  86. // NewLlamaServer will run a server for the given GPUs
  87. // The gpu list must be a single family.
  88. func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
  89. systemInfo := discover.GetSystemInfo()
  90. systemTotalMemory := systemInfo.System.TotalMemory
  91. systemFreeMemory := systemInfo.System.FreeMemory
  92. systemSwapFreeMemory := systemInfo.System.FreeSwap
  93. slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
  94. // If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
  95. if opts.NumGPU == 0 {
  96. gpus = discover.GetCPUInfo()
  97. }
  98. estimate := EstimateGPULayers(gpus, f, projectors, opts)
  99. if len(gpus) > 1 || gpus[0].Library != "cpu" {
  100. switch {
  101. case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
  102. // disable partial offloading when model is greater than total system memory as this
  103. // can lead to locking up the system
  104. opts.NumGPU = 0
  105. case gpus[0].Library != "metal" && estimate.Layers == 0:
  106. // Don't bother loading into the GPU if no layers can fit
  107. gpus = discover.GetCPUInfo()
  108. case opts.NumGPU < 0 && estimate.Layers > 0 && gpus[0].Library != "cpu":
  109. opts.NumGPU = estimate.Layers
  110. }
  111. }
  112. // On linux and windows, over-allocating CPU memory will almost always result in an error
  113. // Darwin has fully dynamic swap so has no direct concept of free swap space
  114. if runtime.GOOS != "darwin" {
  115. systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
  116. available := systemFreeMemory + systemSwapFreeMemory
  117. if systemMemoryRequired > available {
  118. slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
  119. return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
  120. }
  121. }
  122. slog.Info("offload", "", estimate)
  123. params := []string{
  124. "--model", modelPath,
  125. "--ctx-size", strconv.Itoa(opts.NumCtx),
  126. "--batch-size", strconv.Itoa(opts.NumBatch),
  127. }
  128. if opts.NumGPU >= 0 {
  129. params = append(params, "--n-gpu-layers", strconv.Itoa(opts.NumGPU))
  130. }
  131. if envconfig.Debug() {
  132. params = append(params, "--verbose")
  133. }
  134. if opts.MainGPU > 0 {
  135. params = append(params, "--main-gpu", strconv.Itoa(opts.MainGPU))
  136. }
  137. if len(adapters) > 0 {
  138. for _, adapter := range adapters {
  139. params = append(params, "--lora", adapter)
  140. }
  141. }
  142. defaultThreads := systemInfo.GetOptimalThreadCount()
  143. if opts.NumThread > 0 {
  144. params = append(params, "--threads", strconv.Itoa(opts.NumThread))
  145. } else if defaultThreads > 0 {
  146. params = append(params, "--threads", strconv.Itoa(defaultThreads))
  147. }
  148. fa := envconfig.FlashAttention()
  149. if fa && !gpus.FlashAttentionSupported() {
  150. slog.Warn("flash attention enabled but not supported by gpu")
  151. fa = false
  152. }
  153. if fa && !f.SupportsFlashAttention() {
  154. slog.Warn("flash attention enabled but not supported by model")
  155. fa = false
  156. }
  157. kvct := strings.ToLower(envconfig.KvCacheType())
  158. if fa {
  159. slog.Info("enabling flash attention")
  160. params = append(params, "--flash-attn")
  161. // Flash Attention also supports kv cache quantization
  162. // Enable if the requested and kv cache type is supported by the model
  163. if kvct != "" && f.SupportsKVCacheType(kvct) {
  164. params = append(params, "--kv-cache-type", kvct)
  165. } else {
  166. slog.Warn("kv cache type not supported by model", "type", kvct)
  167. }
  168. } else if kvct != "" && kvct != "f16" {
  169. slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
  170. }
  171. // mmap has issues with partial offloading on metal
  172. for _, g := range gpus {
  173. if g.Library == "metal" &&
  174. uint64(opts.NumGPU) > 0 &&
  175. uint64(opts.NumGPU) < f.KV().BlockCount()+1 {
  176. opts.UseMMap = new(bool)
  177. *opts.UseMMap = false
  178. }
  179. }
  180. // Windows CUDA should not use mmap for best performance
  181. // Linux with a model larger than free space, mmap leads to thrashing
  182. // For CPU loads we want the memory to be allocated, not FS cache
  183. if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == nil) ||
  184. (runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == nil) ||
  185. (gpus[0].Library == "cpu" && opts.UseMMap == nil) ||
  186. (opts.UseMMap != nil && !*opts.UseMMap) {
  187. params = append(params, "--no-mmap")
  188. }
  189. if opts.UseMLock {
  190. params = append(params, "--mlock")
  191. }
  192. // TODO - NUMA support currently doesn't work properly
  193. params = append(params, "--parallel", strconv.Itoa(numParallel))
  194. if estimate.TensorSplit != "" {
  195. params = append(params, "--tensor-split", estimate.TensorSplit)
  196. }
  197. if envconfig.MultiUserCache() {
  198. params = append(params, "--multiuser-cache")
  199. }
  200. libs := make(map[string]string)
  201. if entries, err := os.ReadDir(discover.LibOllamaPath); err == nil {
  202. for _, entry := range entries {
  203. libs[entry.Name()] = filepath.Join(discover.LibOllamaPath, entry.Name())
  204. }
  205. }
  206. lib := gpus[0].RunnerName()
  207. requested := envconfig.LLMLibrary()
  208. if libs[requested] != "" {
  209. slog.Info("using requested gpu library", "requested", requested)
  210. lib = requested
  211. }
  212. var compatible []string
  213. for k := range libs {
  214. // exact match first
  215. if k == lib {
  216. compatible = append([]string{k}, compatible...)
  217. continue
  218. }
  219. // then match the family (e.g. 'cuda')
  220. if strings.Split(k, "_")[0] == strings.Split(lib, "_")[0] {
  221. compatible = append(compatible, k)
  222. }
  223. }
  224. slog.Debug("compatible gpu libraries", "compatible", compatible)
  225. exe, err := os.Executable()
  226. if err != nil {
  227. return nil, fmt.Errorf("unable to lookup executable path: %w", err)
  228. }
  229. if eval, err := filepath.EvalSymlinks(exe); err == nil {
  230. exe = eval
  231. }
  232. var llamaModel *llama.Model
  233. var textProcessor model.TextProcessor
  234. if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
  235. textProcessor, err = model.NewTextProcessor(modelPath)
  236. if err != nil {
  237. // To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
  238. slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
  239. }
  240. }
  241. if textProcessor == nil {
  242. llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
  243. if err != nil {
  244. return nil, err
  245. }
  246. }
  247. if len(projectors) > 0 && llamaModel != nil {
  248. params = append(params, "--mmproj", projectors[0])
  249. }
  250. // iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
  251. // adding each library's respective path to the LD_LIBRARY_PATH, until finally running
  252. // without any LD_LIBRARY_PATH flags
  253. for {
  254. port := 0
  255. if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
  256. var l *net.TCPListener
  257. if l, err = net.ListenTCP("tcp", a); err == nil {
  258. port = l.Addr().(*net.TCPAddr).Port
  259. l.Close()
  260. }
  261. }
  262. if port == 0 {
  263. slog.Debug("ResolveTCPAddr failed, using random port")
  264. port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  265. }
  266. finalParams := []string{"runner"}
  267. if textProcessor != nil {
  268. // New engine
  269. // TODO - if we have failure to load scenarios, add logic to retry with the old runner
  270. finalParams = append(finalParams, "--ollama-engine")
  271. }
  272. finalParams = append(finalParams, params...)
  273. finalParams = append(finalParams, "--port", strconv.Itoa(port))
  274. var pathEnv string
  275. switch runtime.GOOS {
  276. case "windows":
  277. pathEnv = "PATH"
  278. case "darwin":
  279. pathEnv = "DYLD_LIBRARY_PATH"
  280. default:
  281. pathEnv = "LD_LIBRARY_PATH"
  282. }
  283. var libraryPaths []string
  284. if libraryPath, ok := os.LookupEnv(pathEnv); ok {
  285. libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
  286. }
  287. if len(compatible) > 0 {
  288. c := compatible[0]
  289. if libpath, ok := libs[c]; ok {
  290. slog.Debug("adding gpu library", "path", libpath)
  291. libraryPaths = append(libraryPaths, libpath)
  292. }
  293. }
  294. // Note: we always put the dependency path first
  295. // since this was the exact version we compiled/linked against
  296. if gpus[0].DependencyPath != nil {
  297. slog.Debug("adding gpu dependency paths", "paths", gpus[0].DependencyPath)
  298. // assume gpus from the same library have the same dependency path
  299. libraryPaths = append(gpus[0].DependencyPath, libraryPaths...)
  300. }
  301. // finally, add the root library path
  302. libraryPaths = append(libraryPaths, discover.LibOllamaPath)
  303. s := &llmServer{
  304. port: port,
  305. cmd: exec.Command(exe, finalParams...),
  306. status: NewStatusWriter(os.Stderr),
  307. options: opts,
  308. modelPath: modelPath,
  309. llamaModel: llamaModel,
  310. textProcessor: textProcessor,
  311. estimate: estimate,
  312. numParallel: numParallel,
  313. sem: semaphore.NewWeighted(int64(numParallel)),
  314. totalLayers: f.KV().BlockCount() + 1,
  315. gpus: gpus,
  316. done: make(chan error, 1),
  317. }
  318. s.cmd.Env = os.Environ()
  319. s.cmd.Stdout = os.Stdout
  320. s.cmd.Stderr = s.status
  321. s.cmd.SysProcAttr = LlamaServerSysProcAttr
  322. envWorkarounds := [][2]string{}
  323. for _, gpu := range gpus {
  324. envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...)
  325. }
  326. visibleDevicesEnv, visibleDevicesEnvVal := gpus.GetVisibleDevicesEnv()
  327. pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
  328. // Update or add the path and visible devices variable with our adjusted version
  329. pathNeeded := true
  330. devicesNeeded := visibleDevicesEnv != ""
  331. for i := range s.cmd.Env {
  332. cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
  333. if strings.EqualFold(cmp[0], pathEnv) {
  334. s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
  335. pathNeeded = false
  336. } else if devicesNeeded && strings.EqualFold(cmp[0], visibleDevicesEnv) {
  337. s.cmd.Env[i] = visibleDevicesEnv + "=" + visibleDevicesEnvVal
  338. devicesNeeded = false
  339. } else if len(envWorkarounds) != 0 {
  340. for _, kv := range envWorkarounds {
  341. if strings.EqualFold(cmp[0], kv[0]) {
  342. s.cmd.Env[i] = kv[0] + "=" + kv[1]
  343. }
  344. }
  345. }
  346. }
  347. if pathNeeded {
  348. s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
  349. }
  350. if devicesNeeded {
  351. s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
  352. }
  353. slog.Info("starting llama server", "cmd", s.cmd)
  354. if envconfig.Debug() {
  355. filteredEnv := []string{}
  356. for _, ev := range s.cmd.Env {
  357. if strings.HasPrefix(ev, "CUDA_") ||
  358. strings.HasPrefix(ev, "ROCR_") ||
  359. strings.HasPrefix(ev, "ROCM_") ||
  360. strings.HasPrefix(ev, "HIP_") ||
  361. strings.HasPrefix(ev, "GPU_") ||
  362. strings.HasPrefix(ev, "HSA_") ||
  363. strings.HasPrefix(ev, "GGML_") ||
  364. strings.HasPrefix(ev, "PATH=") ||
  365. strings.HasPrefix(ev, "LD_LIBRARY_PATH=") ||
  366. strings.HasPrefix(ev, "DYLD_LIBRARY_PATH=") {
  367. filteredEnv = append(filteredEnv, ev)
  368. }
  369. }
  370. // Log at debug as the environment is inherited and might contain sensitive information
  371. slog.Debug("subprocess", "environment", filteredEnv)
  372. }
  373. if err = s.cmd.Start(); err != nil {
  374. var msg string
  375. if s.status != nil && s.status.LastErrMsg != "" {
  376. msg = s.status.LastErrMsg
  377. }
  378. err := fmt.Errorf("error starting runner: %v %s", err, msg)
  379. if len(compatible) == 0 {
  380. if llamaModel != nil {
  381. llama.FreeModel(llamaModel)
  382. }
  383. return nil, err
  384. }
  385. slog.Warn("unable to start runner with compatible gpu", "error", err, "compatible", compatible)
  386. compatible = compatible[1:]
  387. continue
  388. }
  389. // reap subprocess when it exits
  390. go func() {
  391. err := s.cmd.Wait()
  392. // Favor a more detailed message over the process exit status
  393. if err != nil && s.status != nil && s.status.LastErrMsg != "" {
  394. slog.Error("llama runner terminated", "error", err)
  395. if strings.Contains(s.status.LastErrMsg, "unknown model") {
  396. s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
  397. }
  398. s.done <- errors.New(s.status.LastErrMsg)
  399. } else {
  400. s.done <- err
  401. }
  402. }()
  403. return s, nil
  404. }
  405. }
  406. type ServerStatus int
  407. const ( // iota is reset to 0
  408. ServerStatusReady ServerStatus = iota
  409. ServerStatusNoSlotsAvailable
  410. ServerStatusLoadingModel
  411. ServerStatusNotResponding
  412. ServerStatusError
  413. )
  414. func (s ServerStatus) String() string {
  415. switch s {
  416. case ServerStatusReady:
  417. return "llm server ready"
  418. case ServerStatusNoSlotsAvailable:
  419. return "llm busy - no slots available"
  420. case ServerStatusLoadingModel:
  421. return "llm server loading model"
  422. case ServerStatusNotResponding:
  423. return "llm server not responding"
  424. default:
  425. return "llm server error"
  426. }
  427. }
  428. type ServerStatusResponse struct {
  429. Status ServerStatus `json:"status"`
  430. Progress float32 `json:"progress"`
  431. }
  432. func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
  433. // Fail fast if its exited
  434. if s.cmd.ProcessState != nil {
  435. msg := ""
  436. if s.status != nil && s.status.LastErrMsg != "" {
  437. msg = s.status.LastErrMsg
  438. }
  439. if s.cmd.ProcessState.ExitCode() == -1 {
  440. // Most likely a signal killed it, log some more details to try to help troubleshoot
  441. slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
  442. }
  443. return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
  444. }
  445. req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/health", s.port), nil)
  446. if err != nil {
  447. return ServerStatusError, fmt.Errorf("error creating GET request: %v", err)
  448. }
  449. req.Header.Set("Content-Type", "application/json")
  450. resp, err := http.DefaultClient.Do(req)
  451. if err != nil {
  452. if errors.Is(err, context.DeadlineExceeded) {
  453. return ServerStatusNotResponding, errors.New("server not responding")
  454. }
  455. return ServerStatusError, fmt.Errorf("health resp: %w", err)
  456. }
  457. defer resp.Body.Close()
  458. body, err := io.ReadAll(resp.Body)
  459. if err != nil {
  460. return ServerStatusError, fmt.Errorf("read health request: %w", err)
  461. }
  462. var ssr ServerStatusResponse
  463. if err := json.Unmarshal(body, &ssr); err != nil {
  464. return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
  465. }
  466. switch ssr.Status {
  467. case ServerStatusLoadingModel:
  468. s.loadProgress = ssr.Progress
  469. return ssr.Status, nil
  470. case ServerStatusReady, ServerStatusNoSlotsAvailable:
  471. return ssr.Status, nil
  472. default:
  473. return ssr.Status, fmt.Errorf("server error: %+v", ssr)
  474. }
  475. }
  476. // getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received
  477. func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) {
  478. var retries int
  479. for {
  480. status, err := s.getServerStatus(ctx)
  481. if err != nil {
  482. return status, err
  483. }
  484. if status == ServerStatusNoSlotsAvailable {
  485. if retries >= 10 {
  486. return status, fmt.Errorf("no slots available after %d retries", retries)
  487. }
  488. time.Sleep(5 * time.Millisecond)
  489. retries++
  490. continue
  491. }
  492. return status, nil
  493. }
  494. }
  495. func (s *llmServer) Ping(ctx context.Context) error {
  496. _, err := s.getServerStatus(ctx)
  497. if err != nil {
  498. slog.Debug("server unhealthy", "error", err)
  499. return err
  500. }
  501. return nil
  502. }
  503. func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
  504. start := time.Now()
  505. stallDuration := envconfig.LoadTimeout() // If no progress happens
  506. stallTimer := time.Now().Add(stallDuration) // give up if we stall
  507. slog.Info("waiting for llama runner to start responding")
  508. var lastStatus ServerStatus = -1
  509. fullyLoaded := false
  510. for {
  511. select {
  512. case <-ctx.Done():
  513. slog.Warn("client connection closed before server finished loading, aborting load")
  514. return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
  515. case err := <-s.done:
  516. return fmt.Errorf("llama runner process has terminated: %w", err)
  517. default:
  518. }
  519. if time.Now().After(stallTimer) {
  520. // timeout
  521. msg := ""
  522. if s.status != nil && s.status.LastErrMsg != "" {
  523. msg = s.status.LastErrMsg
  524. }
  525. return fmt.Errorf("timed out waiting for llama runner to start - progress %0.2f - %s", s.loadProgress, msg)
  526. }
  527. if s.cmd.ProcessState != nil {
  528. msg := ""
  529. if s.status != nil && s.status.LastErrMsg != "" {
  530. msg = s.status.LastErrMsg
  531. }
  532. return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
  533. }
  534. ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
  535. defer cancel()
  536. priorProgress := s.loadProgress
  537. status, _ := s.getServerStatus(ctx)
  538. if lastStatus != status && status != ServerStatusReady {
  539. // Only log on status changes
  540. slog.Info("waiting for server to become available", "status", status)
  541. }
  542. switch status {
  543. case ServerStatusReady:
  544. s.loadDuration = time.Since(start)
  545. slog.Info(fmt.Sprintf("llama runner started in %0.2f seconds", s.loadDuration.Seconds()))
  546. return nil
  547. default:
  548. lastStatus = status
  549. // Reset the timer as long as we're making forward progress on the load
  550. if priorProgress != s.loadProgress {
  551. slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
  552. stallTimer = time.Now().Add(stallDuration)
  553. } else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
  554. slog.Debug("model load completed, waiting for server to become available", "status", status)
  555. stallTimer = time.Now().Add(stallDuration)
  556. fullyLoaded = true
  557. }
  558. time.Sleep(time.Millisecond * 250)
  559. continue
  560. }
  561. }
  562. }
  563. var grammarJSON = `
  564. root ::= object
  565. value ::= object | array | string | number | ("true" | "false" | "null") ws
  566. object ::=
  567. "{" ws (
  568. string ":" ws value
  569. ("," ws string ":" ws value)*
  570. )? "}" ws
  571. array ::=
  572. "[" ws (
  573. value
  574. ("," ws value)*
  575. )? "]" ws
  576. string ::=
  577. "\"" (
  578. [^"\\\x7F\x00-\x1F] |
  579. "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
  580. )* "\"" ws
  581. number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
  582. # Optional space: by convention, applied in this grammar after literal chars when allowed
  583. ws ::= ([ \t\n] ws)?
  584. `
  585. const maxBufferSize = 512 * format.KiloByte
  586. type ImageData struct {
  587. Data []byte `json:"data"`
  588. ID int `json:"id"`
  589. AspectRatioID int `json:"aspect_ratio_id"`
  590. }
  591. type CompletionRequest struct {
  592. Prompt string
  593. Format json.RawMessage
  594. Images []ImageData
  595. Options *api.Options
  596. Grammar string // set before sending the request to the subprocess
  597. }
  598. type CompletionResponse struct {
  599. Content string `json:"content"`
  600. DoneReason string `json:"done_reason"`
  601. Done bool `json:"done"`
  602. PromptEvalCount int `json:"prompt_eval_count"`
  603. PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
  604. EvalCount int `json:"eval_count"`
  605. EvalDuration time.Duration `json:"eval_duration"`
  606. }
  607. func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
  608. if len(req.Format) > 0 {
  609. switch string(req.Format) {
  610. case `null`, `""`:
  611. // Field was set, but "missing" a value. We accept
  612. // these as "not set".
  613. break
  614. case `"json"`:
  615. req.Grammar = grammarJSON
  616. default:
  617. if req.Format[0] != '{' {
  618. return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
  619. }
  620. // User provided a JSON schema
  621. g, err := grammar.FromSchema(nil, req.Format)
  622. if err != nil {
  623. return fmt.Errorf("invalid JSON schema in format: %w", err)
  624. }
  625. req.Grammar = string(g)
  626. }
  627. }
  628. if req.Options == nil {
  629. opts := api.DefaultOptions()
  630. req.Options = &opts
  631. }
  632. if req.Options == nil {
  633. opts := api.DefaultOptions()
  634. req.Options = &opts
  635. }
  636. if err := s.sem.Acquire(ctx, 1); err != nil {
  637. if errors.Is(err, context.Canceled) {
  638. slog.Info("aborting completion request due to client closing the connection")
  639. } else {
  640. slog.Error("Failed to acquire semaphore", "error", err)
  641. }
  642. return err
  643. }
  644. defer s.sem.Release(1)
  645. // put an upper limit on num_predict to avoid the model running on forever
  646. if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
  647. req.Options.NumPredict = 10 * s.options.NumCtx
  648. }
  649. // Make sure the server is ready
  650. status, err := s.getServerStatusRetry(ctx)
  651. if err != nil {
  652. return err
  653. } else if status != ServerStatusReady {
  654. return fmt.Errorf("unexpected server status: %s", status)
  655. }
  656. // Handling JSON marshaling with special characters unescaped.
  657. buffer := &bytes.Buffer{}
  658. enc := json.NewEncoder(buffer)
  659. enc.SetEscapeHTML(false)
  660. if err := enc.Encode(req); err != nil {
  661. return fmt.Errorf("failed to marshal data: %v", err)
  662. }
  663. endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
  664. serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
  665. if err != nil {
  666. return fmt.Errorf("error creating POST request: %v", err)
  667. }
  668. serverReq.Header.Set("Content-Type", "application/json")
  669. res, err := http.DefaultClient.Do(serverReq)
  670. if err != nil {
  671. return fmt.Errorf("POST predict: %v", err)
  672. }
  673. defer res.Body.Close()
  674. if res.StatusCode >= 400 {
  675. bodyBytes, err := io.ReadAll(res.Body)
  676. if err != nil {
  677. return fmt.Errorf("failed reading llm error response: %w", err)
  678. }
  679. log.Printf("llm predict error: %s", bodyBytes)
  680. return fmt.Errorf("%s", bodyBytes)
  681. }
  682. scanner := bufio.NewScanner(res.Body)
  683. buf := make([]byte, 0, maxBufferSize)
  684. scanner.Buffer(buf, maxBufferSize)
  685. // keep track of the last token generated, this is used to abort if the model starts looping
  686. var lastToken string
  687. var tokenRepeat int
  688. for scanner.Scan() {
  689. select {
  690. case <-ctx.Done():
  691. // This handles the request cancellation
  692. return ctx.Err()
  693. default:
  694. line := scanner.Bytes()
  695. if len(line) == 0 {
  696. continue
  697. }
  698. // slog.Debug("got line", "line", string(line))
  699. evt, ok := bytes.CutPrefix(line, []byte("data: "))
  700. if !ok {
  701. evt = line
  702. }
  703. var c CompletionResponse
  704. if err := json.Unmarshal(evt, &c); err != nil {
  705. return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
  706. }
  707. switch {
  708. case strings.TrimSpace(c.Content) == lastToken:
  709. tokenRepeat++
  710. default:
  711. lastToken = strings.TrimSpace(c.Content)
  712. tokenRepeat = 0
  713. }
  714. // 30 picked as an arbitrary max token repeat limit, modify as needed
  715. if tokenRepeat > 30 {
  716. slog.Debug("prediction aborted, token repeat limit reached")
  717. return ctx.Err()
  718. }
  719. if c.Content != "" {
  720. fn(CompletionResponse{
  721. Content: c.Content,
  722. })
  723. }
  724. if c.Done {
  725. fn(c)
  726. return nil
  727. }
  728. }
  729. }
  730. if err := scanner.Err(); err != nil {
  731. if strings.Contains(err.Error(), "unexpected EOF") || strings.Contains(err.Error(), "forcibly closed") {
  732. s.Close()
  733. var msg string
  734. if s.status != nil && s.status.LastErrMsg != "" {
  735. msg = s.status.LastErrMsg
  736. } else {
  737. msg = err.Error()
  738. }
  739. return fmt.Errorf("an error was encountered while running the model: %s", msg)
  740. }
  741. return fmt.Errorf("error reading llm response: %v", err)
  742. }
  743. return nil
  744. }
  745. type EmbeddingRequest struct {
  746. Content string `json:"content"`
  747. }
  748. type EmbeddingResponse struct {
  749. Embedding []float32 `json:"embedding"`
  750. }
  751. func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
  752. if err := s.sem.Acquire(ctx, 1); err != nil {
  753. if errors.Is(err, context.Canceled) {
  754. slog.Info("aborting embedding request due to client closing the connection")
  755. } else {
  756. slog.Error("Failed to acquire semaphore", "error", err)
  757. }
  758. return nil, err
  759. }
  760. defer s.sem.Release(1)
  761. // Make sure the server is ready
  762. status, err := s.getServerStatusRetry(ctx)
  763. if err != nil {
  764. return nil, err
  765. } else if status != ServerStatusReady {
  766. return nil, fmt.Errorf("unexpected server status: %s", status)
  767. }
  768. data, err := json.Marshal(EmbeddingRequest{Content: input})
  769. if err != nil {
  770. return nil, fmt.Errorf("error marshaling embed data: %w", err)
  771. }
  772. r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
  773. if err != nil {
  774. return nil, fmt.Errorf("error creating embed request: %w", err)
  775. }
  776. r.Header.Set("Content-Type", "application/json")
  777. resp, err := http.DefaultClient.Do(r)
  778. if err != nil {
  779. return nil, fmt.Errorf("do embedding request: %w", err)
  780. }
  781. defer resp.Body.Close()
  782. body, err := io.ReadAll(resp.Body)
  783. if err != nil {
  784. return nil, fmt.Errorf("error reading embed response: %w", err)
  785. }
  786. if resp.StatusCode >= 400 {
  787. log.Printf("llm embedding error: %s", body)
  788. return nil, fmt.Errorf("%s", body)
  789. }
  790. var e EmbeddingResponse
  791. if err := json.Unmarshal(body, &e); err != nil {
  792. return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
  793. }
  794. return e.Embedding, nil
  795. }
  796. type TokenizeRequest struct {
  797. Content string `json:"content"`
  798. }
  799. type TokenizeResponse struct {
  800. Tokens []int `json:"tokens"`
  801. }
  802. func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
  803. s.llamaModelLock.Lock()
  804. defer s.llamaModelLock.Unlock()
  805. if s.llamaModel != nil {
  806. return s.llamaModel.Tokenize(content, false, true)
  807. }
  808. if s.textProcessor != nil {
  809. tokens, err := s.textProcessor.Encode(content, false)
  810. if err != nil {
  811. return nil, err
  812. }
  813. toks := make([]int, len(tokens))
  814. for i, t := range tokens {
  815. toks[i] = int(t)
  816. }
  817. return toks, nil
  818. }
  819. // not reached
  820. return nil, fmt.Errorf("no tokenizer configured")
  821. }
  822. type DetokenizeRequest struct {
  823. Tokens []int `json:"tokens"`
  824. }
  825. type DetokenizeResponse struct {
  826. Content string `json:"content"`
  827. }
  828. func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
  829. s.llamaModelLock.Lock()
  830. defer s.llamaModelLock.Unlock()
  831. if s.llamaModel != nil {
  832. var resp string
  833. for _, token := range tokens {
  834. resp += s.llamaModel.TokenToPiece(token)
  835. }
  836. return resp, nil
  837. }
  838. if s.textProcessor != nil {
  839. toks := make([]int32, len(tokens))
  840. for i, t := range tokens {
  841. toks[i] = int32(t)
  842. }
  843. content, err := s.textProcessor.Decode(toks)
  844. if err != nil {
  845. return "", err
  846. }
  847. return content, nil
  848. }
  849. // not reached
  850. return "", fmt.Errorf("no tokenizer configured")
  851. }
  852. func (s *llmServer) Close() error {
  853. s.llamaModelLock.Lock()
  854. if s.llamaModel != nil {
  855. llama.FreeModel(s.llamaModel)
  856. s.llamaModel = nil
  857. }
  858. s.llamaModelLock.Unlock()
  859. if s.cmd != nil {
  860. slog.Debug("stopping llama server")
  861. if err := s.cmd.Process.Kill(); err != nil {
  862. return err
  863. }
  864. // if ProcessState is already populated, Wait already completed, no need to wait again
  865. if s.cmd.ProcessState == nil {
  866. slog.Debug("waiting for llama server to exit")
  867. <-s.done
  868. }
  869. slog.Debug("llama server stopped")
  870. }
  871. return nil
  872. }
  873. func (s *llmServer) EstimatedVRAM() uint64 {
  874. return s.estimate.VRAMSize
  875. }
  876. func (s *llmServer) EstimatedTotal() uint64 {
  877. return s.estimate.TotalSize
  878. }
  879. func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
  880. for i, gpu := range s.gpus {
  881. if gpu.ID == gpuID {
  882. if i < len(s.estimate.GPUSizes) {
  883. return s.estimate.GPUSizes[i]
  884. }
  885. }
  886. }
  887. return 0
  888. }