llama.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865
  1. package llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "embed"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log"
  13. "math/rand"
  14. "net/http"
  15. "os"
  16. "os/exec"
  17. "path"
  18. "path/filepath"
  19. "runtime"
  20. "strconv"
  21. "strings"
  22. "sync"
  23. "time"
  24. "github.com/jmorganca/ollama/api"
  25. "github.com/jmorganca/ollama/format"
  26. )
  27. const jsonGrammar = `
  28. root ::= object
  29. value ::= object | array | string | number | ("true" | "false" | "null") ws
  30. object ::=
  31. "{" ws (
  32. string ":" ws value
  33. ("," ws string ":" ws value)*
  34. )? "}" ws
  35. array ::=
  36. "[" ws (
  37. value
  38. ("," ws value)*
  39. )? "]" ws
  40. string ::=
  41. "\"" (
  42. [^"\\] |
  43. "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
  44. )* "\"" ws
  45. number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
  46. # Optional space: by convention, applied in this grammar after literal chars when allowed
  47. ws ::= ([ \t\n] ws)?
  48. `
  49. //go:embed llama.cpp/*/build/*/bin/*
  50. var llamaCppEmbed embed.FS
  51. type ModelRunner struct {
  52. Type string // "gguf" or "ggml"
  53. Path string // path to the model runner executable
  54. Accelerated bool
  55. }
  56. func chooseRunners(workDir, runnerType string) []ModelRunner {
  57. buildPath := path.Join("llama.cpp", runnerType, "build")
  58. var runners []ModelRunner
  59. // set the runners based on the OS
  60. // IMPORTANT: the order of the runners in the array is the priority order
  61. switch runtime.GOOS {
  62. case "darwin":
  63. if runtime.GOARCH == "arm64" {
  64. runners = []ModelRunner{{Type: runnerType, Path: path.Join(buildPath, "metal", "bin", "ollama-runner")}}
  65. } else {
  66. runners = []ModelRunner{{Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")}}
  67. }
  68. case "linux":
  69. runners = []ModelRunner{
  70. {Type: runnerType, Path: path.Join(buildPath, "cuda", "bin", "ollama-runner"), Accelerated: true},
  71. {Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
  72. }
  73. case "windows":
  74. // TODO: select windows GPU runner here when available
  75. runners = []ModelRunner{
  76. {Type: runnerType, Path: path.Join(buildPath, "cuda", "bin", "Release", "ollama-runner.exe"), Accelerated: true},
  77. {Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe")},
  78. }
  79. default:
  80. log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
  81. runners = []ModelRunner{
  82. {Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
  83. }
  84. }
  85. runnerAvailable := false // if no runner files are found in the embed, this flag will cause a fast fail
  86. for _, r := range runners {
  87. // find all the files in the runner's bin directory
  88. files, err := fs.Glob(llamaCppEmbed, path.Join(path.Dir(r.Path), "*"))
  89. if err != nil {
  90. // this is expected, ollama may be compiled without all runners packed in
  91. log.Printf("%s runner not found: %v", r.Path, err)
  92. continue
  93. }
  94. for _, f := range files {
  95. runnerAvailable = true
  96. srcFile, err := llamaCppEmbed.Open(f)
  97. if err != nil {
  98. log.Fatalf("read llama runner %s: %v", f, err)
  99. }
  100. defer srcFile.Close()
  101. // create the directory in case it does not exist, filepath.Dir() converts the file path to the OS's format
  102. destPath := filepath.Join(workDir, filepath.Dir(f))
  103. if err := os.MkdirAll(destPath, 0o755); err != nil {
  104. log.Fatalf("create runner temp dir %s: %v", filepath.Dir(f), err)
  105. }
  106. // create the path to the destination file, filepath.Base() converts the file path to the OS's format
  107. destFile := filepath.Join(destPath, filepath.Base(f))
  108. _, err = os.Stat(destFile)
  109. switch {
  110. case errors.Is(err, os.ErrNotExist):
  111. destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  112. if err != nil {
  113. log.Fatalf("write llama runner %s: %v", f, err)
  114. }
  115. defer destFile.Close()
  116. if _, err := io.Copy(destFile, srcFile); err != nil {
  117. log.Fatalf("copy llama runner %s: %v", f, err)
  118. }
  119. case err != nil:
  120. log.Fatalf("stat llama runner %s: %v", f, err)
  121. }
  122. }
  123. }
  124. if !runnerAvailable {
  125. log.Fatalf("%s runner not found", runnerType)
  126. }
  127. // return the runners to try in priority order
  128. localRunnersByPriority := []ModelRunner{}
  129. for _, r := range runners {
  130. // clean the ModelRunner paths so that they match the OS we are running on
  131. localRunnersByPriority = append(localRunnersByPriority, ModelRunner{
  132. Type: r.Type,
  133. Path: filepath.Clean(path.Join(workDir, r.Path)),
  134. Accelerated: r.Accelerated,
  135. })
  136. }
  137. return localRunnersByPriority
  138. }
  139. type llamaModel struct {
  140. hyperparameters llamaHyperparameters
  141. }
  142. func (llm *llamaModel) ModelFamily() string {
  143. return "llama"
  144. }
  145. func llamaModelType(numLayer uint32) string {
  146. switch numLayer {
  147. case 26:
  148. return "3B"
  149. case 32:
  150. return "7B"
  151. case 40:
  152. return "13B"
  153. case 48:
  154. return "34B"
  155. case 60:
  156. return "30B"
  157. case 80:
  158. return "65B"
  159. default:
  160. return "unknown"
  161. }
  162. }
  163. func (llm *llamaModel) ModelType() string {
  164. return llamaModelType(llm.hyperparameters.NumLayer)
  165. }
  166. func (llm *llamaModel) FileType() string {
  167. return fileType(llm.hyperparameters.FileType)
  168. }
  169. func (llm *llamaModel) NumLayers() int64 {
  170. return int64(llm.hyperparameters.NumLayer)
  171. }
  172. type llamaHyperparameters struct {
  173. // NumVocab is the size of the model's vocabulary.
  174. NumVocab uint32
  175. // NumEmbd is the size of the model's embedding layer.
  176. NumEmbd uint32
  177. NumMult uint32
  178. NumHead uint32
  179. // NumLayer is the number of layers in the model.
  180. NumLayer uint32
  181. NumRot uint32
  182. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
  183. FileType uint32
  184. }
  185. type Running struct {
  186. Port int
  187. Cmd *exec.Cmd
  188. Cancel context.CancelFunc
  189. exitOnce sync.Once
  190. exitCh chan error // channel to receive the exit status of the subprocess
  191. *StatusWriter // captures error messages from the llama runner process
  192. }
  193. type ImageData struct {
  194. Data []byte `json:"data"`
  195. ID int `json:"id"`
  196. }
  197. type llama struct {
  198. api.Options
  199. ImageData []ImageData
  200. Running
  201. }
  202. var (
  203. errNvidiaSMI = errors.New("warning: gpu support may not be enabled, check that you have installed GPU drivers: nvidia-smi command failed")
  204. errAvailableVRAM = errors.New("not enough VRAM available, falling back to CPU only")
  205. )
  206. // CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
  207. func CheckVRAM() (int64, error) {
  208. cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits")
  209. var stdout bytes.Buffer
  210. cmd.Stdout = &stdout
  211. err := cmd.Run()
  212. if err != nil {
  213. return 0, errNvidiaSMI
  214. }
  215. var freeMiB int64
  216. scanner := bufio.NewScanner(&stdout)
  217. for scanner.Scan() {
  218. line := scanner.Text()
  219. if strings.Contains(line, "[Insufficient Permissions]") {
  220. return 0, fmt.Errorf("GPU support may not enabled, check you have installed GPU drivers and have the necessary permissions to run nvidia-smi")
  221. }
  222. vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64)
  223. if err != nil {
  224. return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
  225. }
  226. freeMiB += vram
  227. }
  228. freeBytes := freeMiB * 1024 * 1024
  229. if freeBytes < 2*format.GigaByte {
  230. log.Printf("less than 2 GB VRAM available")
  231. return 0, errAvailableVRAM
  232. }
  233. return freeBytes, nil
  234. }
  235. func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
  236. if opts.NumGPU != -1 {
  237. return opts.NumGPU
  238. }
  239. if runtime.GOOS == "linux" || runtime.GOOS == "windows" {
  240. freeBytes, err := CheckVRAM()
  241. if err != nil {
  242. if !errors.Is(err, errNvidiaSMI) {
  243. log.Print(err.Error())
  244. }
  245. // nvidia driver not installed or no nvidia GPU found
  246. return 0
  247. }
  248. /*
  249. Calculate bytes per layer, this will roughly be the size of the model file divided by the number of layers.
  250. We can store the model weights and the kv cache in vram,
  251. to enable kv chache vram storage add two additional layers to the number of layers retrieved from the model file.
  252. */
  253. bytesPerLayer := fileSizeBytes / numLayer
  254. // 75% of the absolute max number of layers we can fit in available VRAM, off-loading too many layers to the GPU can cause OOM errors
  255. layers := int(freeBytes/bytesPerLayer) * 3 / 4
  256. log.Printf("%d MB VRAM available, loading up to %d GPU layers", freeBytes/(1024*1024), layers)
  257. return layers
  258. }
  259. // default to enable metal on macOS
  260. return 1
  261. }
  262. // StatusWriter is a writer that captures error messages from the llama runner process
  263. type StatusWriter struct {
  264. ErrCh chan error
  265. LastErrMsg string
  266. }
  267. func NewStatusWriter() *StatusWriter {
  268. return &StatusWriter{
  269. ErrCh: make(chan error, 1),
  270. }
  271. }
  272. func (w *StatusWriter) Write(b []byte) (int, error) {
  273. var errMsg string
  274. if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
  275. errMsg = string(bytes.TrimSpace(after))
  276. } else if _, after, ok := bytes.Cut(b, []byte("CUDA error")); ok {
  277. errMsg = string(bytes.TrimSpace(after))
  278. }
  279. if errMsg != "" {
  280. w.LastErrMsg = errMsg
  281. w.ErrCh <- fmt.Errorf("llama runner: %s", errMsg)
  282. }
  283. return os.Stderr.Write(b)
  284. }
  285. func newLlama(model string, adapters, projectors []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
  286. fileInfo, err := os.Stat(model)
  287. if err != nil {
  288. return nil, err
  289. }
  290. if len(adapters) > 1 {
  291. return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
  292. }
  293. numGPU := NumGPU(numLayers, fileInfo.Size(), opts)
  294. params := []string{
  295. "--model", model,
  296. "--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
  297. "--batch-size", fmt.Sprintf("%d", opts.NumBatch),
  298. "--n-gpu-layers", fmt.Sprintf("%d", numGPU),
  299. "--embedding",
  300. }
  301. if opts.MainGPU > 0 {
  302. params = append(params, "--main-gpu", fmt.Sprintf("%d", opts.MainGPU))
  303. }
  304. if opts.RopeFrequencyBase > 0 {
  305. params = append(params, "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase))
  306. }
  307. if opts.RopeFrequencyScale > 0 {
  308. params = append(params, "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale))
  309. }
  310. if opts.NumGQA > 0 {
  311. params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
  312. }
  313. if len(adapters) > 0 {
  314. // TODO: applying multiple adapters is not supported by the llama.cpp server yet
  315. params = append(params, "--lora", adapters[0])
  316. }
  317. if len(projectors) > 0 {
  318. // TODO: applying multiple projectors is not supported by the llama.cpp server yet
  319. params = append(params, "--mmproj", projectors[0])
  320. }
  321. if opts.NumThread > 0 {
  322. params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
  323. }
  324. if !opts.F16KV {
  325. params = append(params, "--memory-f32")
  326. }
  327. if opts.UseMLock {
  328. params = append(params, "--mlock")
  329. }
  330. if !opts.UseMMap {
  331. params = append(params, "--no-mmap")
  332. }
  333. if opts.UseNUMA {
  334. params = append(params, "--numa")
  335. }
  336. var runnerErr error
  337. // start the llama.cpp server with a retry in case the port is already in use
  338. for _, runner := range runners {
  339. if runner.Accelerated && numGPU == 0 {
  340. log.Printf("skipping accelerated runner because num_gpu=0")
  341. continue
  342. }
  343. if _, err := os.Stat(runner.Path); err != nil {
  344. log.Printf("llama runner not found: %v", err)
  345. continue
  346. }
  347. port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  348. params := append(params, "--port", strconv.Itoa(port))
  349. ctx, cancel := context.WithCancel(context.Background())
  350. cmd := exec.CommandContext(
  351. ctx,
  352. runner.Path,
  353. params...,
  354. )
  355. var libraryPaths []string
  356. if libraryPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
  357. libraryPaths = append(libraryPaths, libraryPath)
  358. }
  359. libraryPaths = append(libraryPaths, filepath.Dir(runner.Path))
  360. cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", strings.Join(libraryPaths, ":")))
  361. cmd.Stdout = os.Stderr
  362. statusWriter := NewStatusWriter()
  363. cmd.Stderr = statusWriter
  364. llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel, exitCh: make(chan error)}}
  365. log.Print("starting llama runner")
  366. if err := llm.Cmd.Start(); err != nil {
  367. log.Printf("error starting the external llama runner: %v", err)
  368. continue
  369. }
  370. // monitor the llama runner process and signal when it exits
  371. go func() {
  372. err := llm.Cmd.Wait()
  373. // default to printing the exit message of the command process, it will probably just say 'exit staus 1'
  374. errMsg := err.Error()
  375. // try to set a better error message if llama runner logs captured an error
  376. if statusWriter.LastErrMsg != "" {
  377. errMsg = statusWriter.LastErrMsg
  378. }
  379. log.Println(errMsg)
  380. // llm.Cmd.Wait() can only be called once, use this exit channel to signal that the process has exited
  381. llm.exitOnce.Do(func() {
  382. close(llm.exitCh)
  383. })
  384. }()
  385. if err := waitForServer(llm); err != nil {
  386. log.Printf("error starting llama runner: %v", err)
  387. llm.Close()
  388. // default the runnerErr to the error returned by the most recent llama runner process
  389. runnerErr = err
  390. // capture the error directly from the runner process, if any
  391. select {
  392. case runnerErr = <-statusWriter.ErrCh:
  393. default:
  394. // the runner process probably timed out
  395. }
  396. // try again
  397. continue
  398. }
  399. // server started successfully
  400. return llm, nil
  401. }
  402. if runnerErr != nil {
  403. // this is the error returned from the llama runner process that failed most recently
  404. return nil, runnerErr
  405. }
  406. return nil, fmt.Errorf("failed to start a llama runner")
  407. }
  408. func waitForServer(llm *llama) error {
  409. start := time.Now()
  410. expiresAt := time.Now().Add(3 * time.Minute) // be generous with timeout, large models can take a while to load
  411. ticker := time.NewTicker(200 * time.Millisecond)
  412. defer ticker.Stop()
  413. log.Print("waiting for llama runner to start responding")
  414. for {
  415. select {
  416. case <-llm.exitCh:
  417. // failed to start subprocess
  418. return fmt.Errorf("llama runner process has terminated")
  419. case <-ticker.C:
  420. if time.Now().After(expiresAt) {
  421. // timeout
  422. return fmt.Errorf("timed out waiting for llama runner to start")
  423. }
  424. if err := llm.Ping(context.Background()); err == nil {
  425. // success
  426. log.Printf("llama runner started in %f seconds", time.Since(start).Seconds())
  427. return nil
  428. }
  429. }
  430. }
  431. }
  432. func (llm *llama) Close() {
  433. // signal the sub-process to terminate
  434. llm.Cancel()
  435. // wait for the command to exit to prevent race conditions with the next run
  436. <-llm.exitCh
  437. if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
  438. log.Printf("llama runner stopped with error: %v", llm.StatusWriter.LastErrMsg)
  439. } else {
  440. log.Print("llama runner stopped successfully")
  441. }
  442. }
  443. func (llm *llama) SetOptions(opts api.Options) {
  444. llm.Options = opts
  445. }
  446. type prediction struct {
  447. Content string `json:"content"`
  448. Model string `json:"model"`
  449. Prompt string `json:"prompt"`
  450. Stop bool `json:"stop"`
  451. Timings struct {
  452. PredictedN int `json:"predicted_n"`
  453. PredictedMS float64 `json:"predicted_ms"`
  454. PromptN int `json:"prompt_n"`
  455. PromptMS float64 `json:"prompt_ms"`
  456. }
  457. }
  458. const maxBufferSize = 512 * format.KiloByte
  459. const maxRetries = 6
  460. type PredictOpts struct {
  461. Prompt string
  462. Format string
  463. Images []api.ImageData
  464. }
  465. type PredictResult struct {
  466. Content string
  467. Done bool
  468. PromptEvalCount int
  469. PromptEvalDuration time.Duration
  470. EvalCount int
  471. EvalDuration time.Duration
  472. }
  473. // IsRetryable checks if the line matches a condition that can be retried
  474. func isRetryable(line []byte) bool {
  475. return bytes.Contains(line, []byte("slot unavailable"))
  476. }
  477. func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
  478. imageData := llm.ImageData
  479. if len(predict.Images) > 0 {
  480. for cnt, i := range predict.Images {
  481. imageData = append(imageData, ImageData{Data: i, ID: cnt})
  482. }
  483. }
  484. log.Printf("loaded %d images", len(imageData))
  485. request := map[string]any{
  486. "prompt": predict.Prompt,
  487. "stream": true,
  488. "n_predict": llm.NumPredict,
  489. "n_keep": llm.NumKeep,
  490. "main_gpu": llm.MainGPU,
  491. "temperature": llm.Temperature,
  492. "top_k": llm.TopK,
  493. "top_p": llm.TopP,
  494. "tfs_z": llm.TFSZ,
  495. "typical_p": llm.TypicalP,
  496. "repeat_last_n": llm.RepeatLastN,
  497. "repeat_penalty": llm.RepeatPenalty,
  498. "presence_penalty": llm.PresencePenalty,
  499. "frequency_penalty": llm.FrequencyPenalty,
  500. "mirostat": llm.Mirostat,
  501. "mirostat_tau": llm.MirostatTau,
  502. "mirostat_eta": llm.MirostatEta,
  503. "penalize_nl": llm.PenalizeNewline,
  504. "seed": llm.Seed,
  505. "stop": llm.Stop,
  506. "image_data": imageData,
  507. }
  508. if predict.Format == "json" {
  509. request["grammar"] = jsonGrammar
  510. }
  511. retryDelay := 100 * time.Microsecond
  512. for retries := 0; retries < maxRetries; retries++ {
  513. if retries > 0 {
  514. time.Sleep(retryDelay) // wait before retrying
  515. retryDelay *= 2 // exponential backoff
  516. }
  517. // Handling JSON marshaling with special characters unescaped.
  518. buffer := &bytes.Buffer{}
  519. enc := json.NewEncoder(buffer)
  520. enc.SetEscapeHTML(false)
  521. if err := enc.Encode(request); err != nil {
  522. return fmt.Errorf("failed to marshal data: %v", err)
  523. }
  524. endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
  525. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
  526. if err != nil {
  527. return fmt.Errorf("error creating POST request: %v", err)
  528. }
  529. req.Header.Set("Content-Type", "application/json")
  530. resp, err := http.DefaultClient.Do(req)
  531. if err != nil {
  532. return fmt.Errorf("POST predict: %v", err)
  533. }
  534. defer resp.Body.Close()
  535. if resp.StatusCode >= 400 {
  536. bodyBytes, err := io.ReadAll(resp.Body)
  537. if err != nil {
  538. return fmt.Errorf("failed reading llm error response: %w", err)
  539. }
  540. log.Printf("llm predict error: %s", bodyBytes)
  541. return fmt.Errorf("%s", bodyBytes)
  542. }
  543. scanner := bufio.NewScanner(resp.Body)
  544. // increase the buffer size to avoid running out of space
  545. buf := make([]byte, 0, maxBufferSize)
  546. scanner.Buffer(buf, maxBufferSize)
  547. retryNeeded := false
  548. for scanner.Scan() {
  549. select {
  550. case <-ctx.Done():
  551. // This handles the request cancellation
  552. return ctx.Err()
  553. default:
  554. line := scanner.Bytes()
  555. if len(line) == 0 {
  556. continue
  557. }
  558. if isRetryable(line) {
  559. retryNeeded = true
  560. break
  561. }
  562. evt, ok := bytes.CutPrefix(line, []byte("data: "))
  563. if !ok {
  564. return fmt.Errorf("error parsing llm response stream: %s", line)
  565. }
  566. var p prediction
  567. if err := json.Unmarshal(evt, &p); err != nil {
  568. return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
  569. }
  570. if p.Content != "" {
  571. fn(PredictResult{
  572. Content: p.Content,
  573. })
  574. }
  575. if p.Stop {
  576. fn(PredictResult{
  577. Done: true,
  578. PromptEvalCount: p.Timings.PromptN,
  579. PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
  580. EvalCount: p.Timings.PredictedN,
  581. EvalDuration: parseDurationMs(p.Timings.PredictedMS),
  582. })
  583. return nil
  584. }
  585. }
  586. }
  587. if err := scanner.Err(); err != nil {
  588. if strings.Contains(err.Error(), "unexpected EOF") {
  589. // this means the llama runner subprocess crashed
  590. llm.Close()
  591. if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
  592. return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
  593. }
  594. return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
  595. }
  596. return fmt.Errorf("error reading llm response: %v", err)
  597. }
  598. if !retryNeeded {
  599. return nil // success
  600. }
  601. }
  602. // should never reach here ideally
  603. return fmt.Errorf("max retries exceeded")
  604. }
  605. type TokenizeRequest struct {
  606. Content string `json:"content"`
  607. }
  608. type TokenizeResponse struct {
  609. Tokens []int `json:"tokens"`
  610. }
  611. func (llm *llama) Encode(ctx context.Context, prompt string) ([]int, error) {
  612. endpoint := fmt.Sprintf("http://127.0.0.1:%d/tokenize", llm.Port)
  613. data, err := json.Marshal(TokenizeRequest{Content: prompt})
  614. if err != nil {
  615. return nil, fmt.Errorf("marshaling encode data: %w", err)
  616. }
  617. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  618. if err != nil {
  619. return nil, fmt.Errorf("encode request: %w", err)
  620. }
  621. req.Header.Set("Content-Type", "application/json")
  622. resp, err := http.DefaultClient.Do(req)
  623. if err != nil {
  624. return nil, fmt.Errorf("do encode request: %w", err)
  625. }
  626. defer resp.Body.Close()
  627. body, err := io.ReadAll(resp.Body)
  628. if err != nil {
  629. return nil, fmt.Errorf("read encode request: %w", err)
  630. }
  631. if resp.StatusCode >= 400 {
  632. log.Printf("llm encode error: %s", body)
  633. return nil, fmt.Errorf("%s", body)
  634. }
  635. var encoded TokenizeResponse
  636. if err := json.Unmarshal(body, &encoded); err != nil {
  637. return nil, fmt.Errorf("unmarshal encode response: %w", err)
  638. }
  639. return encoded.Tokens, nil
  640. }
  641. type DetokenizeRequest struct {
  642. Tokens []int `json:"tokens"`
  643. }
  644. type DetokenizeResponse struct {
  645. Content string `json:"content"`
  646. }
  647. func (llm *llama) Decode(ctx context.Context, tokens []int) (string, error) {
  648. if len(tokens) == 0 {
  649. return "", nil
  650. }
  651. endpoint := fmt.Sprintf("http://127.0.0.1:%d/detokenize", llm.Port)
  652. data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
  653. if err != nil {
  654. return "", fmt.Errorf("marshaling decode data: %w", err)
  655. }
  656. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  657. if err != nil {
  658. return "", fmt.Errorf("decode request: %w", err)
  659. }
  660. req.Header.Set("Content-Type", "application/json")
  661. resp, err := http.DefaultClient.Do(req)
  662. if err != nil {
  663. return "", fmt.Errorf("do decode request: %w", err)
  664. }
  665. defer resp.Body.Close()
  666. body, err := io.ReadAll(resp.Body)
  667. if err != nil {
  668. return "", fmt.Errorf("read decode request: %w", err)
  669. }
  670. if resp.StatusCode >= 400 {
  671. log.Printf("llm decode error: %s", body)
  672. return "", fmt.Errorf("%s", body)
  673. }
  674. var decoded DetokenizeResponse
  675. if err := json.Unmarshal(body, &decoded); err != nil {
  676. return "", fmt.Errorf("unmarshal encode response: %w", err)
  677. }
  678. return decoded.Content, nil
  679. }
  680. type EmbeddingRequest struct {
  681. Content string `json:"content"`
  682. }
  683. type EmbeddingResponse struct {
  684. Embedding []float64 `json:"embedding"`
  685. }
  686. func (llm *llama) Embedding(ctx context.Context, input string) ([]float64, error) {
  687. endpoint := fmt.Sprintf("http://127.0.0.1:%d/embedding", llm.Port)
  688. data, err := json.Marshal(TokenizeRequest{Content: input})
  689. if err != nil {
  690. return nil, fmt.Errorf("error marshaling embed data: %w", err)
  691. }
  692. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  693. if err != nil {
  694. return nil, fmt.Errorf("error creating embed request: %w", err)
  695. }
  696. req.Header.Set("Content-Type", "application/json")
  697. resp, err := http.DefaultClient.Do(req)
  698. if err != nil {
  699. return nil, fmt.Errorf("POST embedding: %w", err)
  700. }
  701. defer resp.Body.Close()
  702. body, err := io.ReadAll(resp.Body)
  703. if err != nil {
  704. return nil, fmt.Errorf("error reading embed response: %w", err)
  705. }
  706. if resp.StatusCode >= 400 {
  707. log.Printf("llm encode error: %s", body)
  708. return nil, fmt.Errorf("%s", body)
  709. }
  710. var embedding EmbeddingResponse
  711. if err := json.Unmarshal(body, &embedding); err != nil {
  712. return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
  713. }
  714. return embedding.Embedding, nil
  715. }
  716. // Ping checks that the server subprocess is still running and responding to requests
  717. func (llm *llama) Ping(ctx context.Context) error {
  718. resp, err := http.Head(fmt.Sprintf("http://127.0.0.1:%d", llm.Port))
  719. if err != nil {
  720. return fmt.Errorf("ping resp: %w", err)
  721. }
  722. if resp.StatusCode != http.StatusOK {
  723. return fmt.Errorf("unexpected ping status: %s", resp.Status)
  724. }
  725. return nil
  726. }