llama.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. package llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "embed"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log"
  13. "math/rand"
  14. "net/http"
  15. "os"
  16. "os/exec"
  17. "path"
  18. "path/filepath"
  19. "runtime"
  20. "slices"
  21. "strconv"
  22. "strings"
  23. "sync"
  24. "time"
  25. "github.com/jmorganca/ollama/api"
  26. "github.com/jmorganca/ollama/format"
  27. )
  28. //go:embed llama.cpp/*/build/*/bin/*
  29. var llamaCppEmbed embed.FS
  30. type ModelRunner struct {
  31. Path string // path to the model runner executable
  32. Accelerated bool
  33. }
  34. func chooseRunners(workDir, runnerType string) []ModelRunner {
  35. buildPath := path.Join("llama.cpp", runnerType, "build")
  36. var runners []ModelRunner
  37. // set the runners based on the OS
  38. // IMPORTANT: the order of the runners in the array is the priority order
  39. switch runtime.GOOS {
  40. case "darwin":
  41. runners = []ModelRunner{
  42. {Path: path.Join(buildPath, "metal", "bin", "ollama-runner")},
  43. {Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
  44. }
  45. case "linux":
  46. runners = []ModelRunner{
  47. {Path: path.Join(buildPath, "cuda", "bin", "ollama-runner"), Accelerated: true},
  48. {Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
  49. }
  50. case "windows":
  51. // TODO: select windows GPU runner here when available
  52. runners = []ModelRunner{
  53. {Path: path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe")},
  54. }
  55. default:
  56. log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
  57. runners = []ModelRunner{
  58. {Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
  59. }
  60. }
  61. runnerAvailable := false // if no runner files are found in the embed, this flag will cause a fast fail
  62. for _, r := range runners {
  63. // find all the files in the runner's bin directory
  64. files, err := fs.Glob(llamaCppEmbed, path.Join(path.Dir(r.Path), "*"))
  65. if err != nil {
  66. // this is expected, ollama may be compiled without all runners packed in
  67. log.Printf("%s runner not found: %v", r.Path, err)
  68. continue
  69. }
  70. for _, f := range files {
  71. runnerAvailable = true
  72. srcFile, err := llamaCppEmbed.Open(f)
  73. if err != nil {
  74. log.Fatalf("read llama runner %s: %v", f, err)
  75. }
  76. defer srcFile.Close()
  77. // create the directory in case it does not exist, filepath.Dir() converts the file path to the OS's format
  78. destPath := filepath.Join(workDir, filepath.Dir(f))
  79. if err := os.MkdirAll(destPath, 0o755); err != nil {
  80. log.Fatalf("create runner temp dir %s: %v", filepath.Dir(f), err)
  81. }
  82. // create the path to the destination file, filepath.Base() converts the file path to the OS's format
  83. destFile := filepath.Join(destPath, filepath.Base(f))
  84. _, err = os.Stat(destFile)
  85. switch {
  86. case errors.Is(err, os.ErrNotExist):
  87. destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  88. if err != nil {
  89. log.Fatalf("write llama runner %s: %v", f, err)
  90. }
  91. defer destFile.Close()
  92. if _, err := io.Copy(destFile, srcFile); err != nil {
  93. log.Fatalf("copy llama runner %s: %v", f, err)
  94. }
  95. case err != nil:
  96. log.Fatalf("stat llama runner %s: %v", f, err)
  97. }
  98. }
  99. }
  100. if !runnerAvailable {
  101. log.Fatalf("%s runner not found", runnerType)
  102. }
  103. // return the runners to try in priority order
  104. localRunnersByPriority := []ModelRunner{}
  105. for _, r := range runners {
  106. // clean the ModelRunner paths so that they match the OS we are running on
  107. localRunnersByPriority = append(localRunnersByPriority, ModelRunner{
  108. Path: filepath.Clean(path.Join(workDir, r.Path)),
  109. Accelerated: r.Accelerated,
  110. })
  111. }
  112. return localRunnersByPriority
  113. }
  114. type llamaModel struct {
  115. hyperparameters llamaHyperparameters
  116. }
  117. func (llm *llamaModel) ModelFamily() string {
  118. return "llama"
  119. }
  120. func llamaModelType(numLayer uint32) string {
  121. switch numLayer {
  122. case 26:
  123. return "3B"
  124. case 32:
  125. return "7B"
  126. case 40:
  127. return "13B"
  128. case 48:
  129. return "34B"
  130. case 60:
  131. return "30B"
  132. case 80:
  133. return "65B"
  134. default:
  135. return "unknown"
  136. }
  137. }
  138. func (llm *llamaModel) ModelType() string {
  139. return llamaModelType(llm.hyperparameters.NumLayer)
  140. }
  141. func (llm *llamaModel) FileType() string {
  142. return fileType(llm.hyperparameters.FileType)
  143. }
  144. func (llm *llamaModel) NumLayers() int64 {
  145. return int64(llm.hyperparameters.NumLayer)
  146. }
  147. type llamaHyperparameters struct {
  148. // NumVocab is the size of the model's vocabulary.
  149. NumVocab uint32
  150. // NumEmbd is the size of the model's embedding layer.
  151. NumEmbd uint32
  152. NumMult uint32
  153. NumHead uint32
  154. // NumLayer is the number of layers in the model.
  155. NumLayer uint32
  156. NumRot uint32
  157. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
  158. FileType uint32
  159. }
  160. type Running struct {
  161. Port int
  162. Cmd *exec.Cmd
  163. Cancel context.CancelFunc
  164. exitOnce sync.Once
  165. exitCh chan error // channel to receive the exit status of the subprocess
  166. *StatusWriter // captures error messages from the llama runner process
  167. }
  168. type llama struct {
  169. api.Options
  170. Running
  171. }
  172. var errNoGPU = errors.New("nvidia-smi command failed")
  173. // CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
  174. func CheckVRAM() (int64, error) {
  175. cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits")
  176. var stdout bytes.Buffer
  177. cmd.Stdout = &stdout
  178. err := cmd.Run()
  179. if err != nil {
  180. return 0, errNoGPU
  181. }
  182. var freeMiB int64
  183. scanner := bufio.NewScanner(&stdout)
  184. for scanner.Scan() {
  185. line := scanner.Text()
  186. vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64)
  187. if err != nil {
  188. return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
  189. }
  190. freeMiB += vram
  191. }
  192. freeBytes := freeMiB * 1024 * 1024
  193. if freeBytes < 2*format.GigaByte {
  194. log.Printf("less than 2 GB VRAM available, falling back to CPU only")
  195. freeMiB = 0
  196. }
  197. return freeBytes, nil
  198. }
  199. func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
  200. if opts.NumGPU != -1 {
  201. return opts.NumGPU
  202. }
  203. if runtime.GOOS == "linux" {
  204. freeBytes, err := CheckVRAM()
  205. if err != nil {
  206. if err.Error() != "nvidia-smi command failed" {
  207. log.Print(err.Error())
  208. }
  209. // nvidia driver not installed or no nvidia GPU found
  210. return 0
  211. }
  212. // Calculate bytes per layer
  213. // TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size
  214. bytesPerLayer := fileSizeBytes / numLayer
  215. // max number of layers we can fit in VRAM, subtract 8% to prevent consuming all available VRAM and running out of memory
  216. layers := int(freeBytes/bytesPerLayer) * 92 / 100
  217. log.Printf("%d MB VRAM available, loading up to %d GPU layers", freeBytes/(1024*1024), layers)
  218. return layers
  219. }
  220. // default to enable metal on macOS
  221. return 1
  222. }
  223. // StatusWriter is a writer that captures error messages from the llama runner process
  224. type StatusWriter struct {
  225. ErrCh chan error
  226. LastErrMsg string
  227. }
  228. func NewStatusWriter() *StatusWriter {
  229. return &StatusWriter{
  230. ErrCh: make(chan error, 1),
  231. }
  232. }
  233. func (w *StatusWriter) Write(b []byte) (int, error) {
  234. var errMsg string
  235. if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
  236. errMsg = string(bytes.TrimSpace(after))
  237. } else if _, after, ok := bytes.Cut(b, []byte("CUDA error")); ok {
  238. errMsg = string(bytes.TrimSpace(after))
  239. }
  240. if errMsg != "" {
  241. w.LastErrMsg = errMsg
  242. w.ErrCh <- fmt.Errorf("llama runner: %s", errMsg)
  243. }
  244. return os.Stderr.Write(b)
  245. }
  246. func newLlama(model string, adapters []string, runners []ModelRunner, ggml *GGML, opts api.Options) (*llama, error) {
  247. fileInfo, err := os.Stat(model)
  248. if err != nil {
  249. return nil, err
  250. }
  251. if len(adapters) > 1 {
  252. return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
  253. }
  254. numGPU := NumGPU(ggml.NumLayers(), fileInfo.Size(), opts)
  255. params := []string{
  256. "--model", model,
  257. "--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
  258. "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
  259. "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
  260. "--batch-size", fmt.Sprintf("%d", opts.NumBatch),
  261. "--n-gpu-layers", fmt.Sprintf("%d", numGPU),
  262. "--embedding",
  263. }
  264. if opts.NumGQA > 0 {
  265. params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
  266. }
  267. if len(adapters) > 0 {
  268. // TODO: applying multiple adapters is not supported by the llama.cpp server yet
  269. params = append(params, "--lora", adapters[0])
  270. }
  271. if opts.NumThread > 0 {
  272. params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
  273. }
  274. if !opts.F16KV {
  275. params = append(params, "--memory-f32")
  276. }
  277. if opts.UseMLock {
  278. params = append(params, "--mlock")
  279. }
  280. if !opts.UseMMap {
  281. params = append(params, "--no-mmap")
  282. }
  283. if opts.UseNUMA {
  284. params = append(params, "--numa")
  285. }
  286. var runnerErr error
  287. // start the llama.cpp server with a retry in case the port is already in use
  288. for _, runner := range runners {
  289. if runner.Accelerated && numGPU == 0 {
  290. log.Printf("skipping accelerated runner because num_gpu=0")
  291. continue
  292. }
  293. if _, err := os.Stat(runner.Path); err != nil {
  294. log.Printf("llama runner not found: %v", err)
  295. continue
  296. }
  297. port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  298. ctx, cancel := context.WithCancel(context.Background())
  299. cmd := exec.CommandContext(
  300. ctx,
  301. runner.Path,
  302. append(params, "--port", strconv.Itoa(port))...,
  303. )
  304. cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", filepath.Dir(runner.Path)))
  305. cmd.Stdout = os.Stderr
  306. statusWriter := NewStatusWriter()
  307. cmd.Stderr = statusWriter
  308. llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel, exitCh: make(chan error)}}
  309. log.Print("starting llama runner")
  310. if err := llm.Cmd.Start(); err != nil {
  311. log.Printf("error starting the external llama runner: %v", err)
  312. continue
  313. }
  314. // monitor the llama runner process and signal when it exits
  315. go func() {
  316. err := llm.Cmd.Wait()
  317. // default to printing the exit message of the command process, it will probably just say 'exit staus 1'
  318. errMsg := err.Error()
  319. // try to set a better error message if llama runner logs captured an error
  320. if statusWriter.LastErrMsg != "" {
  321. errMsg = statusWriter.LastErrMsg
  322. }
  323. log.Println(errMsg)
  324. // llm.Cmd.Wait() can only be called once, use this exit channel to signal that the process has exited
  325. llm.exitOnce.Do(func() {
  326. close(llm.exitCh)
  327. })
  328. }()
  329. if err := waitForServer(llm); err != nil {
  330. log.Printf("error starting llama runner: %v", err)
  331. llm.Close()
  332. // default the runnerErr to the error returned by the most recent llama runner process
  333. runnerErr = err
  334. // capture the error directly from the runner process, if any
  335. select {
  336. case runnerErr = <-statusWriter.ErrCh:
  337. default:
  338. // the runner process probably timed out
  339. }
  340. // try again
  341. continue
  342. }
  343. // server started successfully
  344. return llm, nil
  345. }
  346. if runnerErr != nil {
  347. // this is the error returned from the llama runner process that failed most recently
  348. // falcon and starcoder model families are not compatible with older versions of llama.cpp
  349. families := []string{"falcon", "starcoder"}
  350. if strings.Contains(runnerErr.Error(), "failed to load model") && slices.Contains(families, ggml.ModelFamily()) {
  351. return nil, fmt.Errorf("%v: %s", runnerErr, "this model may be incompatible with your version of Ollama. Please run `ollama pull` to get the latest version of this model.")
  352. }
  353. return nil, runnerErr
  354. }
  355. return nil, fmt.Errorf("failed to start a llama runner")
  356. }
  357. func waitForServer(llm *llama) error {
  358. start := time.Now()
  359. expiresAt := time.Now().Add(3 * time.Minute) // be generous with timeout, large models can take a while to load
  360. ticker := time.NewTicker(200 * time.Millisecond)
  361. defer ticker.Stop()
  362. log.Print("waiting for llama runner to start responding")
  363. for {
  364. select {
  365. case <-llm.exitCh:
  366. // failed to start subprocess
  367. return fmt.Errorf("llama runner process has terminated")
  368. case <-ticker.C:
  369. if time.Now().After(expiresAt) {
  370. // timeout
  371. return fmt.Errorf("timed out waiting for llama runner to start")
  372. }
  373. if err := llm.Ping(context.Background()); err == nil {
  374. // success
  375. log.Printf("llama runner started in %f seconds", time.Since(start).Seconds())
  376. return nil
  377. }
  378. }
  379. }
  380. }
  381. func (llm *llama) Close() {
  382. // signal the sub-process to terminate
  383. llm.Cancel()
  384. // wait for the command to exit to prevent race conditions with the next run
  385. <-llm.exitCh
  386. if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
  387. log.Printf("llama runner stopped with error: %v", llm.StatusWriter.LastErrMsg)
  388. } else {
  389. log.Print("llama runner stopped successfully")
  390. }
  391. }
  392. func (llm *llama) SetOptions(opts api.Options) {
  393. llm.Options = opts
  394. }
  395. type prediction struct {
  396. Content string `json:"content"`
  397. Model string `json:"model"`
  398. Prompt string `json:"prompt"`
  399. Stop bool `json:"stop"`
  400. Timings struct {
  401. PredictedN int `json:"predicted_n"`
  402. PredictedMS float64 `json:"predicted_ms"`
  403. PromptN int `json:"prompt_n"`
  404. PromptMS float64 `json:"prompt_ms"`
  405. }
  406. }
  407. const maxBufferSize = 512 * format.KiloByte
  408. func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
  409. prevConvo, err := llm.Decode(ctx, prevContext)
  410. if err != nil {
  411. return err
  412. }
  413. // Remove leading spaces from prevConvo if present
  414. prevConvo = strings.TrimPrefix(prevConvo, " ")
  415. var nextContext strings.Builder
  416. nextContext.WriteString(prevConvo)
  417. nextContext.WriteString(prompt)
  418. request := map[string]any{
  419. "prompt": nextContext.String(),
  420. "stream": true,
  421. "n_predict": llm.NumPredict,
  422. "n_keep": llm.NumKeep,
  423. "temperature": llm.Temperature,
  424. "top_k": llm.TopK,
  425. "top_p": llm.TopP,
  426. "tfs_z": llm.TFSZ,
  427. "typical_p": llm.TypicalP,
  428. "repeat_last_n": llm.RepeatLastN,
  429. "repeat_penalty": llm.RepeatPenalty,
  430. "presence_penalty": llm.PresencePenalty,
  431. "frequency_penalty": llm.FrequencyPenalty,
  432. "mirostat": llm.Mirostat,
  433. "mirostat_tau": llm.MirostatTau,
  434. "mirostat_eta": llm.MirostatEta,
  435. "penalize_nl": llm.PenalizeNewline,
  436. "seed": llm.Seed,
  437. "stop": llm.Stop,
  438. }
  439. // Handling JSON marshaling with special characters unescaped.
  440. buffer := &bytes.Buffer{}
  441. enc := json.NewEncoder(buffer)
  442. enc.SetEscapeHTML(false)
  443. if err := enc.Encode(request); err != nil {
  444. return fmt.Errorf("failed to marshal data: %v", err)
  445. }
  446. endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
  447. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
  448. if err != nil {
  449. return fmt.Errorf("error creating POST request: %v", err)
  450. }
  451. req.Header.Set("Content-Type", "application/json")
  452. resp, err := http.DefaultClient.Do(req)
  453. if err != nil {
  454. return fmt.Errorf("POST predict: %v", err)
  455. }
  456. defer resp.Body.Close()
  457. if resp.StatusCode >= 400 {
  458. bodyBytes, err := io.ReadAll(resp.Body)
  459. if err != nil {
  460. return fmt.Errorf("failed reading llm error response: %w", err)
  461. }
  462. log.Printf("llm predict error: %s", bodyBytes)
  463. return fmt.Errorf("%s", bodyBytes)
  464. }
  465. scanner := bufio.NewScanner(resp.Body)
  466. // increase the buffer size to avoid running out of space
  467. buf := make([]byte, 0, maxBufferSize)
  468. scanner.Buffer(buf, maxBufferSize)
  469. for scanner.Scan() {
  470. select {
  471. case <-ctx.Done():
  472. // This handles the request cancellation
  473. return ctx.Err()
  474. default:
  475. line := scanner.Bytes()
  476. if len(line) == 0 {
  477. continue
  478. }
  479. if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
  480. var p prediction
  481. if err := json.Unmarshal(evt, &p); err != nil {
  482. return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
  483. }
  484. if p.Content != "" {
  485. fn(api.GenerateResponse{Response: p.Content})
  486. nextContext.WriteString(p.Content)
  487. }
  488. if p.Stop {
  489. embd, err := llm.Encode(ctx, nextContext.String())
  490. if err != nil {
  491. return fmt.Errorf("encoding context: %v", err)
  492. }
  493. fn(api.GenerateResponse{
  494. Done: true,
  495. Context: embd,
  496. PromptEvalCount: p.Timings.PromptN,
  497. PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
  498. EvalCount: p.Timings.PredictedN,
  499. EvalDuration: parseDurationMs(p.Timings.PredictedMS),
  500. })
  501. return nil
  502. }
  503. }
  504. }
  505. }
  506. if err := scanner.Err(); err != nil {
  507. if strings.Contains(err.Error(), "unexpected EOF") {
  508. // this means the llama runner subprocess crashed
  509. llm.Close()
  510. if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
  511. return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
  512. }
  513. return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
  514. }
  515. return fmt.Errorf("error reading llm response: %v", err)
  516. }
  517. return nil
  518. }
  519. type TokenizeRequest struct {
  520. Content string `json:"content"`
  521. }
  522. type TokenizeResponse struct {
  523. Tokens []int `json:"tokens"`
  524. }
  525. func (llm *llama) Encode(ctx context.Context, prompt string) ([]int, error) {
  526. endpoint := fmt.Sprintf("http://127.0.0.1:%d/tokenize", llm.Port)
  527. data, err := json.Marshal(TokenizeRequest{Content: prompt})
  528. if err != nil {
  529. return nil, fmt.Errorf("marshaling encode data: %w", err)
  530. }
  531. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  532. if err != nil {
  533. return nil, fmt.Errorf("encode request: %w", err)
  534. }
  535. req.Header.Set("Content-Type", "application/json")
  536. resp, err := http.DefaultClient.Do(req)
  537. if err != nil {
  538. return nil, fmt.Errorf("do encode request: %w", err)
  539. }
  540. defer resp.Body.Close()
  541. body, err := io.ReadAll(resp.Body)
  542. if err != nil {
  543. return nil, fmt.Errorf("read encode request: %w", err)
  544. }
  545. if resp.StatusCode >= 400 {
  546. log.Printf("llm encode error: %s", body)
  547. return nil, fmt.Errorf("%s", body)
  548. }
  549. var encoded TokenizeResponse
  550. if err := json.Unmarshal(body, &encoded); err != nil {
  551. return nil, fmt.Errorf("unmarshal encode response: %w", err)
  552. }
  553. return encoded.Tokens, nil
  554. }
  555. type DetokenizeRequest struct {
  556. Tokens []int `json:"tokens"`
  557. }
  558. type DetokenizeResponse struct {
  559. Content string `json:"content"`
  560. }
  561. func (llm *llama) Decode(ctx context.Context, tokens []int) (string, error) {
  562. if len(tokens) == 0 {
  563. return "", nil
  564. }
  565. endpoint := fmt.Sprintf("http://127.0.0.1:%d/detokenize", llm.Port)
  566. data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
  567. if err != nil {
  568. return "", fmt.Errorf("marshaling decode data: %w", err)
  569. }
  570. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  571. if err != nil {
  572. return "", fmt.Errorf("decode request: %w", err)
  573. }
  574. req.Header.Set("Content-Type", "application/json")
  575. resp, err := http.DefaultClient.Do(req)
  576. if err != nil {
  577. return "", fmt.Errorf("do decode request: %w", err)
  578. }
  579. defer resp.Body.Close()
  580. body, err := io.ReadAll(resp.Body)
  581. if err != nil {
  582. return "", fmt.Errorf("read decode request: %w", err)
  583. }
  584. if resp.StatusCode >= 400 {
  585. log.Printf("llm decode error: %s", body)
  586. return "", fmt.Errorf("%s", body)
  587. }
  588. var decoded DetokenizeResponse
  589. if err := json.Unmarshal(body, &decoded); err != nil {
  590. return "", fmt.Errorf("unmarshal encode response: %w", err)
  591. }
  592. return decoded.Content, nil
  593. }
  594. type EmbeddingRequest struct {
  595. Content string `json:"content"`
  596. }
  597. type EmbeddingResponse struct {
  598. Embedding []float64 `json:"embedding"`
  599. }
  600. func (llm *llama) Embedding(ctx context.Context, input string) ([]float64, error) {
  601. endpoint := fmt.Sprintf("http://127.0.0.1:%d/embedding", llm.Port)
  602. data, err := json.Marshal(TokenizeRequest{Content: input})
  603. if err != nil {
  604. return nil, fmt.Errorf("error marshaling embed data: %w", err)
  605. }
  606. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  607. if err != nil {
  608. return nil, fmt.Errorf("error creating embed request: %w", err)
  609. }
  610. req.Header.Set("Content-Type", "application/json")
  611. resp, err := http.DefaultClient.Do(req)
  612. if err != nil {
  613. return nil, fmt.Errorf("POST embedding: %w", err)
  614. }
  615. defer resp.Body.Close()
  616. body, err := io.ReadAll(resp.Body)
  617. if err != nil {
  618. return nil, fmt.Errorf("error reading embed response: %w", err)
  619. }
  620. if resp.StatusCode >= 400 {
  621. log.Printf("llm encode error: %s", body)
  622. return nil, fmt.Errorf("%s", body)
  623. }
  624. var embedding EmbeddingResponse
  625. if err := json.Unmarshal(body, &embedding); err != nil {
  626. return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
  627. }
  628. return embedding.Embedding, nil
  629. }
  630. // Ping checks that the server subprocess is still running and responding to requests
  631. func (llm *llama) Ping(ctx context.Context) error {
  632. resp, err := http.Head(fmt.Sprintf("http://127.0.0.1:%d", llm.Port))
  633. if err != nil {
  634. return fmt.Errorf("ping resp: %w", err)
  635. }
  636. if resp.StatusCode != http.StatusOK {
  637. return fmt.Errorf("unexpected ping status: %s", resp.Status)
  638. }
  639. return nil
  640. }