llama.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781
  1. package llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "embed"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log"
  13. "math/rand"
  14. "net/http"
  15. "os"
  16. "os/exec"
  17. "path"
  18. "path/filepath"
  19. "runtime"
  20. "strconv"
  21. "strings"
  22. "sync"
  23. "time"
  24. "github.com/jmorganca/ollama/api"
  25. "github.com/jmorganca/ollama/format"
  26. )
  27. //go:embed llama.cpp/*/build/*/bin/*
  28. var llamaCppEmbed embed.FS
  29. type ModelRunner struct {
  30. Path string // path to the model runner executable
  31. Accelerated bool
  32. }
  33. func chooseRunners(workDir, runnerType string) []ModelRunner {
  34. buildPath := path.Join("llama.cpp", runnerType, "build")
  35. var runners []ModelRunner
  36. // set the runners based on the OS
  37. // IMPORTANT: the order of the runners in the array is the priority order
  38. switch runtime.GOOS {
  39. case "darwin":
  40. runners = []ModelRunner{
  41. {Path: path.Join(buildPath, "metal", "bin", "ollama-runner")},
  42. {Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
  43. }
  44. case "linux":
  45. runners = []ModelRunner{
  46. {Path: path.Join(buildPath, "cuda", "bin", "ollama-runner"), Accelerated: true},
  47. {Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
  48. }
  49. case "windows":
  50. // TODO: select windows GPU runner here when available
  51. runners = []ModelRunner{
  52. {Path: path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe")},
  53. }
  54. default:
  55. log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
  56. runners = []ModelRunner{
  57. {Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
  58. }
  59. }
  60. runnerAvailable := false // if no runner files are found in the embed, this flag will cause a fast fail
  61. for _, r := range runners {
  62. // find all the files in the runner's bin directory
  63. files, err := fs.Glob(llamaCppEmbed, path.Join(path.Dir(r.Path), "*"))
  64. if err != nil {
  65. // this is expected, ollama may be compiled without all runners packed in
  66. log.Printf("%s runner not found: %v", r.Path, err)
  67. continue
  68. }
  69. for _, f := range files {
  70. runnerAvailable = true
  71. srcFile, err := llamaCppEmbed.Open(f)
  72. if err != nil {
  73. log.Fatalf("read llama runner %s: %v", f, err)
  74. }
  75. defer srcFile.Close()
  76. // create the directory in case it does not exist, filepath.Dir() converts the file path to the OS's format
  77. destPath := filepath.Join(workDir, filepath.Dir(f))
  78. if err := os.MkdirAll(destPath, 0o755); err != nil {
  79. log.Fatalf("create runner temp dir %s: %v", filepath.Dir(f), err)
  80. }
  81. // create the path to the destination file, filepath.Base() converts the file path to the OS's format
  82. destFile := filepath.Join(destPath, filepath.Base(f))
  83. _, err = os.Stat(destFile)
  84. switch {
  85. case errors.Is(err, os.ErrNotExist):
  86. destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  87. if err != nil {
  88. log.Fatalf("write llama runner %s: %v", f, err)
  89. }
  90. defer destFile.Close()
  91. if _, err := io.Copy(destFile, srcFile); err != nil {
  92. log.Fatalf("copy llama runner %s: %v", f, err)
  93. }
  94. case err != nil:
  95. log.Fatalf("stat llama runner %s: %v", f, err)
  96. }
  97. }
  98. }
  99. if !runnerAvailable {
  100. log.Fatalf("%s runner not found", runnerType)
  101. }
  102. // return the runners to try in priority order
  103. localRunnersByPriority := []ModelRunner{}
  104. for _, r := range runners {
  105. // clean the ModelRunner paths so that they match the OS we are running on
  106. localRunnersByPriority = append(localRunnersByPriority, ModelRunner{
  107. Path: filepath.Clean(path.Join(workDir, r.Path)),
  108. Accelerated: r.Accelerated,
  109. })
  110. }
  111. return localRunnersByPriority
  112. }
  113. type llamaModel struct {
  114. hyperparameters llamaHyperparameters
  115. }
  116. func (llm *llamaModel) ModelFamily() string {
  117. return "llama"
  118. }
  119. func llamaModelType(numLayer uint32) string {
  120. switch numLayer {
  121. case 26:
  122. return "3B"
  123. case 32:
  124. return "7B"
  125. case 40:
  126. return "13B"
  127. case 48:
  128. return "34B"
  129. case 60:
  130. return "30B"
  131. case 80:
  132. return "65B"
  133. default:
  134. return "unknown"
  135. }
  136. }
  137. func (llm *llamaModel) ModelType() string {
  138. return llamaModelType(llm.hyperparameters.NumLayer)
  139. }
  140. func (llm *llamaModel) FileType() string {
  141. return fileType(llm.hyperparameters.FileType)
  142. }
  143. func (llm *llamaModel) NumLayers() int64 {
  144. return int64(llm.hyperparameters.NumLayer)
  145. }
  146. type llamaHyperparameters struct {
  147. // NumVocab is the size of the model's vocabulary.
  148. NumVocab uint32
  149. // NumEmbd is the size of the model's embedding layer.
  150. NumEmbd uint32
  151. NumMult uint32
  152. NumHead uint32
  153. // NumLayer is the number of layers in the model.
  154. NumLayer uint32
  155. NumRot uint32
  156. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
  157. FileType uint32
  158. }
  159. type Running struct {
  160. Port int
  161. Cmd *exec.Cmd
  162. Cancel context.CancelFunc
  163. exitOnce sync.Once
  164. exitCh chan error // channel to receive the exit status of the subprocess
  165. exitErr error // error returned by the subprocess
  166. }
  167. type llama struct {
  168. api.Options
  169. Running
  170. }
  171. var errNoGPU = errors.New("nvidia-smi command failed")
  172. // CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
  173. func CheckVRAM() (int64, error) {
  174. cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits")
  175. var stdout bytes.Buffer
  176. cmd.Stdout = &stdout
  177. err := cmd.Run()
  178. if err != nil {
  179. return 0, errNoGPU
  180. }
  181. var freeMiB int64
  182. scanner := bufio.NewScanner(&stdout)
  183. for scanner.Scan() {
  184. line := scanner.Text()
  185. vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64)
  186. if err != nil {
  187. return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
  188. }
  189. freeMiB += vram
  190. }
  191. freeBytes := freeMiB * 1024 * 1024
  192. if freeBytes < 2*format.GigaByte {
  193. log.Printf("less than 2 GB VRAM available, falling back to CPU only")
  194. freeMiB = 0
  195. }
  196. return freeBytes, nil
  197. }
  198. func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
  199. if opts.NumGPU != -1 {
  200. return opts.NumGPU
  201. }
  202. if runtime.GOOS == "linux" {
  203. freeBytes, err := CheckVRAM()
  204. if err != nil {
  205. if err.Error() != "nvidia-smi command failed" {
  206. log.Print(err.Error())
  207. }
  208. // nvidia driver not installed or no nvidia GPU found
  209. return 0
  210. }
  211. // Calculate bytes per layer
  212. // TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size
  213. bytesPerLayer := fileSizeBytes / numLayer
  214. // max number of layers we can fit in VRAM, subtract 8% to prevent consuming all available VRAM and running out of memory
  215. layers := int(freeBytes/bytesPerLayer) * 92 / 100
  216. log.Printf("%d MB VRAM available, loading up to %d GPU layers", freeBytes/(1024*1024), layers)
  217. return layers
  218. }
  219. // default to enable metal on macOS
  220. return 1
  221. }
  222. // StatusWriter is a writer that captures error messages from the llama runner process
  223. type StatusWriter struct {
  224. ErrCh chan error
  225. }
  226. func NewStatusWriter() *StatusWriter {
  227. return &StatusWriter{
  228. ErrCh: make(chan error, 1),
  229. }
  230. }
  231. func (w *StatusWriter) Write(b []byte) (int, error) {
  232. if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
  233. w.ErrCh <- fmt.Errorf("llama runner: %s", bytes.TrimSpace(after))
  234. }
  235. return os.Stderr.Write(b)
  236. }
  237. func newLlama(model string, adapters []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
  238. fileInfo, err := os.Stat(model)
  239. if err != nil {
  240. return nil, err
  241. }
  242. if len(adapters) > 1 {
  243. return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
  244. }
  245. numGPU := NumGPU(numLayers, fileInfo.Size(), opts)
  246. params := []string{
  247. "--model", model,
  248. "--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
  249. "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
  250. "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
  251. "--batch-size", fmt.Sprintf("%d", opts.NumBatch),
  252. "--n-gpu-layers", fmt.Sprintf("%d", numGPU),
  253. "--embedding",
  254. }
  255. if opts.NumGQA > 0 {
  256. params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
  257. }
  258. if len(adapters) > 0 {
  259. // TODO: applying multiple adapters is not supported by the llama.cpp server yet
  260. params = append(params, "--lora", adapters[0])
  261. }
  262. if opts.NumThread > 0 {
  263. params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
  264. }
  265. if !opts.F16KV {
  266. params = append(params, "--memory-f32")
  267. }
  268. if opts.UseMLock {
  269. params = append(params, "--mlock")
  270. }
  271. if !opts.UseMMap {
  272. params = append(params, "--no-mmap")
  273. }
  274. if opts.UseNUMA {
  275. params = append(params, "--numa")
  276. }
  277. var runnerErr error
  278. // start the llama.cpp server with a retry in case the port is already in use
  279. for _, runner := range runners {
  280. if runner.Accelerated && numGPU == 0 {
  281. log.Printf("skipping accelerated runner because num_gpu=0")
  282. continue
  283. }
  284. if _, err := os.Stat(runner.Path); err != nil {
  285. log.Printf("llama runner not found: %v", err)
  286. continue
  287. }
  288. port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  289. ctx, cancel := context.WithCancel(context.Background())
  290. cmd := exec.CommandContext(
  291. ctx,
  292. runner.Path,
  293. append(params, "--port", strconv.Itoa(port))...,
  294. )
  295. cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", filepath.Dir(runner.Path)))
  296. cmd.Stdout = os.Stderr
  297. statusWriter := NewStatusWriter()
  298. cmd.Stderr = statusWriter
  299. llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel, exitCh: make(chan error)}}
  300. log.Print("starting llama runner")
  301. if err := llm.Cmd.Start(); err != nil {
  302. log.Printf("error starting the external llama runner: %v", err)
  303. continue
  304. }
  305. // monitor the llama runner process and signal when it exits
  306. go func() {
  307. err := llm.Cmd.Wait()
  308. llm.exitErr = err
  309. // llm.Cmd.Wait() can only be called once, use this exit channel to signal that the process has exited
  310. llm.exitOnce.Do(func() {
  311. close(llm.exitCh)
  312. })
  313. }()
  314. if err := waitForServer(llm); err != nil {
  315. log.Printf("error starting llama runner: %v", err)
  316. llm.Close()
  317. // default the runnerErr to the error returned by the most recent llama runner process
  318. runnerErr = err
  319. // capture the error directly from the runner process, if any
  320. select {
  321. case runnerErr = <-statusWriter.ErrCh:
  322. default:
  323. // the runner process probably timed out
  324. }
  325. // try again
  326. continue
  327. }
  328. // server started successfully
  329. return llm, nil
  330. }
  331. if runnerErr != nil {
  332. // this is the error returned from the llama runner process that failed most recently
  333. return nil, runnerErr
  334. }
  335. return nil, fmt.Errorf("failed to start a llama runner")
  336. }
  337. func waitForServer(llm *llama) error {
  338. start := time.Now()
  339. expiresAt := time.Now().Add(3 * time.Minute) // be generous with timeout, large models can take a while to load
  340. ticker := time.NewTicker(200 * time.Millisecond)
  341. defer ticker.Stop()
  342. log.Print("waiting for llama runner to start responding")
  343. for {
  344. select {
  345. case <-llm.exitCh:
  346. // failed to start subprocess
  347. return fmt.Errorf("llama runner process has terminated")
  348. case <-ticker.C:
  349. if time.Now().After(expiresAt) {
  350. // timeout
  351. return fmt.Errorf("timed out waiting for llama runner to start")
  352. }
  353. if err := llm.Ping(context.Background()); err == nil {
  354. // success
  355. log.Printf("llama runner started in %f seconds", time.Since(start).Seconds())
  356. return nil
  357. }
  358. }
  359. }
  360. }
  361. func (llm *llama) Close() {
  362. // signal the sub-process to terminate
  363. llm.Cancel()
  364. // wait for the command to exit to prevent race conditions with the next run
  365. <-llm.exitCh
  366. err := llm.exitErr
  367. if err != nil {
  368. log.Printf("llama runner stopped with error: %v", err)
  369. } else {
  370. log.Print("llama runner stopped successfully")
  371. }
  372. }
  373. func (llm *llama) SetOptions(opts api.Options) {
  374. llm.Options = opts
  375. }
  376. type GenerationSettings struct {
  377. FrequencyPenalty float64 `json:"frequency_penalty"`
  378. IgnoreEOS bool `json:"ignore_eos"`
  379. LogitBias []interface{} `json:"logit_bias"`
  380. Mirostat int `json:"mirostat"`
  381. MirostatEta float64 `json:"mirostat_eta"`
  382. MirostatTau float64 `json:"mirostat_tau"`
  383. Model string `json:"model"`
  384. NCtx int `json:"n_ctx"`
  385. NKeep int `json:"n_keep"`
  386. NPredict int `json:"n_predict"`
  387. NProbs int `json:"n_probs"`
  388. PenalizeNl bool `json:"penalize_nl"`
  389. PresencePenalty float64 `json:"presence_penalty"`
  390. RepeatLastN int `json:"repeat_last_n"`
  391. RepeatPenalty float64 `json:"repeat_penalty"`
  392. Seed uint32 `json:"seed"`
  393. Stop []string `json:"stop"`
  394. Stream bool `json:"stream"`
  395. Temp float64 `json:"temp"`
  396. TfsZ float64 `json:"tfs_z"`
  397. TopK int `json:"top_k"`
  398. TopP float64 `json:"top_p"`
  399. TypicalP float64 `json:"typical_p"`
  400. }
  401. type Timings struct {
  402. PredictedN int `json:"predicted_n"`
  403. PredictedMS float64 `json:"predicted_ms"`
  404. PromptN int `json:"prompt_n"`
  405. PromptMS float64 `json:"prompt_ms"`
  406. }
  407. type Prediction struct {
  408. Content string `json:"content"`
  409. Model string `json:"model"`
  410. Prompt string `json:"prompt"`
  411. Stop bool `json:"stop"`
  412. Timings `json:"timings"`
  413. }
  414. type PredictRequest struct {
  415. Prompt string `json:"prompt"`
  416. Stream bool `json:"stream"`
  417. NPredict int `json:"n_predict"`
  418. NKeep int `json:"n_keep"`
  419. Temperature float32 `json:"temperature"`
  420. TopK int `json:"top_k"`
  421. TopP float32 `json:"top_p"`
  422. TfsZ float32 `json:"tfs_z"`
  423. TypicalP float32 `json:"typical_p"`
  424. RepeatLastN int `json:"repeat_last_n"`
  425. RepeatPenalty float32 `json:"repeat_penalty"`
  426. PresencePenalty float32 `json:"presence_penalty"`
  427. FrequencyPenalty float32 `json:"frequency_penalty"`
  428. Mirostat int `json:"mirostat"`
  429. MirostatTau float32 `json:"mirostat_tau"`
  430. MirostatEta float32 `json:"mirostat_eta"`
  431. PenalizeNl bool `json:"penalize_nl"`
  432. Seed int `json:"seed"`
  433. Stop []string `json:"stop,omitempty"`
  434. }
  435. const maxBufferSize = 512 * format.KiloByte
  436. func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
  437. prevConvo, err := llm.Decode(ctx, prevContext)
  438. if err != nil {
  439. return err
  440. }
  441. var nextContext strings.Builder
  442. nextContext.WriteString(prevConvo)
  443. nextContext.WriteString(prompt)
  444. endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
  445. predReq := PredictRequest{
  446. Prompt: nextContext.String(),
  447. Stream: true,
  448. NPredict: llm.NumPredict,
  449. NKeep: llm.NumKeep,
  450. Temperature: llm.Temperature,
  451. TopK: llm.TopK,
  452. TopP: llm.TopP,
  453. TfsZ: llm.TFSZ,
  454. TypicalP: llm.TypicalP,
  455. RepeatLastN: llm.RepeatLastN,
  456. RepeatPenalty: llm.RepeatPenalty,
  457. PresencePenalty: llm.PresencePenalty,
  458. FrequencyPenalty: llm.FrequencyPenalty,
  459. Mirostat: llm.Mirostat,
  460. MirostatTau: llm.MirostatTau,
  461. MirostatEta: llm.MirostatEta,
  462. PenalizeNl: llm.PenalizeNewline,
  463. Seed: llm.Seed,
  464. Stop: llm.Stop,
  465. }
  466. // Handling JSON marshaling with special characters unescaped.
  467. buffer := &bytes.Buffer{}
  468. enc := json.NewEncoder(buffer)
  469. enc.SetEscapeHTML(false)
  470. if err := enc.Encode(predReq); err != nil {
  471. return fmt.Errorf("failed to marshal data: %v", err)
  472. }
  473. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
  474. if err != nil {
  475. return fmt.Errorf("error creating POST request: %v", err)
  476. }
  477. req.Header.Set("Content-Type", "application/json")
  478. resp, err := http.DefaultClient.Do(req)
  479. if err != nil {
  480. return fmt.Errorf("POST predict: %v", err)
  481. }
  482. defer resp.Body.Close()
  483. if resp.StatusCode >= 400 {
  484. bodyBytes, err := io.ReadAll(resp.Body)
  485. if err != nil {
  486. return fmt.Errorf("failed reading llm error response: %w", err)
  487. }
  488. log.Printf("llm predict error: %s", bodyBytes)
  489. return fmt.Errorf("%s", bodyBytes)
  490. }
  491. scanner := bufio.NewScanner(resp.Body)
  492. // increase the buffer size to avoid running out of space
  493. buf := make([]byte, 0, maxBufferSize)
  494. scanner.Buffer(buf, maxBufferSize)
  495. for scanner.Scan() {
  496. select {
  497. case <-ctx.Done():
  498. // This handles the request cancellation
  499. return ctx.Err()
  500. default:
  501. line := scanner.Text()
  502. if line == "" {
  503. continue
  504. }
  505. // Read data from the server-side event stream
  506. if strings.HasPrefix(line, "data: ") {
  507. evt := line[6:]
  508. var p Prediction
  509. if err := json.Unmarshal([]byte(evt), &p); err != nil {
  510. return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
  511. }
  512. if p.Content != "" {
  513. fn(api.GenerateResponse{Response: p.Content})
  514. nextContext.WriteString(p.Content)
  515. }
  516. if p.Stop {
  517. embd, err := llm.Encode(ctx, nextContext.String())
  518. if err != nil {
  519. return fmt.Errorf("encoding context: %v", err)
  520. }
  521. fn(api.GenerateResponse{
  522. Done: true,
  523. Context: embd,
  524. PromptEvalCount: p.PromptN,
  525. PromptEvalDuration: parseDurationMs(p.PromptMS),
  526. EvalCount: p.PredictedN,
  527. EvalDuration: parseDurationMs(p.PredictedMS),
  528. })
  529. return nil
  530. }
  531. }
  532. }
  533. }
  534. if err := scanner.Err(); err != nil {
  535. return fmt.Errorf("error reading llm response: %v", err)
  536. }
  537. return nil
  538. }
  539. type TokenizeRequest struct {
  540. Content string `json:"content"`
  541. }
  542. type TokenizeResponse struct {
  543. Tokens []int `json:"tokens"`
  544. }
  545. func (llm *llama) Encode(ctx context.Context, prompt string) ([]int, error) {
  546. endpoint := fmt.Sprintf("http://127.0.0.1:%d/tokenize", llm.Port)
  547. data, err := json.Marshal(TokenizeRequest{Content: prompt})
  548. if err != nil {
  549. return nil, fmt.Errorf("marshaling encode data: %w", err)
  550. }
  551. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  552. if err != nil {
  553. return nil, fmt.Errorf("encode request: %w", err)
  554. }
  555. req.Header.Set("Content-Type", "application/json")
  556. resp, err := http.DefaultClient.Do(req)
  557. if err != nil {
  558. return nil, fmt.Errorf("do encode request: %w", err)
  559. }
  560. defer resp.Body.Close()
  561. body, err := io.ReadAll(resp.Body)
  562. if err != nil {
  563. return nil, fmt.Errorf("read encode request: %w", err)
  564. }
  565. if resp.StatusCode >= 400 {
  566. log.Printf("llm encode error: %s", body)
  567. return nil, fmt.Errorf("%s", body)
  568. }
  569. var encoded TokenizeResponse
  570. if err := json.Unmarshal(body, &encoded); err != nil {
  571. return nil, fmt.Errorf("unmarshal encode response: %w", err)
  572. }
  573. return encoded.Tokens, nil
  574. }
  575. type DetokenizeRequest struct {
  576. Tokens []int `json:"tokens"`
  577. }
  578. type DetokenizeResponse struct {
  579. Content string `json:"content"`
  580. }
  581. func (llm *llama) Decode(ctx context.Context, tokens []int) (string, error) {
  582. if len(tokens) == 0 {
  583. return "", nil
  584. }
  585. endpoint := fmt.Sprintf("http://127.0.0.1:%d/detokenize", llm.Port)
  586. data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
  587. if err != nil {
  588. return "", fmt.Errorf("marshaling decode data: %w", err)
  589. }
  590. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  591. if err != nil {
  592. return "", fmt.Errorf("decode request: %w", err)
  593. }
  594. req.Header.Set("Content-Type", "application/json")
  595. resp, err := http.DefaultClient.Do(req)
  596. if err != nil {
  597. return "", fmt.Errorf("do decode request: %w", err)
  598. }
  599. defer resp.Body.Close()
  600. body, err := io.ReadAll(resp.Body)
  601. if err != nil {
  602. return "", fmt.Errorf("read decode request: %w", err)
  603. }
  604. if resp.StatusCode >= 400 {
  605. log.Printf("llm decode error: %s", body)
  606. return "", fmt.Errorf("%s", body)
  607. }
  608. var decoded DetokenizeResponse
  609. if err := json.Unmarshal(body, &decoded); err != nil {
  610. return "", fmt.Errorf("unmarshal encode response: %w", err)
  611. }
  612. // decoded content contains a leading whitespace
  613. decoded.Content, _ = strings.CutPrefix(decoded.Content, "")
  614. return decoded.Content, nil
  615. }
  616. type EmbeddingRequest struct {
  617. Content string `json:"content"`
  618. }
  619. type EmbeddingResponse struct {
  620. Embedding []float64 `json:"embedding"`
  621. }
  622. func (llm *llama) Embedding(ctx context.Context, input string) ([]float64, error) {
  623. endpoint := fmt.Sprintf("http://127.0.0.1:%d/embedding", llm.Port)
  624. data, err := json.Marshal(TokenizeRequest{Content: input})
  625. if err != nil {
  626. return nil, fmt.Errorf("error marshaling embed data: %w", err)
  627. }
  628. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  629. if err != nil {
  630. return nil, fmt.Errorf("error creating embed request: %w", err)
  631. }
  632. req.Header.Set("Content-Type", "application/json")
  633. resp, err := http.DefaultClient.Do(req)
  634. if err != nil {
  635. return nil, fmt.Errorf("POST embedding: %w", err)
  636. }
  637. defer resp.Body.Close()
  638. body, err := io.ReadAll(resp.Body)
  639. if err != nil {
  640. return nil, fmt.Errorf("error reading embed response: %w", err)
  641. }
  642. if resp.StatusCode >= 400 {
  643. log.Printf("llm encode error: %s", body)
  644. return nil, fmt.Errorf("%s", body)
  645. }
  646. var embedding EmbeddingResponse
  647. if err := json.Unmarshal(body, &embedding); err != nil {
  648. return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
  649. }
  650. return embedding.Embedding, nil
  651. }
  652. // Ping checks that the server subprocess is still running and responding to requests
  653. func (llm *llama) Ping(ctx context.Context) error {
  654. resp, err := http.Head(fmt.Sprintf("http://127.0.0.1:%d", llm.Port))
  655. if err != nil {
  656. return fmt.Errorf("ping resp: %w", err)
  657. }
  658. if resp.StatusCode != http.StatusOK {
  659. return fmt.Errorf("unexpected ping status: %s", resp.Status)
  660. }
  661. return nil
  662. }