llama.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763
  1. package llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "embed"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log"
  13. "math/rand"
  14. "net/http"
  15. "os"
  16. "os/exec"
  17. "path"
  18. "path/filepath"
  19. "runtime"
  20. "strconv"
  21. "strings"
  22. "sync"
  23. "time"
  24. "github.com/jmorganca/ollama/api"
  25. )
  26. //go:embed llama.cpp/*/build/*/bin/*
  27. var llamaCppEmbed embed.FS
  28. type ModelRunner struct {
  29. Path string // path to the model runner executable
  30. }
  31. func chooseRunners(workDir, runnerType string) []ModelRunner {
  32. buildPath := path.Join("llama.cpp", runnerType, "build")
  33. var runners []string
  34. // set the runners based on the OS
  35. // IMPORTANT: the order of the runners in the array is the priority order
  36. switch runtime.GOOS {
  37. case "darwin":
  38. runners = []string{
  39. path.Join(buildPath, "metal", "bin", "ollama-runner"),
  40. path.Join(buildPath, "cpu", "bin", "ollama-runner"),
  41. }
  42. case "linux":
  43. runners = []string{
  44. path.Join(buildPath, "cuda", "bin", "ollama-runner"),
  45. path.Join(buildPath, "cpu", "bin", "ollama-runner"),
  46. }
  47. case "windows":
  48. // TODO: select windows GPU runner here when available
  49. runners = []string{
  50. path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe"),
  51. }
  52. default:
  53. log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
  54. runners = []string{
  55. path.Join(buildPath, "cpu", "bin", "ollama-runner"),
  56. }
  57. }
  58. runnerAvailable := false // if no runner files are found in the embed, this flag will cause a fast fail
  59. for _, r := range runners {
  60. // find all the files in the runner's bin directory
  61. files, err := fs.Glob(llamaCppEmbed, path.Join(path.Dir(r), "*"))
  62. if err != nil {
  63. // this is expected, ollama may be compiled without all runners packed in
  64. log.Printf("%s runner not found: %v", r, err)
  65. continue
  66. }
  67. for _, f := range files {
  68. runnerAvailable = true
  69. srcFile, err := llamaCppEmbed.Open(f)
  70. if err != nil {
  71. log.Fatalf("read llama runner %s: %v", f, err)
  72. }
  73. defer srcFile.Close()
  74. // create the directory in case it does not exist, filepath.Dir() converts the file path to the OS's format
  75. destPath := filepath.Join(workDir, filepath.Dir(f))
  76. if err := os.MkdirAll(destPath, 0o755); err != nil {
  77. log.Fatalf("create runner temp dir %s: %v", filepath.Dir(f), err)
  78. }
  79. // create the path to the destination file, filepath.Base() converts the file path to the OS's format
  80. destFile := filepath.Join(destPath, filepath.Base(f))
  81. _, err = os.Stat(destFile)
  82. switch {
  83. case errors.Is(err, os.ErrNotExist):
  84. destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  85. if err != nil {
  86. log.Fatalf("write llama runner %s: %v", f, err)
  87. }
  88. defer destFile.Close()
  89. if _, err := io.Copy(destFile, srcFile); err != nil {
  90. log.Fatalf("copy llama runner %s: %v", f, err)
  91. }
  92. case err != nil:
  93. log.Fatalf("stat llama runner %s: %v", f, err)
  94. }
  95. }
  96. }
  97. if !runnerAvailable {
  98. log.Fatalf("%s runner not found", runnerType)
  99. }
  100. // return the runners to try in priority order
  101. localRunnersByPriority := []ModelRunner{}
  102. for _, r := range runners {
  103. // clean the ModelRunner paths so that they match the OS we are running on
  104. localRunnersByPriority = append(localRunnersByPriority, ModelRunner{Path: filepath.Clean(path.Join(workDir, r))})
  105. }
  106. return localRunnersByPriority
  107. }
  108. type llamaModel struct {
  109. hyperparameters llamaHyperparameters
  110. }
  111. func (llm *llamaModel) ModelFamily() string {
  112. return "llama"
  113. }
  114. func llamaModelType(numLayer uint32) string {
  115. switch numLayer {
  116. case 26:
  117. return "3B"
  118. case 32:
  119. return "7B"
  120. case 40:
  121. return "13B"
  122. case 48:
  123. return "34B"
  124. case 60:
  125. return "30B"
  126. case 80:
  127. return "65B"
  128. default:
  129. return "unknown"
  130. }
  131. }
  132. func (llm *llamaModel) ModelType() string {
  133. return llamaModelType(llm.hyperparameters.NumLayer)
  134. }
  135. func (llm *llamaModel) FileType() string {
  136. return fileType(llm.hyperparameters.FileType)
  137. }
  138. func (llm *llamaModel) NumLayers() int64 {
  139. return int64(llm.hyperparameters.NumLayer)
  140. }
  141. type llamaHyperparameters struct {
  142. // NumVocab is the size of the model's vocabulary.
  143. NumVocab uint32
  144. // NumEmbd is the size of the model's embedding layer.
  145. NumEmbd uint32
  146. NumMult uint32
  147. NumHead uint32
  148. // NumLayer is the number of layers in the model.
  149. NumLayer uint32
  150. NumRot uint32
  151. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
  152. FileType uint32
  153. }
  154. type Running struct {
  155. Port int
  156. Cmd *exec.Cmd
  157. Cancel context.CancelFunc
  158. exitOnce sync.Once
  159. exitCh chan error // channel to receive the exit status of the subprocess
  160. exitErr error // error returned by the subprocess
  161. }
  162. type llama struct {
  163. api.Options
  164. Running
  165. }
  166. var errNoGPU = errors.New("nvidia-smi command failed")
  167. // CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs
  168. func CheckVRAM() (int64, error) {
  169. cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits")
  170. var stdout bytes.Buffer
  171. cmd.Stdout = &stdout
  172. err := cmd.Run()
  173. if err != nil {
  174. return 0, errNoGPU
  175. }
  176. var free int64
  177. scanner := bufio.NewScanner(&stdout)
  178. for scanner.Scan() {
  179. line := scanner.Text()
  180. vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64)
  181. if err != nil {
  182. return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
  183. }
  184. free += vram
  185. }
  186. return free, nil
  187. }
  188. func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
  189. if opts.NumGPU != -1 {
  190. return opts.NumGPU
  191. }
  192. if runtime.GOOS == "linux" {
  193. vramMib, err := CheckVRAM()
  194. if err != nil {
  195. if err.Error() != "nvidia-smi command failed" {
  196. log.Print(err.Error())
  197. }
  198. // nvidia driver not installed or no nvidia GPU found
  199. return 0
  200. }
  201. freeVramBytes := int64(vramMib) * 1024 * 1024 // 1 MiB = 1024^2 bytes
  202. // Calculate bytes per layer
  203. // TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size
  204. bytesPerLayer := fileSizeBytes / numLayer
  205. // max number of layers we can fit in VRAM, subtract 5% to prevent consuming all available VRAM and running out of memory
  206. layers := int(freeVramBytes/bytesPerLayer) * 95 / 100
  207. log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, layers)
  208. return layers
  209. }
  210. // default to enable metal on macOS
  211. return 1
  212. }
  213. // StatusWriter is a writer that captures error messages from the llama runner process
  214. type StatusWriter struct {
  215. ErrCh chan error
  216. }
  217. func NewStatusWriter() *StatusWriter {
  218. return &StatusWriter{
  219. ErrCh: make(chan error, 1),
  220. }
  221. }
  222. func (w *StatusWriter) Write(b []byte) (int, error) {
  223. if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
  224. err := fmt.Errorf("llama runner: %s", after)
  225. w.ErrCh <- err
  226. }
  227. return os.Stderr.Write(b)
  228. }
  229. func newLlama(model string, adapters []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
  230. fileInfo, err := os.Stat(model)
  231. if err != nil {
  232. return nil, err
  233. }
  234. if len(adapters) > 1 {
  235. return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
  236. }
  237. params := []string{
  238. "--model", model,
  239. "--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
  240. "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
  241. "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
  242. "--batch-size", fmt.Sprintf("%d", opts.NumBatch),
  243. "--n-gpu-layers", fmt.Sprintf("%d", NumGPU(numLayers, fileInfo.Size(), opts)),
  244. "--embedding",
  245. }
  246. if opts.NumGQA > 0 {
  247. params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
  248. }
  249. if len(adapters) > 0 {
  250. // TODO: applying multiple adapters is not supported by the llama.cpp server yet
  251. params = append(params, "--lora", adapters[0])
  252. }
  253. if opts.NumThread > 0 {
  254. params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
  255. }
  256. if !opts.F16KV {
  257. params = append(params, "--memory-f32")
  258. }
  259. if opts.UseMLock {
  260. params = append(params, "--mlock")
  261. }
  262. if !opts.UseMMap {
  263. params = append(params, "--no-mmap")
  264. }
  265. if opts.UseNUMA {
  266. params = append(params, "--numa")
  267. }
  268. var runnerErr error
  269. // start the llama.cpp server with a retry in case the port is already in use
  270. for _, runner := range runners {
  271. if _, err := os.Stat(runner.Path); err != nil {
  272. log.Printf("llama runner not found: %v", err)
  273. continue
  274. }
  275. port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  276. ctx, cancel := context.WithCancel(context.Background())
  277. cmd := exec.CommandContext(
  278. ctx,
  279. runner.Path,
  280. append(params, "--port", strconv.Itoa(port))...,
  281. )
  282. cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", filepath.Dir(runner.Path)))
  283. cmd.Stdout = os.Stderr
  284. statusWriter := NewStatusWriter()
  285. cmd.Stderr = statusWriter
  286. llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel, exitCh: make(chan error)}}
  287. log.Print("starting llama runner")
  288. if err := llm.Cmd.Start(); err != nil {
  289. log.Printf("error starting the external llama runner: %v", err)
  290. continue
  291. }
  292. // monitor the llama runner process and signal when it exits
  293. go func() {
  294. err := llm.Cmd.Wait()
  295. llm.exitErr = err
  296. // llm.Cmd.Wait() can only be called once, use this exit channel to signal that the process has exited
  297. llm.exitOnce.Do(func() {
  298. close(llm.exitCh)
  299. })
  300. }()
  301. if err := waitForServer(llm); err != nil {
  302. log.Printf("error starting llama runner: %v", err)
  303. llm.Close()
  304. // default the runnerErr to the error returned by the most recent llama runner process
  305. runnerErr = err
  306. // capture the error directly from the runner process, if any
  307. select {
  308. case runnerErr = <-statusWriter.ErrCh:
  309. default:
  310. // the runner process probably timed out
  311. }
  312. // try again
  313. continue
  314. }
  315. // server started successfully
  316. return llm, nil
  317. }
  318. if runnerErr != nil {
  319. // this is the error returned from the llama runner process that failed most recently
  320. return nil, runnerErr
  321. }
  322. return nil, fmt.Errorf("failed to start a llama runner")
  323. }
  324. func waitForServer(llm *llama) error {
  325. start := time.Now()
  326. expiresAt := time.Now().Add(3 * time.Minute) // be generous with timeout, large models can take a while to load
  327. ticker := time.NewTicker(200 * time.Millisecond)
  328. defer ticker.Stop()
  329. log.Print("waiting for llama runner to start responding")
  330. for {
  331. select {
  332. case <-llm.exitCh:
  333. // failed to start subprocess
  334. return fmt.Errorf("llama runner process has terminated")
  335. case <-ticker.C:
  336. if time.Now().After(expiresAt) {
  337. // timeout
  338. return fmt.Errorf("timed out waiting for llama runner to start")
  339. }
  340. if err := llm.Ping(context.Background()); err == nil {
  341. // success
  342. log.Printf("llama runner started in %f seconds", time.Since(start).Seconds())
  343. return nil
  344. }
  345. }
  346. }
  347. }
  348. func (llm *llama) Close() {
  349. // signal the sub-process to terminate
  350. llm.Cancel()
  351. // wait for the command to exit to prevent race conditions with the next run
  352. <-llm.exitCh
  353. err := llm.exitErr
  354. if err != nil {
  355. log.Printf("llama runner stopped with error: %v", err)
  356. } else {
  357. log.Print("llama runner stopped successfully")
  358. }
  359. }
  360. func (llm *llama) SetOptions(opts api.Options) {
  361. llm.Options = opts
  362. }
  363. type GenerationSettings struct {
  364. FrequencyPenalty float64 `json:"frequency_penalty"`
  365. IgnoreEOS bool `json:"ignore_eos"`
  366. LogitBias []interface{} `json:"logit_bias"`
  367. Mirostat int `json:"mirostat"`
  368. MirostatEta float64 `json:"mirostat_eta"`
  369. MirostatTau float64 `json:"mirostat_tau"`
  370. Model string `json:"model"`
  371. NCtx int `json:"n_ctx"`
  372. NKeep int `json:"n_keep"`
  373. NPredict int `json:"n_predict"`
  374. NProbs int `json:"n_probs"`
  375. PenalizeNl bool `json:"penalize_nl"`
  376. PresencePenalty float64 `json:"presence_penalty"`
  377. RepeatLastN int `json:"repeat_last_n"`
  378. RepeatPenalty float64 `json:"repeat_penalty"`
  379. Seed uint32 `json:"seed"`
  380. Stop []string `json:"stop"`
  381. Stream bool `json:"stream"`
  382. Temp float64 `json:"temp"`
  383. TfsZ float64 `json:"tfs_z"`
  384. TopK int `json:"top_k"`
  385. TopP float64 `json:"top_p"`
  386. TypicalP float64 `json:"typical_p"`
  387. }
  388. type Timings struct {
  389. PredictedN int `json:"predicted_n"`
  390. PredictedMS float64 `json:"predicted_ms"`
  391. PromptN int `json:"prompt_n"`
  392. PromptMS float64 `json:"prompt_ms"`
  393. }
  394. type Prediction struct {
  395. Content string `json:"content"`
  396. Model string `json:"model"`
  397. Prompt string `json:"prompt"`
  398. Stop bool `json:"stop"`
  399. Timings `json:"timings"`
  400. }
  401. type PredictRequest struct {
  402. Prompt string `json:"prompt"`
  403. Stream bool `json:"stream"`
  404. NPredict int `json:"n_predict"`
  405. NKeep int `json:"n_keep"`
  406. Temperature float32 `json:"temperature"`
  407. TopK int `json:"top_k"`
  408. TopP float32 `json:"top_p"`
  409. TfsZ float32 `json:"tfs_z"`
  410. TypicalP float32 `json:"typical_p"`
  411. RepeatLastN int `json:"repeat_last_n"`
  412. RepeatPenalty float32 `json:"repeat_penalty"`
  413. PresencePenalty float32 `json:"presence_penalty"`
  414. FrequencyPenalty float32 `json:"frequency_penalty"`
  415. Mirostat int `json:"mirostat"`
  416. MirostatTau float32 `json:"mirostat_tau"`
  417. MirostatEta float32 `json:"mirostat_eta"`
  418. PenalizeNl bool `json:"penalize_nl"`
  419. Seed int `json:"seed"`
  420. Stop []string `json:"stop,omitempty"`
  421. }
  422. const maxBufferSize = 512 * 1000 // 512KB
  423. func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
  424. prevConvo, err := llm.Decode(ctx, prevContext)
  425. if err != nil {
  426. return err
  427. }
  428. var nextContext strings.Builder
  429. nextContext.WriteString(prevConvo)
  430. nextContext.WriteString(prompt)
  431. endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
  432. predReq := PredictRequest{
  433. Prompt: nextContext.String(),
  434. Stream: true,
  435. NPredict: llm.NumPredict,
  436. NKeep: llm.NumKeep,
  437. Temperature: llm.Temperature,
  438. TopK: llm.TopK,
  439. TopP: llm.TopP,
  440. TfsZ: llm.TFSZ,
  441. TypicalP: llm.TypicalP,
  442. RepeatLastN: llm.RepeatLastN,
  443. RepeatPenalty: llm.RepeatPenalty,
  444. PresencePenalty: llm.PresencePenalty,
  445. FrequencyPenalty: llm.FrequencyPenalty,
  446. Mirostat: llm.Mirostat,
  447. MirostatTau: llm.MirostatTau,
  448. MirostatEta: llm.MirostatEta,
  449. PenalizeNl: llm.PenalizeNewline,
  450. Seed: llm.Seed,
  451. Stop: llm.Stop,
  452. }
  453. data, err := json.Marshal(predReq)
  454. if err != nil {
  455. return fmt.Errorf("error marshaling data: %v", err)
  456. }
  457. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  458. if err != nil {
  459. return fmt.Errorf("error creating POST request: %v", err)
  460. }
  461. req.Header.Set("Content-Type", "application/json")
  462. resp, err := http.DefaultClient.Do(req)
  463. if err != nil {
  464. return fmt.Errorf("POST predict: %v", err)
  465. }
  466. defer resp.Body.Close()
  467. if resp.StatusCode >= 400 {
  468. bodyBytes, err := io.ReadAll(resp.Body)
  469. if err != nil {
  470. return fmt.Errorf("failed reading llm error response: %w", err)
  471. }
  472. log.Printf("llm predict error: %s", bodyBytes)
  473. return fmt.Errorf("%s", bodyBytes)
  474. }
  475. scanner := bufio.NewScanner(resp.Body)
  476. // increase the buffer size to avoid running out of space
  477. buf := make([]byte, 0, maxBufferSize)
  478. scanner.Buffer(buf, maxBufferSize)
  479. for scanner.Scan() {
  480. select {
  481. case <-ctx.Done():
  482. // This handles the request cancellation
  483. return ctx.Err()
  484. default:
  485. line := scanner.Text()
  486. if line == "" {
  487. continue
  488. }
  489. // Read data from the server-side event stream
  490. if strings.HasPrefix(line, "data: ") {
  491. evt := line[6:]
  492. var p Prediction
  493. if err := json.Unmarshal([]byte(evt), &p); err != nil {
  494. return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
  495. }
  496. if p.Content != "" {
  497. fn(api.GenerateResponse{Response: p.Content})
  498. nextContext.WriteString(p.Content)
  499. }
  500. if p.Stop {
  501. embd, err := llm.Encode(ctx, nextContext.String())
  502. if err != nil {
  503. return fmt.Errorf("encoding context: %v", err)
  504. }
  505. fn(api.GenerateResponse{
  506. Done: true,
  507. Context: embd,
  508. PromptEvalCount: p.PromptN,
  509. PromptEvalDuration: parseDurationMs(p.PromptMS),
  510. EvalCount: p.PredictedN,
  511. EvalDuration: parseDurationMs(p.PredictedMS),
  512. })
  513. return nil
  514. }
  515. }
  516. }
  517. }
  518. if err := scanner.Err(); err != nil {
  519. return fmt.Errorf("error reading llm response: %v", err)
  520. }
  521. return nil
  522. }
  523. type TokenizeRequest struct {
  524. Content string `json:"content"`
  525. }
  526. type TokenizeResponse struct {
  527. Tokens []int `json:"tokens"`
  528. }
  529. func (llm *llama) Encode(ctx context.Context, prompt string) ([]int, error) {
  530. endpoint := fmt.Sprintf("http://127.0.0.1:%d/tokenize", llm.Port)
  531. data, err := json.Marshal(TokenizeRequest{Content: prompt})
  532. if err != nil {
  533. return nil, fmt.Errorf("marshaling encode data: %w", err)
  534. }
  535. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  536. if err != nil {
  537. return nil, fmt.Errorf("encode request: %w", err)
  538. }
  539. req.Header.Set("Content-Type", "application/json")
  540. resp, err := http.DefaultClient.Do(req)
  541. if err != nil {
  542. return nil, fmt.Errorf("do encode request: %w", err)
  543. }
  544. defer resp.Body.Close()
  545. body, err := io.ReadAll(resp.Body)
  546. if err != nil {
  547. return nil, fmt.Errorf("read encode request: %w", err)
  548. }
  549. if resp.StatusCode >= 400 {
  550. log.Printf("llm encode error: %s", body)
  551. return nil, fmt.Errorf("%s", body)
  552. }
  553. var encoded TokenizeResponse
  554. if err := json.Unmarshal(body, &encoded); err != nil {
  555. return nil, fmt.Errorf("unmarshal encode response: %w", err)
  556. }
  557. return encoded.Tokens, nil
  558. }
  559. type DetokenizeRequest struct {
  560. Tokens []int `json:"tokens"`
  561. }
  562. type DetokenizeResponse struct {
  563. Content string `json:"content"`
  564. }
  565. func (llm *llama) Decode(ctx context.Context, tokens []int) (string, error) {
  566. if len(tokens) == 0 {
  567. return "", nil
  568. }
  569. endpoint := fmt.Sprintf("http://127.0.0.1:%d/detokenize", llm.Port)
  570. data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
  571. if err != nil {
  572. return "", fmt.Errorf("marshaling decode data: %w", err)
  573. }
  574. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  575. if err != nil {
  576. return "", fmt.Errorf("decode request: %w", err)
  577. }
  578. req.Header.Set("Content-Type", "application/json")
  579. resp, err := http.DefaultClient.Do(req)
  580. if err != nil {
  581. return "", fmt.Errorf("do decode request: %w", err)
  582. }
  583. defer resp.Body.Close()
  584. body, err := io.ReadAll(resp.Body)
  585. if err != nil {
  586. return "", fmt.Errorf("read decode request: %w", err)
  587. }
  588. if resp.StatusCode >= 400 {
  589. log.Printf("llm decode error: %s", body)
  590. return "", fmt.Errorf("%s", body)
  591. }
  592. var decoded DetokenizeResponse
  593. if err := json.Unmarshal(body, &decoded); err != nil {
  594. return "", fmt.Errorf("unmarshal encode response: %w", err)
  595. }
  596. // decoded content contains a leading whitespace
  597. decoded.Content, _ = strings.CutPrefix(decoded.Content, "")
  598. return decoded.Content, nil
  599. }
  600. type EmbeddingRequest struct {
  601. Content string `json:"content"`
  602. }
  603. type EmbeddingResponse struct {
  604. Embedding []float64 `json:"embedding"`
  605. }
  606. func (llm *llama) Embedding(ctx context.Context, input string) ([]float64, error) {
  607. endpoint := fmt.Sprintf("http://127.0.0.1:%d/embedding", llm.Port)
  608. data, err := json.Marshal(TokenizeRequest{Content: input})
  609. if err != nil {
  610. return nil, fmt.Errorf("error marshaling embed data: %w", err)
  611. }
  612. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  613. if err != nil {
  614. return nil, fmt.Errorf("error creating embed request: %w", err)
  615. }
  616. req.Header.Set("Content-Type", "application/json")
  617. resp, err := http.DefaultClient.Do(req)
  618. if err != nil {
  619. return nil, fmt.Errorf("POST embedding: %w", err)
  620. }
  621. defer resp.Body.Close()
  622. body, err := io.ReadAll(resp.Body)
  623. if err != nil {
  624. return nil, fmt.Errorf("error reading embed response: %w", err)
  625. }
  626. if resp.StatusCode >= 400 {
  627. log.Printf("llm encode error: %s", body)
  628. return nil, fmt.Errorf("%s", body)
  629. }
  630. var embedding EmbeddingResponse
  631. if err := json.Unmarshal(body, &embedding); err != nil {
  632. return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
  633. }
  634. return embedding.Embedding, nil
  635. }
  636. // Ping checks that the server subprocess is still running and responding to requests
  637. func (llm *llama) Ping(ctx context.Context) error {
  638. resp, err := http.Head(fmt.Sprintf("http://127.0.0.1:%d", llm.Port))
  639. if err != nil {
  640. return fmt.Errorf("ping resp: %w", err)
  641. }
  642. if resp.StatusCode != http.StatusOK {
  643. return fmt.Errorf("unexpected ping status: %s", resp.Status)
  644. }
  645. return nil
  646. }