llama.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675
  1. package llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "embed"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log"
  13. "math/rand"
  14. "net/http"
  15. "os"
  16. "os/exec"
  17. "path"
  18. "path/filepath"
  19. "runtime"
  20. "strconv"
  21. "strings"
  22. "time"
  23. "github.com/jmorganca/ollama/api"
  24. )
  25. //go:embed llama.cpp/*/build/*/bin/*
  26. var llamaCppEmbed embed.FS
  27. func osPath(llamaPath string) string {
  28. if runtime.GOOS == "windows" {
  29. return path.Join(llamaPath, "Release")
  30. }
  31. return llamaPath
  32. }
  33. func chooseRunner(gpuPath, cpuPath string) string {
  34. tmpDir, err := os.MkdirTemp("", "llama-*")
  35. if err != nil {
  36. log.Fatalf("llama.cpp: failed to create temp dir: %v", err)
  37. }
  38. llamaPath := osPath(gpuPath)
  39. if _, err := fs.Stat(llamaCppEmbed, llamaPath); err != nil {
  40. llamaPath = osPath(cpuPath)
  41. if _, err := fs.Stat(llamaCppEmbed, llamaPath); err != nil {
  42. log.Fatalf("llama.cpp executable not found")
  43. }
  44. }
  45. files := []string{"server"}
  46. switch runtime.GOOS {
  47. case "windows":
  48. files = []string{"server.exe"}
  49. case "darwin":
  50. if llamaPath == osPath(gpuPath) {
  51. files = append(files, "ggml-metal.metal")
  52. }
  53. case "linux":
  54. // check if there is a GPU available
  55. if _, err := CheckVRAM(); errors.Is(err, errNoGPU) {
  56. // this error was logged on start-up, so we don't need to log it again
  57. llamaPath = osPath(cpuPath)
  58. }
  59. }
  60. for _, f := range files {
  61. srcPath := path.Join(llamaPath, f)
  62. destPath := filepath.Join(tmpDir, f)
  63. srcFile, err := llamaCppEmbed.Open(srcPath)
  64. if err != nil {
  65. log.Fatalf("read llama.cpp %s: %v", f, err)
  66. }
  67. defer srcFile.Close()
  68. destFile, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  69. if err != nil {
  70. log.Fatalf("write llama.cpp %s: %v", f, err)
  71. }
  72. defer destFile.Close()
  73. if _, err := io.Copy(destFile, srcFile); err != nil {
  74. log.Fatalf("copy llama.cpp %s: %v", f, err)
  75. }
  76. }
  77. runPath := filepath.Join(tmpDir, "server")
  78. if runtime.GOOS == "windows" {
  79. runPath = filepath.Join(tmpDir, "server.exe")
  80. }
  81. return runPath
  82. }
  83. type llamaModel struct {
  84. hyperparameters llamaHyperparameters
  85. }
  86. func (llm *llamaModel) ModelFamily() string {
  87. return "llama"
  88. }
  89. func llamaModelType(numLayer uint32) string {
  90. switch numLayer {
  91. case 26:
  92. return "3B"
  93. case 32:
  94. return "7B"
  95. case 40:
  96. return "13B"
  97. case 48:
  98. return "34B"
  99. case 60:
  100. return "30B"
  101. case 80:
  102. return "65B"
  103. default:
  104. return "Unknown"
  105. }
  106. }
  107. func (llm *llamaModel) ModelType() string {
  108. return llamaModelType(llm.hyperparameters.NumLayer)
  109. }
  110. func (llm *llamaModel) FileType() string {
  111. return fileType(llm.hyperparameters.FileType)
  112. }
  113. type llamaHyperparameters struct {
  114. // NumVocab is the size of the model's vocabulary.
  115. NumVocab uint32
  116. // NumEmbd is the size of the model's embedding layer.
  117. NumEmbd uint32
  118. NumMult uint32
  119. NumHead uint32
  120. // NumLayer is the number of layers in the model.
  121. NumLayer uint32
  122. NumRot uint32
  123. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
  124. FileType uint32
  125. }
  126. type Running struct {
  127. Port int
  128. Cmd *exec.Cmd
  129. Cancel context.CancelFunc
  130. }
  131. type ModelRunner struct {
  132. Path string // path to the model runner executable
  133. }
  134. type llama struct {
  135. api.Options
  136. Running
  137. }
  138. var errNoGPU = errors.New("nvidia-smi command failed")
  139. // CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs
  140. func CheckVRAM() (int, error) {
  141. cmd := exec.Command("nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits")
  142. var stdout bytes.Buffer
  143. cmd.Stdout = &stdout
  144. err := cmd.Run()
  145. if err != nil {
  146. return 0, errNoGPU
  147. }
  148. var total int
  149. scanner := bufio.NewScanner(&stdout)
  150. for scanner.Scan() {
  151. line := scanner.Text()
  152. vram, err := strconv.Atoi(line)
  153. if err != nil {
  154. return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
  155. }
  156. total += vram
  157. }
  158. return total, nil
  159. }
  160. func NumGPU(opts api.Options) int {
  161. if opts.NumGPU != -1 {
  162. return opts.NumGPU
  163. }
  164. n := 1 // default to enable metal on macOS
  165. if runtime.GOOS == "linux" {
  166. vram, err := CheckVRAM()
  167. if err != nil {
  168. if err.Error() != "nvidia-smi command failed" {
  169. log.Print(err.Error())
  170. }
  171. // nvidia driver not installed or no nvidia GPU found
  172. return 0
  173. }
  174. // TODO: this is a very rough heuristic, better would be to calculate this based on number of layers and context size
  175. switch {
  176. case vram < 500:
  177. log.Printf("WARNING: Low VRAM detected, disabling GPU")
  178. n = 0
  179. case vram < 1000:
  180. n = 4
  181. case vram < 2000:
  182. n = 8
  183. case vram < 4000:
  184. n = 12
  185. case vram < 8000:
  186. n = 16
  187. case vram < 12000:
  188. n = 24
  189. case vram < 16000:
  190. n = 32
  191. default:
  192. n = 48
  193. }
  194. log.Printf("%d MB VRAM available, loading %d GPU layers", vram, n)
  195. }
  196. return n
  197. }
  198. func newLlama(model string, adapters []string, runner ModelRunner, opts api.Options) (*llama, error) {
  199. if _, err := os.Stat(model); err != nil {
  200. return nil, err
  201. }
  202. if _, err := os.Stat(runner.Path); err != nil {
  203. return nil, err
  204. }
  205. if len(adapters) > 1 {
  206. return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
  207. }
  208. params := []string{
  209. "--model", model,
  210. "--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
  211. "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
  212. "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
  213. "--batch-size", fmt.Sprintf("%d", opts.NumBatch),
  214. "--n-gpu-layers", fmt.Sprintf("%d", NumGPU(opts)),
  215. "--embedding",
  216. }
  217. if opts.NumGQA > 0 {
  218. params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
  219. }
  220. if len(adapters) > 0 {
  221. // TODO: applying multiple adapters is not supported by the llama.cpp server yet
  222. params = append(params, "--lora", adapters[0])
  223. }
  224. if opts.NumThread > 0 {
  225. params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
  226. }
  227. if !opts.F16KV {
  228. params = append(params, "--memory-f32")
  229. }
  230. if opts.UseMLock {
  231. params = append(params, "--mlock")
  232. }
  233. if !opts.UseMMap {
  234. params = append(params, "--no-mmap")
  235. }
  236. if opts.UseNUMA {
  237. params = append(params, "--numa")
  238. }
  239. // start the llama.cpp server with a retry in case the port is already in use
  240. for try := 0; try < 3; try++ {
  241. port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  242. ctx, cancel := context.WithCancel(context.Background())
  243. cmd := exec.CommandContext(
  244. ctx,
  245. runner.Path,
  246. append(params, "--port", strconv.Itoa(port))...,
  247. )
  248. cmd.Stdout = os.Stderr
  249. cmd.Stderr = os.Stderr
  250. llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
  251. log.Print("starting llama.cpp server")
  252. if err := llm.Cmd.Start(); err != nil {
  253. log.Printf("error starting the external llama.cpp server: %v", err)
  254. continue
  255. }
  256. if err := waitForServer(llm); err != nil {
  257. log.Printf("error starting llama.cpp server: %v", err)
  258. llm.Close()
  259. // try again
  260. continue
  261. }
  262. // server started successfully
  263. return llm, nil
  264. }
  265. return nil, fmt.Errorf("max retry exceeded starting llama.cpp")
  266. }
  267. func waitForServer(llm *llama) error {
  268. // wait for the server to start responding
  269. start := time.Now()
  270. expiresAt := time.Now().Add(45 * time.Second)
  271. ticker := time.NewTicker(200 * time.Millisecond)
  272. log.Print("waiting for llama.cpp server to start responding")
  273. for range ticker.C {
  274. if time.Now().After(expiresAt) {
  275. return fmt.Errorf("llama.cpp server did not start within alloted time, retrying")
  276. }
  277. if err := llm.Ping(context.Background()); err == nil {
  278. break
  279. }
  280. }
  281. log.Printf("llama.cpp server started in %f seconds", time.Since(start).Seconds())
  282. return nil
  283. }
  284. func (llm *llama) Close() {
  285. llm.Cancel()
  286. if err := llm.Cmd.Wait(); err != nil {
  287. log.Printf("llama.cpp server exited with error: %v", err)
  288. }
  289. }
  290. func (llm *llama) SetOptions(opts api.Options) {
  291. llm.Options = opts
  292. }
  293. type GenerationSettings struct {
  294. FrequencyPenalty float64 `json:"frequency_penalty"`
  295. IgnoreEOS bool `json:"ignore_eos"`
  296. LogitBias []interface{} `json:"logit_bias"`
  297. Mirostat int `json:"mirostat"`
  298. MirostatEta float64 `json:"mirostat_eta"`
  299. MirostatTau float64 `json:"mirostat_tau"`
  300. Model string `json:"model"`
  301. NCtx int `json:"n_ctx"`
  302. NKeep int `json:"n_keep"`
  303. NPredict int `json:"n_predict"`
  304. NProbs int `json:"n_probs"`
  305. PenalizeNl bool `json:"penalize_nl"`
  306. PresencePenalty float64 `json:"presence_penalty"`
  307. RepeatLastN int `json:"repeat_last_n"`
  308. RepeatPenalty float64 `json:"repeat_penalty"`
  309. Seed uint32 `json:"seed"`
  310. Stop []string `json:"stop"`
  311. Stream bool `json:"stream"`
  312. Temp float64 `json:"temp"`
  313. TfsZ float64 `json:"tfs_z"`
  314. TopK int `json:"top_k"`
  315. TopP float64 `json:"top_p"`
  316. TypicalP float64 `json:"typical_p"`
  317. }
  318. type Timings struct {
  319. PredictedN int `json:"predicted_n"`
  320. PredictedMS float64 `json:"predicted_ms"`
  321. PromptN int `json:"prompt_n"`
  322. PromptMS float64 `json:"prompt_ms"`
  323. }
  324. type Prediction struct {
  325. Content string `json:"content"`
  326. Model string `json:"model"`
  327. Prompt string `json:"prompt"`
  328. Stop bool `json:"stop"`
  329. Timings `json:"timings"`
  330. }
  331. type PredictRequest struct {
  332. Stream bool `json:"stream"`
  333. NPredict int `json:"n_predict,omitempty"`
  334. TopK int `json:"top_k,omitempty"`
  335. TopP float32 `json:"top_p,omitempty"`
  336. TfsZ float32 `json:"tfs_z,omitempty"`
  337. TypicalP float32 `json:"typical_p,omitempty"`
  338. RepeatLastN int `json:"repeat_last_n,omitempty"`
  339. Temperature float32 `json:"temperature,omitempty"`
  340. RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
  341. PresencePenalty float32 `json:"presence_penalty,omitempty"`
  342. FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
  343. Mirostat int `json:"mirostat,omitempty"`
  344. MirostatTau float32 `json:"mirostat_tau,omitempty"`
  345. MirostatEta float32 `json:"mirostat_eta,omitempty"`
  346. PenalizeNl bool `json:"penalize_nl,omitempty"`
  347. NKeep int `json:"n_keep,omitempty"`
  348. Seed int `json:"seed,omitempty"`
  349. Prompt string `json:"prompt,omitempty"`
  350. NProbs int `json:"n_probs,omitempty"`
  351. LogitBias map[int]float32 `json:"logit_bias,omitempty"`
  352. IgnoreEos bool `json:"ignore_eos,omitempty"`
  353. Stop []string `json:"stop,omitempty"`
  354. }
  355. func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
  356. prevConvo, err := llm.Decode(ctx, prevContext)
  357. if err != nil {
  358. return err
  359. }
  360. var nextContext strings.Builder
  361. nextContext.WriteString(prevConvo)
  362. nextContext.WriteString(prompt)
  363. endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
  364. predReq := PredictRequest{
  365. Prompt: nextContext.String(),
  366. Stream: true,
  367. NPredict: llm.NumPredict,
  368. NKeep: llm.NumKeep,
  369. Temperature: llm.Temperature,
  370. TopK: llm.TopK,
  371. TopP: llm.TopP,
  372. TfsZ: llm.TFSZ,
  373. TypicalP: llm.TypicalP,
  374. RepeatLastN: llm.RepeatLastN,
  375. RepeatPenalty: llm.RepeatPenalty,
  376. PresencePenalty: llm.PresencePenalty,
  377. FrequencyPenalty: llm.FrequencyPenalty,
  378. Mirostat: llm.Mirostat,
  379. MirostatTau: llm.MirostatTau,
  380. MirostatEta: llm.MirostatEta,
  381. PenalizeNl: llm.PenalizeNewline,
  382. Stop: llm.Stop,
  383. }
  384. data, err := json.Marshal(predReq)
  385. if err != nil {
  386. return fmt.Errorf("error marshaling data: %v", err)
  387. }
  388. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  389. if err != nil {
  390. return fmt.Errorf("error creating POST request: %v", err)
  391. }
  392. req.Header.Set("Content-Type", "application/json")
  393. resp, err := http.DefaultClient.Do(req)
  394. if err != nil {
  395. return fmt.Errorf("POST predict: %v", err)
  396. }
  397. defer resp.Body.Close()
  398. if resp.StatusCode >= 400 {
  399. bodyBytes, err := io.ReadAll(resp.Body)
  400. if err != nil {
  401. return fmt.Errorf("failed reading llm error response: %w", err)
  402. }
  403. log.Printf("llm predict error: %s", bodyBytes)
  404. return fmt.Errorf("%s", bodyBytes)
  405. }
  406. scanner := bufio.NewScanner(resp.Body)
  407. for scanner.Scan() {
  408. select {
  409. case <-ctx.Done():
  410. // This handles the request cancellation
  411. return ctx.Err()
  412. default:
  413. line := scanner.Text()
  414. if line == "" {
  415. continue
  416. }
  417. // Read data from the server-side event stream
  418. if strings.HasPrefix(line, "data: ") {
  419. evt := line[6:]
  420. var p Prediction
  421. if err := json.Unmarshal([]byte(evt), &p); err != nil {
  422. return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
  423. }
  424. if p.Content != "" {
  425. fn(api.GenerateResponse{Response: p.Content})
  426. nextContext.WriteString(p.Content)
  427. }
  428. if p.Stop {
  429. embd, err := llm.Encode(ctx, nextContext.String())
  430. if err != nil {
  431. return fmt.Errorf("encoding context: %v", err)
  432. }
  433. fn(api.GenerateResponse{
  434. Done: true,
  435. Context: embd,
  436. PromptEvalCount: p.PromptN,
  437. PromptEvalDuration: parseDurationMs(p.PromptMS),
  438. EvalCount: p.PredictedN,
  439. EvalDuration: parseDurationMs(p.PredictedMS),
  440. })
  441. return nil
  442. }
  443. }
  444. }
  445. }
  446. if err := scanner.Err(); err != nil {
  447. return fmt.Errorf("error reading llm response: %v", err)
  448. }
  449. return nil
  450. }
  451. type TokenizeRequest struct {
  452. Content string `json:"content"`
  453. }
  454. type TokenizeResponse struct {
  455. Tokens []int `json:"tokens"`
  456. }
  457. func (llm *llama) Encode(ctx context.Context, prompt string) ([]int, error) {
  458. endpoint := fmt.Sprintf("http://127.0.0.1:%d/tokenize", llm.Port)
  459. data, err := json.Marshal(TokenizeRequest{Content: prompt})
  460. if err != nil {
  461. return nil, fmt.Errorf("marshaling encode data: %w", err)
  462. }
  463. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  464. if err != nil {
  465. return nil, fmt.Errorf("encode request: %w", err)
  466. }
  467. req.Header.Set("Content-Type", "application/json")
  468. resp, err := http.DefaultClient.Do(req)
  469. if err != nil {
  470. return nil, fmt.Errorf("do encode request: %w", err)
  471. }
  472. defer resp.Body.Close()
  473. body, err := io.ReadAll(resp.Body)
  474. if err != nil {
  475. return nil, fmt.Errorf("read encode request: %w", err)
  476. }
  477. if resp.StatusCode >= 400 {
  478. log.Printf("llm encode error: %s", body)
  479. return nil, fmt.Errorf("%s", body)
  480. }
  481. var encoded TokenizeResponse
  482. if err := json.Unmarshal(body, &encoded); err != nil {
  483. return nil, fmt.Errorf("unmarshal encode response: %w", err)
  484. }
  485. return encoded.Tokens, nil
  486. }
  487. type DetokenizeRequest struct {
  488. Tokens []int `json:"tokens"`
  489. }
  490. type DetokenizeResponse struct {
  491. Content string `json:"content"`
  492. }
  493. func (llm *llama) Decode(ctx context.Context, tokens []int) (string, error) {
  494. if len(tokens) == 0 {
  495. return "", nil
  496. }
  497. endpoint := fmt.Sprintf("http://127.0.0.1:%d/detokenize", llm.Port)
  498. data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
  499. if err != nil {
  500. return "", fmt.Errorf("marshaling decode data: %w", err)
  501. }
  502. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  503. if err != nil {
  504. return "", fmt.Errorf("decode request: %w", err)
  505. }
  506. req.Header.Set("Content-Type", "application/json")
  507. resp, err := http.DefaultClient.Do(req)
  508. if err != nil {
  509. return "", fmt.Errorf("do decode request: %w", err)
  510. }
  511. defer resp.Body.Close()
  512. body, err := io.ReadAll(resp.Body)
  513. if err != nil {
  514. return "", fmt.Errorf("read decode request: %w", err)
  515. }
  516. if resp.StatusCode >= 400 {
  517. log.Printf("llm decode error: %s", body)
  518. return "", fmt.Errorf("%s", body)
  519. }
  520. var decoded DetokenizeResponse
  521. if err := json.Unmarshal(body, &decoded); err != nil {
  522. return "", fmt.Errorf("unmarshal encode response: %w", err)
  523. }
  524. // decoded content contains a leading whitespace
  525. decoded.Content, _ = strings.CutPrefix(decoded.Content, "")
  526. return decoded.Content, nil
  527. }
  528. type EmbeddingRequest struct {
  529. Content string `json:"content"`
  530. }
  531. type EmbeddingResponse struct {
  532. Embedding []float64 `json:"embedding"`
  533. }
  534. func (llm *llama) Embedding(ctx context.Context, input string) ([]float64, error) {
  535. endpoint := fmt.Sprintf("http://127.0.0.1:%d/embedding", llm.Port)
  536. data, err := json.Marshal(TokenizeRequest{Content: input})
  537. if err != nil {
  538. return nil, fmt.Errorf("error marshaling embed data: %w", err)
  539. }
  540. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  541. if err != nil {
  542. return nil, fmt.Errorf("error creating embed request: %w", err)
  543. }
  544. req.Header.Set("Content-Type", "application/json")
  545. resp, err := http.DefaultClient.Do(req)
  546. if err != nil {
  547. return nil, fmt.Errorf("POST embedding: %w", err)
  548. }
  549. defer resp.Body.Close()
  550. body, err := io.ReadAll(resp.Body)
  551. if err != nil {
  552. return nil, fmt.Errorf("error reading embed response: %w", err)
  553. }
  554. if resp.StatusCode >= 400 {
  555. log.Printf("llm encode error: %s", body)
  556. return nil, fmt.Errorf("%s", body)
  557. }
  558. var embedding EmbeddingResponse
  559. if err := json.Unmarshal(body, &embedding); err != nil {
  560. return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
  561. }
  562. return embedding.Embedding, nil
  563. }
  564. // Ping checks that the server subprocess is still running and responding to requests
  565. func (llm *llama) Ping(ctx context.Context) error {
  566. resp, err := http.Head(fmt.Sprintf("http://127.0.0.1:%d", llm.Port))
  567. if err != nil {
  568. return fmt.Errorf("ping resp: %w", err)
  569. }
  570. if resp.StatusCode != http.StatusOK {
  571. return fmt.Errorf("unexpected ping status: %s", resp.Status)
  572. }
  573. return nil
  574. }