llama.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737
  1. package llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "embed"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log"
  13. "math/rand"
  14. "net/http"
  15. "os"
  16. "os/exec"
  17. "path"
  18. "path/filepath"
  19. "runtime"
  20. "strconv"
  21. "strings"
  22. "time"
  23. "github.com/jmorganca/ollama/api"
  24. )
  25. //go:embed llama.cpp/*/build/*/bin/*
  26. var llamaCppEmbed embed.FS
  27. func osPath(llamaPath string) string {
  28. if runtime.GOOS == "windows" {
  29. return path.Join(llamaPath, "Release")
  30. }
  31. return llamaPath
  32. }
  33. func chooseRunner(gpuPath, cpuPath string) string {
  34. tmpDir, err := os.MkdirTemp("", "llama-*")
  35. if err != nil {
  36. log.Fatalf("llama.cpp: failed to create temp dir: %v", err)
  37. }
  38. llamaPath := osPath(gpuPath)
  39. if _, err := fs.Stat(llamaCppEmbed, llamaPath); err != nil {
  40. llamaPath = osPath(cpuPath)
  41. if _, err := fs.Stat(llamaCppEmbed, llamaPath); err != nil {
  42. log.Fatalf("llama.cpp executable not found")
  43. }
  44. }
  45. files := []string{"server"}
  46. switch runtime.GOOS {
  47. case "windows":
  48. files = []string{"server.exe"}
  49. case "darwin":
  50. if llamaPath == osPath(gpuPath) {
  51. files = append(files, "ggml-metal.metal")
  52. }
  53. case "linux":
  54. // check if there is a GPU available
  55. if _, err := CheckVRAM(); errors.Is(err, errNoGPU) {
  56. // this error was logged on start-up, so we don't need to log it again
  57. llamaPath = osPath(cpuPath)
  58. }
  59. }
  60. for _, f := range files {
  61. srcPath := path.Join(llamaPath, f)
  62. destPath := filepath.Join(tmpDir, f)
  63. srcFile, err := llamaCppEmbed.Open(srcPath)
  64. if err != nil {
  65. log.Fatalf("read llama.cpp %s: %v", f, err)
  66. }
  67. defer srcFile.Close()
  68. destFile, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  69. if err != nil {
  70. log.Fatalf("write llama.cpp %s: %v", f, err)
  71. }
  72. defer destFile.Close()
  73. if _, err := io.Copy(destFile, srcFile); err != nil {
  74. log.Fatalf("copy llama.cpp %s: %v", f, err)
  75. }
  76. }
  77. runPath := filepath.Join(tmpDir, "server")
  78. if runtime.GOOS == "windows" {
  79. runPath = filepath.Join(tmpDir, "server.exe")
  80. }
  81. return runPath
  82. }
  83. const ModelFamilyLlama ModelFamily = "llama"
  84. type llamaModel struct {
  85. hyperparameters llamaHyperparameters
  86. }
  87. func (llm *llamaModel) ModelFamily() ModelFamily {
  88. return ModelFamilyLlama
  89. }
  90. func (llm *llamaModel) ModelType() ModelType {
  91. switch llm.hyperparameters.NumLayer {
  92. case 26:
  93. return ModelType3B
  94. case 32:
  95. return ModelType7B
  96. case 40:
  97. return ModelType13B
  98. case 48:
  99. return ModelType34B
  100. case 60:
  101. return ModelType30B
  102. case 80:
  103. return ModelType65B
  104. }
  105. // TODO: find a better default
  106. return ModelType7B
  107. }
  108. func (llm *llamaModel) FileType() FileType {
  109. return llm.hyperparameters.FileType
  110. }
  111. type llamaHyperparameters struct {
  112. // NumVocab is the size of the model's vocabulary.
  113. NumVocab uint32
  114. // NumEmbd is the size of the model's embedding layer.
  115. NumEmbd uint32
  116. NumMult uint32
  117. NumHead uint32
  118. // NumLayer is the number of layers in the model.
  119. NumLayer uint32
  120. NumRot uint32
  121. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
  122. FileType llamaFileType
  123. }
  124. type llamaFileType uint32
  125. const (
  126. llamaFileTypeF32 llamaFileType = iota
  127. llamaFileTypeF16
  128. llamaFileTypeQ4_0
  129. llamaFileTypeQ4_1
  130. llamaFileTypeQ4_1_F16
  131. llamaFileTypeQ8_0 llamaFileType = iota + 2
  132. llamaFileTypeQ5_0
  133. llamaFileTypeQ5_1
  134. llamaFileTypeQ2_K
  135. llamaFileTypeQ3_K_S
  136. llamaFileTypeQ3_K_M
  137. llamaFileTypeQ3_K_L
  138. llamaFileTypeQ4_K_S
  139. llamaFileTypeQ4_K_M
  140. llamaFileTypeQ5_K_S
  141. llamaFileTypeQ5_K_M
  142. llamaFileTypeQ6_K
  143. )
  144. func (ft llamaFileType) String() string {
  145. switch ft {
  146. case llamaFileTypeF32:
  147. return "F32"
  148. case llamaFileTypeF16:
  149. return "F16"
  150. case llamaFileTypeQ4_0:
  151. return "Q4_0"
  152. case llamaFileTypeQ4_1:
  153. return "Q4_1"
  154. case llamaFileTypeQ4_1_F16:
  155. return "Q4_1_F16"
  156. case llamaFileTypeQ8_0:
  157. return "Q8_0"
  158. case llamaFileTypeQ5_0:
  159. return "Q5_0"
  160. case llamaFileTypeQ5_1:
  161. return "Q5_1"
  162. case llamaFileTypeQ2_K:
  163. return "Q2_K"
  164. case llamaFileTypeQ3_K_S:
  165. return "Q3_K_S"
  166. case llamaFileTypeQ3_K_M:
  167. return "Q3_K_M"
  168. case llamaFileTypeQ3_K_L:
  169. return "Q3_K_L"
  170. case llamaFileTypeQ4_K_S:
  171. return "Q4_K_S"
  172. case llamaFileTypeQ4_K_M:
  173. return "Q4_K_M"
  174. case llamaFileTypeQ5_K_S:
  175. return "Q5_K_S"
  176. case llamaFileTypeQ5_K_M:
  177. return "Q5_K_M"
  178. case llamaFileTypeQ6_K:
  179. return "Q6_K"
  180. default:
  181. return "Unknown"
  182. }
  183. }
  184. type Running struct {
  185. Port int
  186. Cmd *exec.Cmd
  187. Cancel context.CancelFunc
  188. }
  189. type ModelRunner struct {
  190. Path string // path to the model runner executable
  191. }
  192. type llama struct {
  193. api.Options
  194. Running
  195. }
  196. var errNoGPU = errors.New("nvidia-smi command failed")
  197. // CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs
  198. func CheckVRAM() (int, error) {
  199. cmd := exec.Command("nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits")
  200. var stdout bytes.Buffer
  201. cmd.Stdout = &stdout
  202. err := cmd.Run()
  203. if err != nil {
  204. return 0, errNoGPU
  205. }
  206. var total int
  207. scanner := bufio.NewScanner(&stdout)
  208. for scanner.Scan() {
  209. line := scanner.Text()
  210. vram, err := strconv.Atoi(line)
  211. if err != nil {
  212. return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
  213. }
  214. total += vram
  215. }
  216. return total, nil
  217. }
  218. func NumGPU(opts api.Options) int {
  219. if opts.NumGPU != -1 {
  220. return opts.NumGPU
  221. }
  222. n := 1 // default to enable metal on macOS
  223. if runtime.GOOS == "linux" {
  224. vram, err := CheckVRAM()
  225. if err != nil {
  226. if err.Error() != "nvidia-smi command failed" {
  227. log.Print(err.Error())
  228. }
  229. // nvidia driver not installed or no nvidia GPU found
  230. return 0
  231. }
  232. // TODO: this is a very rough heuristic, better would be to calculate this based on number of layers and context size
  233. switch {
  234. case vram < 500:
  235. log.Printf("WARNING: Low VRAM detected, disabling GPU")
  236. n = 0
  237. case vram < 1000:
  238. n = 4
  239. case vram < 2000:
  240. n = 8
  241. case vram < 4000:
  242. n = 12
  243. case vram < 8000:
  244. n = 16
  245. case vram < 12000:
  246. n = 24
  247. case vram < 16000:
  248. n = 32
  249. default:
  250. n = 48
  251. }
  252. log.Printf("%d MB VRAM available, loading %d GPU layers", vram, n)
  253. }
  254. return n
  255. }
  256. func newLlama(model string, adapters []string, runner ModelRunner, opts api.Options) (*llama, error) {
  257. if _, err := os.Stat(model); err != nil {
  258. return nil, err
  259. }
  260. if _, err := os.Stat(runner.Path); err != nil {
  261. return nil, err
  262. }
  263. if len(adapters) > 1 {
  264. return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
  265. }
  266. params := []string{
  267. "--model", model,
  268. "--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
  269. "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
  270. "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
  271. "--batch-size", fmt.Sprintf("%d", opts.NumBatch),
  272. "--n-gpu-layers", fmt.Sprintf("%d", NumGPU(opts)),
  273. "--embedding",
  274. }
  275. if opts.NumGQA > 0 {
  276. params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
  277. }
  278. if len(adapters) > 0 {
  279. // TODO: applying multiple adapters is not supported by the llama.cpp server yet
  280. params = append(params, "--lora", adapters[0])
  281. }
  282. if opts.NumThread > 0 {
  283. params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
  284. }
  285. if !opts.F16KV {
  286. params = append(params, "--memory-f32")
  287. }
  288. if opts.UseMLock {
  289. params = append(params, "--mlock")
  290. }
  291. if !opts.UseMMap {
  292. params = append(params, "--no-mmap")
  293. }
  294. if opts.UseNUMA {
  295. params = append(params, "--numa")
  296. }
  297. // start the llama.cpp server with a retry in case the port is already in use
  298. for try := 0; try < 3; try++ {
  299. port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  300. ctx, cancel := context.WithCancel(context.Background())
  301. cmd := exec.CommandContext(
  302. ctx,
  303. runner.Path,
  304. append(params, "--port", strconv.Itoa(port))...,
  305. )
  306. cmd.Stdout = os.Stderr
  307. cmd.Stderr = os.Stderr
  308. llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
  309. log.Print("starting llama.cpp server")
  310. if err := llm.Cmd.Start(); err != nil {
  311. log.Printf("error starting the external llama.cpp server: %v", err)
  312. continue
  313. }
  314. if err := waitForServer(llm); err != nil {
  315. log.Printf("error starting llama.cpp server: %v", err)
  316. llm.Close()
  317. // try again
  318. continue
  319. }
  320. // server started successfully
  321. return llm, nil
  322. }
  323. return nil, fmt.Errorf("max retry exceeded starting llama.cpp")
  324. }
  325. func waitForServer(llm *llama) error {
  326. // wait for the server to start responding
  327. start := time.Now()
  328. expiresAt := time.Now().Add(45 * time.Second)
  329. ticker := time.NewTicker(200 * time.Millisecond)
  330. log.Print("waiting for llama.cpp server to start responding")
  331. for range ticker.C {
  332. if time.Now().After(expiresAt) {
  333. return fmt.Errorf("llama.cpp server did not start within alloted time, retrying")
  334. }
  335. if err := llm.Ping(context.Background()); err == nil {
  336. break
  337. }
  338. }
  339. log.Printf("llama.cpp server started in %f seconds", time.Since(start).Seconds())
  340. return nil
  341. }
  342. func (llm *llama) Close() {
  343. llm.Cancel()
  344. if err := llm.Cmd.Wait(); err != nil {
  345. log.Printf("llama.cpp server exited with error: %v", err)
  346. }
  347. }
  348. func (llm *llama) SetOptions(opts api.Options) {
  349. llm.Options = opts
  350. }
  351. type GenerationSettings struct {
  352. FrequencyPenalty float64 `json:"frequency_penalty"`
  353. IgnoreEOS bool `json:"ignore_eos"`
  354. LogitBias []interface{} `json:"logit_bias"`
  355. Mirostat int `json:"mirostat"`
  356. MirostatEta float64 `json:"mirostat_eta"`
  357. MirostatTau float64 `json:"mirostat_tau"`
  358. Model string `json:"model"`
  359. NCtx int `json:"n_ctx"`
  360. NKeep int `json:"n_keep"`
  361. NPredict int `json:"n_predict"`
  362. NProbs int `json:"n_probs"`
  363. PenalizeNl bool `json:"penalize_nl"`
  364. PresencePenalty float64 `json:"presence_penalty"`
  365. RepeatLastN int `json:"repeat_last_n"`
  366. RepeatPenalty float64 `json:"repeat_penalty"`
  367. Seed uint32 `json:"seed"`
  368. Stop []string `json:"stop"`
  369. Stream bool `json:"stream"`
  370. Temp float64 `json:"temp"`
  371. TfsZ float64 `json:"tfs_z"`
  372. TopK int `json:"top_k"`
  373. TopP float64 `json:"top_p"`
  374. TypicalP float64 `json:"typical_p"`
  375. }
  376. type Timings struct {
  377. PredictedN int `json:"predicted_n"`
  378. PredictedMS float64 `json:"predicted_ms"`
  379. PromptN int `json:"prompt_n"`
  380. PromptMS float64 `json:"prompt_ms"`
  381. }
  382. type Prediction struct {
  383. Content string `json:"content"`
  384. Model string `json:"model"`
  385. Prompt string `json:"prompt"`
  386. Stop bool `json:"stop"`
  387. Timings `json:"timings"`
  388. }
  389. type PredictRequest struct {
  390. Stream bool `json:"stream"`
  391. NPredict int `json:"n_predict,omitempty"`
  392. TopK int `json:"top_k,omitempty"`
  393. TopP float32 `json:"top_p,omitempty"`
  394. TfsZ float32 `json:"tfs_z,omitempty"`
  395. TypicalP float32 `json:"typical_p,omitempty"`
  396. RepeatLastN int `json:"repeat_last_n,omitempty"`
  397. Temperature float32 `json:"temperature,omitempty"`
  398. RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
  399. PresencePenalty float32 `json:"presence_penalty,omitempty"`
  400. FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
  401. Mirostat int `json:"mirostat,omitempty"`
  402. MirostatTau float32 `json:"mirostat_tau,omitempty"`
  403. MirostatEta float32 `json:"mirostat_eta,omitempty"`
  404. PenalizeNl bool `json:"penalize_nl,omitempty"`
  405. NKeep int `json:"n_keep,omitempty"`
  406. Seed int `json:"seed,omitempty"`
  407. Prompt string `json:"prompt,omitempty"`
  408. NProbs int `json:"n_probs,omitempty"`
  409. LogitBias map[int]float32 `json:"logit_bias,omitempty"`
  410. IgnoreEos bool `json:"ignore_eos,omitempty"`
  411. Stop []string `json:"stop,omitempty"`
  412. }
  413. func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
  414. prevConvo, err := llm.Decode(ctx, prevContext)
  415. if err != nil {
  416. return err
  417. }
  418. var nextContext strings.Builder
  419. nextContext.WriteString(prevConvo)
  420. nextContext.WriteString(prompt)
  421. endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
  422. predReq := PredictRequest{
  423. Prompt: nextContext.String(),
  424. Stream: true,
  425. NPredict: llm.NumPredict,
  426. NKeep: llm.NumKeep,
  427. Temperature: llm.Temperature,
  428. TopK: llm.TopK,
  429. TopP: llm.TopP,
  430. TfsZ: llm.TFSZ,
  431. TypicalP: llm.TypicalP,
  432. RepeatLastN: llm.RepeatLastN,
  433. RepeatPenalty: llm.RepeatPenalty,
  434. PresencePenalty: llm.PresencePenalty,
  435. FrequencyPenalty: llm.FrequencyPenalty,
  436. Mirostat: llm.Mirostat,
  437. MirostatTau: llm.MirostatTau,
  438. MirostatEta: llm.MirostatEta,
  439. PenalizeNl: llm.PenalizeNewline,
  440. Stop: llm.Stop,
  441. }
  442. data, err := json.Marshal(predReq)
  443. if err != nil {
  444. return fmt.Errorf("error marshaling data: %v", err)
  445. }
  446. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  447. if err != nil {
  448. return fmt.Errorf("error creating POST request: %v", err)
  449. }
  450. req.Header.Set("Content-Type", "application/json")
  451. resp, err := http.DefaultClient.Do(req)
  452. if err != nil {
  453. return fmt.Errorf("POST predict: %v", err)
  454. }
  455. defer resp.Body.Close()
  456. if resp.StatusCode >= 400 {
  457. bodyBytes, err := io.ReadAll(resp.Body)
  458. if err != nil {
  459. return fmt.Errorf("failed reading llm error response: %w", err)
  460. }
  461. log.Printf("llm predict error: %s", bodyBytes)
  462. return fmt.Errorf("%s", bodyBytes)
  463. }
  464. scanner := bufio.NewScanner(resp.Body)
  465. for scanner.Scan() {
  466. select {
  467. case <-ctx.Done():
  468. // This handles the request cancellation
  469. return ctx.Err()
  470. default:
  471. line := scanner.Text()
  472. if line == "" {
  473. continue
  474. }
  475. // Read data from the server-side event stream
  476. if strings.HasPrefix(line, "data: ") {
  477. evt := line[6:]
  478. var p Prediction
  479. if err := json.Unmarshal([]byte(evt), &p); err != nil {
  480. return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
  481. }
  482. if p.Content != "" {
  483. fn(api.GenerateResponse{Response: p.Content})
  484. nextContext.WriteString(p.Content)
  485. }
  486. if p.Stop {
  487. embd, err := llm.Encode(ctx, nextContext.String())
  488. if err != nil {
  489. return fmt.Errorf("encoding context: %v", err)
  490. }
  491. fn(api.GenerateResponse{
  492. Done: true,
  493. Context: embd,
  494. PromptEvalCount: p.PromptN,
  495. PromptEvalDuration: parseDurationMs(p.PromptMS),
  496. EvalCount: p.PredictedN,
  497. EvalDuration: parseDurationMs(p.PredictedMS),
  498. })
  499. return nil
  500. }
  501. }
  502. }
  503. }
  504. if err := scanner.Err(); err != nil {
  505. return fmt.Errorf("error reading llm response: %v", err)
  506. }
  507. return nil
  508. }
  509. type TokenizeRequest struct {
  510. Content string `json:"content"`
  511. }
  512. type TokenizeResponse struct {
  513. Tokens []int `json:"tokens"`
  514. }
  515. func (llm *llama) Encode(ctx context.Context, prompt string) ([]int, error) {
  516. endpoint := fmt.Sprintf("http://127.0.0.1:%d/tokenize", llm.Port)
  517. data, err := json.Marshal(TokenizeRequest{Content: prompt})
  518. if err != nil {
  519. return nil, fmt.Errorf("marshaling encode data: %w", err)
  520. }
  521. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  522. if err != nil {
  523. return nil, fmt.Errorf("encode request: %w", err)
  524. }
  525. req.Header.Set("Content-Type", "application/json")
  526. resp, err := http.DefaultClient.Do(req)
  527. if err != nil {
  528. return nil, fmt.Errorf("do encode request: %w", err)
  529. }
  530. defer resp.Body.Close()
  531. body, err := io.ReadAll(resp.Body)
  532. if err != nil {
  533. return nil, fmt.Errorf("read encode request: %w", err)
  534. }
  535. if resp.StatusCode >= 400 {
  536. log.Printf("llm encode error: %s", body)
  537. return nil, fmt.Errorf("%s", body)
  538. }
  539. var encoded TokenizeResponse
  540. if err := json.Unmarshal(body, &encoded); err != nil {
  541. return nil, fmt.Errorf("unmarshal encode response: %w", err)
  542. }
  543. return encoded.Tokens, nil
  544. }
  545. type DetokenizeRequest struct {
  546. Tokens []int `json:"tokens"`
  547. }
  548. type DetokenizeResponse struct {
  549. Content string `json:"content"`
  550. }
  551. func (llm *llama) Decode(ctx context.Context, tokens []int) (string, error) {
  552. if len(tokens) == 0 {
  553. return "", nil
  554. }
  555. endpoint := fmt.Sprintf("http://127.0.0.1:%d/detokenize", llm.Port)
  556. data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
  557. if err != nil {
  558. return "", fmt.Errorf("marshaling decode data: %w", err)
  559. }
  560. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  561. if err != nil {
  562. return "", fmt.Errorf("decode request: %w", err)
  563. }
  564. req.Header.Set("Content-Type", "application/json")
  565. resp, err := http.DefaultClient.Do(req)
  566. if err != nil {
  567. return "", fmt.Errorf("do decode request: %w", err)
  568. }
  569. defer resp.Body.Close()
  570. body, err := io.ReadAll(resp.Body)
  571. if err != nil {
  572. return "", fmt.Errorf("read decode request: %w", err)
  573. }
  574. if resp.StatusCode >= 400 {
  575. log.Printf("llm decode error: %s", body)
  576. return "", fmt.Errorf("%s", body)
  577. }
  578. var decoded DetokenizeResponse
  579. if err := json.Unmarshal(body, &decoded); err != nil {
  580. return "", fmt.Errorf("unmarshal encode response: %w", err)
  581. }
  582. // decoded content contains a leading whitespace
  583. decoded.Content, _ = strings.CutPrefix(decoded.Content, "")
  584. return decoded.Content, nil
  585. }
  586. type EmbeddingRequest struct {
  587. Content string `json:"content"`
  588. }
  589. type EmbeddingResponse struct {
  590. Embedding []float64 `json:"embedding"`
  591. }
  592. func (llm *llama) Embedding(ctx context.Context, input string) ([]float64, error) {
  593. endpoint := fmt.Sprintf("http://127.0.0.1:%d/embedding", llm.Port)
  594. data, err := json.Marshal(TokenizeRequest{Content: input})
  595. if err != nil {
  596. return nil, fmt.Errorf("error marshaling embed data: %w", err)
  597. }
  598. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  599. if err != nil {
  600. return nil, fmt.Errorf("error creating embed request: %w", err)
  601. }
  602. req.Header.Set("Content-Type", "application/json")
  603. resp, err := http.DefaultClient.Do(req)
  604. if err != nil {
  605. return nil, fmt.Errorf("POST embedding: %w", err)
  606. }
  607. defer resp.Body.Close()
  608. body, err := io.ReadAll(resp.Body)
  609. if err != nil {
  610. return nil, fmt.Errorf("error reading embed response: %w", err)
  611. }
  612. if resp.StatusCode >= 400 {
  613. log.Printf("llm encode error: %s", body)
  614. return nil, fmt.Errorf("%s", body)
  615. }
  616. var embedding EmbeddingResponse
  617. if err := json.Unmarshal(body, &embedding); err != nil {
  618. return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
  619. }
  620. return embedding.Embedding, nil
  621. }
  622. // Ping checks that the server subprocess is still running and responding to requests
  623. func (llm *llama) Ping(ctx context.Context) error {
  624. resp, err := http.Head(fmt.Sprintf("http://127.0.0.1:%d", llm.Port))
  625. if err != nil {
  626. return fmt.Errorf("ping resp: %w", err)
  627. }
  628. if resp.StatusCode != http.StatusOK {
  629. return fmt.Errorf("unexpected ping status: %s", resp.Status)
  630. }
  631. return nil
  632. }