llama.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715
  1. package llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "embed"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log"
  13. "math/rand"
  14. "net/http"
  15. "os"
  16. "os/exec"
  17. "path"
  18. "path/filepath"
  19. "runtime"
  20. "strconv"
  21. "strings"
  22. "time"
  23. "github.com/jmorganca/ollama/api"
  24. )
  25. //go:embed llama.cpp/*/build/*/bin/*
  26. var llamaCppEmbed embed.FS
  27. type ModelRunner struct {
  28. Path string // path to the model runner executable
  29. }
  30. func chooseRunners(workDir, runnerType string) []ModelRunner {
  31. buildPath := path.Join("llama.cpp", runnerType, "build")
  32. var runners []string
  33. // set the runners based on the OS
  34. // IMPORTANT: the order of the runners in the array is the priority order
  35. switch runtime.GOOS {
  36. case "darwin":
  37. runners = []string{
  38. path.Join(buildPath, "metal", "bin", "ollama-runner"),
  39. path.Join(buildPath, "cpu", "bin", "ollama-runner"),
  40. }
  41. case "linux":
  42. runners = []string{
  43. path.Join(buildPath, "cuda", "bin", "ollama-runner"),
  44. path.Join(buildPath, "cpu", "bin", "ollama-runner"),
  45. }
  46. case "windows":
  47. // TODO: select windows GPU runner here when available
  48. runners = []string{
  49. path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe"),
  50. }
  51. default:
  52. log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
  53. runners = []string{
  54. path.Join(buildPath, "cpu", "bin", "ollama-runner"),
  55. }
  56. }
  57. runnerAvailable := false // if no runner files are found in the embed, this flag will cause a fast fail
  58. for _, r := range runners {
  59. // find all the files in the runner's bin directory
  60. files, err := fs.Glob(llamaCppEmbed, path.Join(path.Dir(r), "*"))
  61. if err != nil {
  62. // this is expected, ollama may be compiled without all runners packed in
  63. log.Printf("%s runner not found: %v", r, err)
  64. continue
  65. }
  66. for _, f := range files {
  67. runnerAvailable = true
  68. srcFile, err := llamaCppEmbed.Open(f)
  69. if err != nil {
  70. log.Fatalf("read llama runner %s: %v", f, err)
  71. }
  72. defer srcFile.Close()
  73. // create the directory in case it does not exist, filepath.Dir() converts the file path to the OS's format
  74. destPath := filepath.Join(workDir, filepath.Dir(f))
  75. if err := os.MkdirAll(destPath, 0o755); err != nil {
  76. log.Fatalf("create runner temp dir %s: %v", filepath.Dir(f), err)
  77. }
  78. // create the path to the destination file, filepath.Base() converts the file path to the OS's format
  79. destFile := filepath.Join(destPath, filepath.Base(f))
  80. _, err = os.Stat(destFile)
  81. switch {
  82. case errors.Is(err, os.ErrNotExist):
  83. destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  84. if err != nil {
  85. log.Fatalf("write llama runner %s: %v", f, err)
  86. }
  87. defer destFile.Close()
  88. if _, err := io.Copy(destFile, srcFile); err != nil {
  89. log.Fatalf("copy llama runner %s: %v", f, err)
  90. }
  91. case err != nil:
  92. log.Fatalf("stat llama runner %s: %v", f, err)
  93. }
  94. }
  95. }
  96. if !runnerAvailable {
  97. log.Fatalf("%s runner not found", runnerType)
  98. }
  99. // return the runners to try in priority order
  100. localRunnersByPriority := []ModelRunner{}
  101. for _, r := range runners {
  102. // clean the ModelRunner paths so that they match the OS we are running on
  103. localRunnersByPriority = append(localRunnersByPriority, ModelRunner{Path: filepath.Clean(path.Join(workDir, r))})
  104. }
  105. return localRunnersByPriority
  106. }
  107. type llamaModel struct {
  108. hyperparameters llamaHyperparameters
  109. }
  110. func (llm *llamaModel) ModelFamily() string {
  111. return "llama"
  112. }
  113. func llamaModelType(numLayer uint32) string {
  114. switch numLayer {
  115. case 26:
  116. return "3B"
  117. case 32:
  118. return "7B"
  119. case 40:
  120. return "13B"
  121. case 48:
  122. return "34B"
  123. case 60:
  124. return "30B"
  125. case 80:
  126. return "65B"
  127. default:
  128. return "unknown"
  129. }
  130. }
  131. func (llm *llamaModel) ModelType() string {
  132. return llamaModelType(llm.hyperparameters.NumLayer)
  133. }
  134. func (llm *llamaModel) FileType() string {
  135. return fileType(llm.hyperparameters.FileType)
  136. }
  137. func (llm *llamaModel) NumLayers() int64 {
  138. return int64(llm.hyperparameters.NumLayer)
  139. }
  140. type llamaHyperparameters struct {
  141. // NumVocab is the size of the model's vocabulary.
  142. NumVocab uint32
  143. // NumEmbd is the size of the model's embedding layer.
  144. NumEmbd uint32
  145. NumMult uint32
  146. NumHead uint32
  147. // NumLayer is the number of layers in the model.
  148. NumLayer uint32
  149. NumRot uint32
  150. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
  151. FileType uint32
  152. }
  153. type Running struct {
  154. Port int
  155. Cmd *exec.Cmd
  156. Cancel context.CancelFunc
  157. }
  158. type llama struct {
  159. api.Options
  160. Running
  161. }
  162. var errNoGPU = errors.New("nvidia-smi command failed")
  163. // CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs
  164. func CheckVRAM() (int64, error) {
  165. cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits")
  166. var stdout bytes.Buffer
  167. cmd.Stdout = &stdout
  168. err := cmd.Run()
  169. if err != nil {
  170. return 0, errNoGPU
  171. }
  172. var free int64
  173. scanner := bufio.NewScanner(&stdout)
  174. for scanner.Scan() {
  175. line := scanner.Text()
  176. vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64)
  177. if err != nil {
  178. return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
  179. }
  180. free += vram
  181. }
  182. return free, nil
  183. }
  184. func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
  185. if opts.NumGPU != -1 {
  186. return opts.NumGPU
  187. }
  188. if runtime.GOOS == "linux" {
  189. vramMib, err := CheckVRAM()
  190. if err != nil {
  191. if err.Error() != "nvidia-smi command failed" {
  192. log.Print(err.Error())
  193. }
  194. // nvidia driver not installed or no nvidia GPU found
  195. return 0
  196. }
  197. freeVramBytes := int64(vramMib) * 1024 * 1024 // 1 MiB = 1024^2 bytes
  198. // Calculate bytes per layer
  199. // TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size
  200. bytesPerLayer := fileSizeBytes / numLayer
  201. // max number of layers we can fit in VRAM, subtract 5% to prevent consuming all available VRAM and running out of memory
  202. layers := int(freeVramBytes/bytesPerLayer) * 95 / 100
  203. log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, layers)
  204. return layers
  205. }
  206. // default to enable metal on macOS
  207. return 1
  208. }
  209. func newLlama(model string, adapters []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
  210. fileInfo, err := os.Stat(model)
  211. if err != nil {
  212. return nil, err
  213. }
  214. if len(adapters) > 1 {
  215. return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
  216. }
  217. params := []string{
  218. "--model", model,
  219. "--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
  220. "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
  221. "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
  222. "--batch-size", fmt.Sprintf("%d", opts.NumBatch),
  223. "--n-gpu-layers", fmt.Sprintf("%d", NumGPU(numLayers, fileInfo.Size(), opts)),
  224. "--embedding",
  225. }
  226. if opts.NumGQA > 0 {
  227. params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
  228. }
  229. if len(adapters) > 0 {
  230. // TODO: applying multiple adapters is not supported by the llama.cpp server yet
  231. params = append(params, "--lora", adapters[0])
  232. }
  233. if opts.NumThread > 0 {
  234. params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
  235. }
  236. if !opts.F16KV {
  237. params = append(params, "--memory-f32")
  238. }
  239. if opts.UseMLock {
  240. params = append(params, "--mlock")
  241. }
  242. if !opts.UseMMap {
  243. params = append(params, "--no-mmap")
  244. }
  245. if opts.UseNUMA {
  246. params = append(params, "--numa")
  247. }
  248. // start the llama.cpp server with a retry in case the port is already in use
  249. for _, runner := range runners {
  250. if _, err := os.Stat(runner.Path); err != nil {
  251. log.Printf("llama runner not found: %v", err)
  252. continue
  253. }
  254. port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  255. ctx, cancel := context.WithCancel(context.Background())
  256. cmd := exec.CommandContext(
  257. ctx,
  258. runner.Path,
  259. append(params, "--port", strconv.Itoa(port))...,
  260. )
  261. cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", filepath.Dir(runner.Path)))
  262. cmd.Stdout = os.Stderr
  263. cmd.Stderr = os.Stderr
  264. llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
  265. log.Print("starting llama runner")
  266. if err := llm.Cmd.Start(); err != nil {
  267. log.Printf("error starting the external llama runner: %v", err)
  268. continue
  269. }
  270. // monitor the command, it is blocking, so if it exits we need to capture that
  271. go func() {
  272. err := llm.Cmd.Wait() // this will block until the command exits
  273. if err != nil {
  274. log.Printf("llama runner exited with error: %v", err)
  275. } else {
  276. log.Printf("llama runner exited")
  277. }
  278. }()
  279. if err := waitForServer(llm); err != nil {
  280. log.Printf("error starting llama runner: %v", err)
  281. llm.Close()
  282. // try again
  283. continue
  284. }
  285. // server started successfully
  286. return llm, nil
  287. }
  288. return nil, fmt.Errorf("failed to start a llama runner")
  289. }
  290. func waitForServer(llm *llama) error {
  291. // wait for the server to start responding
  292. start := time.Now()
  293. expiresAt := time.Now().Add(2 * time.Minute) // be generous with timeout, large models can take a while to load
  294. ticker := time.NewTicker(200 * time.Millisecond)
  295. log.Print("waiting for llama runner to start responding")
  296. for range ticker.C {
  297. if time.Now().After(expiresAt) {
  298. return fmt.Errorf("llama runner did not start within alloted time, retrying")
  299. }
  300. // check if the server process has terminated
  301. if llm.Cmd.ProcessState != nil && llm.Cmd.ProcessState.Exited() {
  302. return fmt.Errorf("llama runner process has terminated")
  303. }
  304. if err := llm.Ping(context.Background()); err == nil {
  305. break
  306. }
  307. }
  308. log.Printf("llama runner started in %f seconds", time.Since(start).Seconds())
  309. return nil
  310. }
  311. func (llm *llama) Close() {
  312. // signal the sub-process to terminate
  313. llm.Cancel()
  314. // wait for the command to exit to prevent race conditions with the next run
  315. if err := llm.Cmd.Wait(); err != nil {
  316. log.Printf("llama runner exited: %v", err)
  317. }
  318. }
  319. func (llm *llama) SetOptions(opts api.Options) {
  320. llm.Options = opts
  321. }
  322. type GenerationSettings struct {
  323. FrequencyPenalty float64 `json:"frequency_penalty"`
  324. IgnoreEOS bool `json:"ignore_eos"`
  325. LogitBias []interface{} `json:"logit_bias"`
  326. Mirostat int `json:"mirostat"`
  327. MirostatEta float64 `json:"mirostat_eta"`
  328. MirostatTau float64 `json:"mirostat_tau"`
  329. Model string `json:"model"`
  330. NCtx int `json:"n_ctx"`
  331. NKeep int `json:"n_keep"`
  332. NPredict int `json:"n_predict"`
  333. NProbs int `json:"n_probs"`
  334. PenalizeNl bool `json:"penalize_nl"`
  335. PresencePenalty float64 `json:"presence_penalty"`
  336. RepeatLastN int `json:"repeat_last_n"`
  337. RepeatPenalty float64 `json:"repeat_penalty"`
  338. Seed uint32 `json:"seed"`
  339. Stop []string `json:"stop"`
  340. Stream bool `json:"stream"`
  341. Temp float64 `json:"temp"`
  342. TfsZ float64 `json:"tfs_z"`
  343. TopK int `json:"top_k"`
  344. TopP float64 `json:"top_p"`
  345. TypicalP float64 `json:"typical_p"`
  346. }
  347. type Timings struct {
  348. PredictedN int `json:"predicted_n"`
  349. PredictedMS float64 `json:"predicted_ms"`
  350. PromptN int `json:"prompt_n"`
  351. PromptMS float64 `json:"prompt_ms"`
  352. }
  353. type Prediction struct {
  354. Content string `json:"content"`
  355. Model string `json:"model"`
  356. Prompt string `json:"prompt"`
  357. Stop bool `json:"stop"`
  358. Timings `json:"timings"`
  359. }
  360. type PredictRequest struct {
  361. Prompt string `json:"prompt"`
  362. Stream bool `json:"stream"`
  363. NPredict int `json:"n_predict"`
  364. NKeep int `json:"n_keep"`
  365. Temperature float32 `json:"temperature"`
  366. TopK int `json:"top_k"`
  367. TopP float32 `json:"top_p"`
  368. TfsZ float32 `json:"tfs_z"`
  369. TypicalP float32 `json:"typical_p"`
  370. RepeatLastN int `json:"repeat_last_n"`
  371. RepeatPenalty float32 `json:"repeat_penalty"`
  372. PresencePenalty float32 `json:"presence_penalty"`
  373. FrequencyPenalty float32 `json:"frequency_penalty"`
  374. Mirostat int `json:"mirostat"`
  375. MirostatTau float32 `json:"mirostat_tau"`
  376. MirostatEta float32 `json:"mirostat_eta"`
  377. PenalizeNl bool `json:"penalize_nl"`
  378. Seed int `json:"seed"`
  379. Stop []string `json:"stop,omitempty"`
  380. }
  381. const maxBufferSize = 512 * 1024 // 512KB
  382. func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
  383. prevConvo, err := llm.Decode(ctx, prevContext)
  384. if err != nil {
  385. return err
  386. }
  387. var nextContext strings.Builder
  388. nextContext.WriteString(prevConvo)
  389. nextContext.WriteString(prompt)
  390. endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
  391. predReq := PredictRequest{
  392. Prompt: nextContext.String(),
  393. Stream: true,
  394. NPredict: llm.NumPredict,
  395. NKeep: llm.NumKeep,
  396. Temperature: llm.Temperature,
  397. TopK: llm.TopK,
  398. TopP: llm.TopP,
  399. TfsZ: llm.TFSZ,
  400. TypicalP: llm.TypicalP,
  401. RepeatLastN: llm.RepeatLastN,
  402. RepeatPenalty: llm.RepeatPenalty,
  403. PresencePenalty: llm.PresencePenalty,
  404. FrequencyPenalty: llm.FrequencyPenalty,
  405. Mirostat: llm.Mirostat,
  406. MirostatTau: llm.MirostatTau,
  407. MirostatEta: llm.MirostatEta,
  408. PenalizeNl: llm.PenalizeNewline,
  409. Seed: llm.Seed,
  410. Stop: llm.Stop,
  411. }
  412. data, err := json.Marshal(predReq)
  413. if err != nil {
  414. return fmt.Errorf("error marshaling data: %v", err)
  415. }
  416. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  417. if err != nil {
  418. return fmt.Errorf("error creating POST request: %v", err)
  419. }
  420. req.Header.Set("Content-Type", "application/json")
  421. resp, err := http.DefaultClient.Do(req)
  422. if err != nil {
  423. return fmt.Errorf("POST predict: %v", err)
  424. }
  425. defer resp.Body.Close()
  426. if resp.StatusCode >= 400 {
  427. bodyBytes, err := io.ReadAll(resp.Body)
  428. if err != nil {
  429. return fmt.Errorf("failed reading llm error response: %w", err)
  430. }
  431. log.Printf("llm predict error: %s", bodyBytes)
  432. return fmt.Errorf("%s", bodyBytes)
  433. }
  434. scanner := bufio.NewScanner(resp.Body)
  435. // increase the buffer size to avoid running out of space
  436. buf := make([]byte, 0, maxBufferSize)
  437. scanner.Buffer(buf, maxBufferSize)
  438. for scanner.Scan() {
  439. select {
  440. case <-ctx.Done():
  441. // This handles the request cancellation
  442. return ctx.Err()
  443. default:
  444. line := scanner.Text()
  445. if line == "" {
  446. continue
  447. }
  448. // Read data from the server-side event stream
  449. if strings.HasPrefix(line, "data: ") {
  450. evt := line[6:]
  451. var p Prediction
  452. if err := json.Unmarshal([]byte(evt), &p); err != nil {
  453. return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
  454. }
  455. if p.Content != "" {
  456. fn(api.GenerateResponse{Response: p.Content})
  457. nextContext.WriteString(p.Content)
  458. }
  459. if p.Stop {
  460. embd, err := llm.Encode(ctx, nextContext.String())
  461. if err != nil {
  462. return fmt.Errorf("encoding context: %v", err)
  463. }
  464. fn(api.GenerateResponse{
  465. Done: true,
  466. Context: embd,
  467. PromptEvalCount: p.PromptN,
  468. PromptEvalDuration: parseDurationMs(p.PromptMS),
  469. EvalCount: p.PredictedN,
  470. EvalDuration: parseDurationMs(p.PredictedMS),
  471. })
  472. return nil
  473. }
  474. }
  475. }
  476. }
  477. if err := scanner.Err(); err != nil {
  478. return fmt.Errorf("error reading llm response: %v", err)
  479. }
  480. return nil
  481. }
  482. type TokenizeRequest struct {
  483. Content string `json:"content"`
  484. }
  485. type TokenizeResponse struct {
  486. Tokens []int `json:"tokens"`
  487. }
  488. func (llm *llama) Encode(ctx context.Context, prompt string) ([]int, error) {
  489. endpoint := fmt.Sprintf("http://127.0.0.1:%d/tokenize", llm.Port)
  490. data, err := json.Marshal(TokenizeRequest{Content: prompt})
  491. if err != nil {
  492. return nil, fmt.Errorf("marshaling encode data: %w", err)
  493. }
  494. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  495. if err != nil {
  496. return nil, fmt.Errorf("encode request: %w", err)
  497. }
  498. req.Header.Set("Content-Type", "application/json")
  499. resp, err := http.DefaultClient.Do(req)
  500. if err != nil {
  501. return nil, fmt.Errorf("do encode request: %w", err)
  502. }
  503. defer resp.Body.Close()
  504. body, err := io.ReadAll(resp.Body)
  505. if err != nil {
  506. return nil, fmt.Errorf("read encode request: %w", err)
  507. }
  508. if resp.StatusCode >= 400 {
  509. log.Printf("llm encode error: %s", body)
  510. return nil, fmt.Errorf("%s", body)
  511. }
  512. var encoded TokenizeResponse
  513. if err := json.Unmarshal(body, &encoded); err != nil {
  514. return nil, fmt.Errorf("unmarshal encode response: %w", err)
  515. }
  516. return encoded.Tokens, nil
  517. }
  518. type DetokenizeRequest struct {
  519. Tokens []int `json:"tokens"`
  520. }
  521. type DetokenizeResponse struct {
  522. Content string `json:"content"`
  523. }
  524. func (llm *llama) Decode(ctx context.Context, tokens []int) (string, error) {
  525. if len(tokens) == 0 {
  526. return "", nil
  527. }
  528. endpoint := fmt.Sprintf("http://127.0.0.1:%d/detokenize", llm.Port)
  529. data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
  530. if err != nil {
  531. return "", fmt.Errorf("marshaling decode data: %w", err)
  532. }
  533. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  534. if err != nil {
  535. return "", fmt.Errorf("decode request: %w", err)
  536. }
  537. req.Header.Set("Content-Type", "application/json")
  538. resp, err := http.DefaultClient.Do(req)
  539. if err != nil {
  540. return "", fmt.Errorf("do decode request: %w", err)
  541. }
  542. defer resp.Body.Close()
  543. body, err := io.ReadAll(resp.Body)
  544. if err != nil {
  545. return "", fmt.Errorf("read decode request: %w", err)
  546. }
  547. if resp.StatusCode >= 400 {
  548. log.Printf("llm decode error: %s", body)
  549. return "", fmt.Errorf("%s", body)
  550. }
  551. var decoded DetokenizeResponse
  552. if err := json.Unmarshal(body, &decoded); err != nil {
  553. return "", fmt.Errorf("unmarshal encode response: %w", err)
  554. }
  555. // decoded content contains a leading whitespace
  556. decoded.Content, _ = strings.CutPrefix(decoded.Content, "")
  557. return decoded.Content, nil
  558. }
  559. type EmbeddingRequest struct {
  560. Content string `json:"content"`
  561. }
  562. type EmbeddingResponse struct {
  563. Embedding []float64 `json:"embedding"`
  564. }
  565. func (llm *llama) Embedding(ctx context.Context, input string) ([]float64, error) {
  566. endpoint := fmt.Sprintf("http://127.0.0.1:%d/embedding", llm.Port)
  567. data, err := json.Marshal(TokenizeRequest{Content: input})
  568. if err != nil {
  569. return nil, fmt.Errorf("error marshaling embed data: %w", err)
  570. }
  571. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  572. if err != nil {
  573. return nil, fmt.Errorf("error creating embed request: %w", err)
  574. }
  575. req.Header.Set("Content-Type", "application/json")
  576. resp, err := http.DefaultClient.Do(req)
  577. if err != nil {
  578. return nil, fmt.Errorf("POST embedding: %w", err)
  579. }
  580. defer resp.Body.Close()
  581. body, err := io.ReadAll(resp.Body)
  582. if err != nil {
  583. return nil, fmt.Errorf("error reading embed response: %w", err)
  584. }
  585. if resp.StatusCode >= 400 {
  586. log.Printf("llm encode error: %s", body)
  587. return nil, fmt.Errorf("%s", body)
  588. }
  589. var embedding EmbeddingResponse
  590. if err := json.Unmarshal(body, &embedding); err != nil {
  591. return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
  592. }
  593. return embedding.Embedding, nil
  594. }
  595. // Ping checks that the server subprocess is still running and responding to requests
  596. func (llm *llama) Ping(ctx context.Context) error {
  597. resp, err := http.Head(fmt.Sprintf("http://127.0.0.1:%d", llm.Port))
  598. if err != nil {
  599. return fmt.Errorf("ping resp: %w", err)
  600. }
  601. if resp.StatusCode != http.StatusOK {
  602. return fmt.Errorf("unexpected ping status: %s", resp.Status)
  603. }
  604. return nil
  605. }