llama.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780
  1. package llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "embed"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log"
  13. "math/rand"
  14. "net/http"
  15. "os"
  16. "os/exec"
  17. "path"
  18. "path/filepath"
  19. "runtime"
  20. "strconv"
  21. "strings"
  22. "sync"
  23. "time"
  24. "github.com/jmorganca/ollama/api"
  25. )
  26. //go:embed llama.cpp/*/build/*/bin/*
  27. var llamaCppEmbed embed.FS
  28. type ModelRunner struct {
  29. Path string // path to the model runner executable
  30. Accelerated bool
  31. }
  32. func chooseRunners(workDir, runnerType string) []ModelRunner {
  33. buildPath := path.Join("llama.cpp", runnerType, "build")
  34. var runners []ModelRunner
  35. // set the runners based on the OS
  36. // IMPORTANT: the order of the runners in the array is the priority order
  37. switch runtime.GOOS {
  38. case "darwin":
  39. runners = []ModelRunner{
  40. {Path: path.Join(buildPath, "metal", "bin", "ollama-runner")},
  41. {Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
  42. }
  43. case "linux":
  44. runners = []ModelRunner{
  45. {Path: path.Join(buildPath, "cuda", "bin", "ollama-runner"), Accelerated: true},
  46. {Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
  47. }
  48. case "windows":
  49. // TODO: select windows GPU runner here when available
  50. runners = []ModelRunner{
  51. {Path: path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe")},
  52. }
  53. default:
  54. log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
  55. runners = []ModelRunner{
  56. {Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
  57. }
  58. }
  59. runnerAvailable := false // if no runner files are found in the embed, this flag will cause a fast fail
  60. for _, r := range runners {
  61. // find all the files in the runner's bin directory
  62. files, err := fs.Glob(llamaCppEmbed, path.Join(path.Dir(r.Path), "*"))
  63. if err != nil {
  64. // this is expected, ollama may be compiled without all runners packed in
  65. log.Printf("%s runner not found: %v", r, err)
  66. continue
  67. }
  68. for _, f := range files {
  69. runnerAvailable = true
  70. srcFile, err := llamaCppEmbed.Open(f)
  71. if err != nil {
  72. log.Fatalf("read llama runner %s: %v", f, err)
  73. }
  74. defer srcFile.Close()
  75. // create the directory in case it does not exist, filepath.Dir() converts the file path to the OS's format
  76. destPath := filepath.Join(workDir, filepath.Dir(f))
  77. if err := os.MkdirAll(destPath, 0o755); err != nil {
  78. log.Fatalf("create runner temp dir %s: %v", filepath.Dir(f), err)
  79. }
  80. // create the path to the destination file, filepath.Base() converts the file path to the OS's format
  81. destFile := filepath.Join(destPath, filepath.Base(f))
  82. _, err = os.Stat(destFile)
  83. switch {
  84. case errors.Is(err, os.ErrNotExist):
  85. destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  86. if err != nil {
  87. log.Fatalf("write llama runner %s: %v", f, err)
  88. }
  89. defer destFile.Close()
  90. if _, err := io.Copy(destFile, srcFile); err != nil {
  91. log.Fatalf("copy llama runner %s: %v", f, err)
  92. }
  93. case err != nil:
  94. log.Fatalf("stat llama runner %s: %v", f, err)
  95. }
  96. }
  97. }
  98. if !runnerAvailable {
  99. log.Fatalf("%s runner not found", runnerType)
  100. }
  101. // return the runners to try in priority order
  102. localRunnersByPriority := []ModelRunner{}
  103. for _, r := range runners {
  104. // clean the ModelRunner paths so that they match the OS we are running on
  105. localRunnersByPriority = append(localRunnersByPriority, ModelRunner{
  106. Path: filepath.Clean(path.Join(workDir, r.Path)),
  107. Accelerated: r.Accelerated,
  108. })
  109. }
  110. return localRunnersByPriority
  111. }
  112. type llamaModel struct {
  113. hyperparameters llamaHyperparameters
  114. }
  115. func (llm *llamaModel) ModelFamily() string {
  116. return "llama"
  117. }
  118. func llamaModelType(numLayer uint32) string {
  119. switch numLayer {
  120. case 26:
  121. return "3B"
  122. case 32:
  123. return "7B"
  124. case 40:
  125. return "13B"
  126. case 48:
  127. return "34B"
  128. case 60:
  129. return "30B"
  130. case 80:
  131. return "65B"
  132. default:
  133. return "unknown"
  134. }
  135. }
  136. func (llm *llamaModel) ModelType() string {
  137. return llamaModelType(llm.hyperparameters.NumLayer)
  138. }
  139. func (llm *llamaModel) FileType() string {
  140. return fileType(llm.hyperparameters.FileType)
  141. }
  142. func (llm *llamaModel) NumLayers() int64 {
  143. return int64(llm.hyperparameters.NumLayer)
  144. }
  145. type llamaHyperparameters struct {
  146. // NumVocab is the size of the model's vocabulary.
  147. NumVocab uint32
  148. // NumEmbd is the size of the model's embedding layer.
  149. NumEmbd uint32
  150. NumMult uint32
  151. NumHead uint32
  152. // NumLayer is the number of layers in the model.
  153. NumLayer uint32
  154. NumRot uint32
  155. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
  156. FileType uint32
  157. }
  158. type Running struct {
  159. Port int
  160. Cmd *exec.Cmd
  161. Cancel context.CancelFunc
  162. exitOnce sync.Once
  163. exitCh chan error // channel to receive the exit status of the subprocess
  164. exitErr error // error returned by the subprocess
  165. }
  166. type llama struct {
  167. api.Options
  168. Running
  169. }
  170. var errNoGPU = errors.New("nvidia-smi command failed")
  171. // CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs
  172. func CheckVRAM() (int64, error) {
  173. cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits")
  174. var stdout bytes.Buffer
  175. cmd.Stdout = &stdout
  176. err := cmd.Run()
  177. if err != nil {
  178. return 0, errNoGPU
  179. }
  180. var free int64
  181. scanner := bufio.NewScanner(&stdout)
  182. for scanner.Scan() {
  183. line := scanner.Text()
  184. vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64)
  185. if err != nil {
  186. return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
  187. }
  188. free += vram
  189. }
  190. if free*1024*1024 < 2*1000*1000*1000 {
  191. log.Printf("less than 2 GB VRAM available, falling back to CPU only")
  192. free = 0
  193. }
  194. return free, nil
  195. }
  196. func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
  197. if opts.NumGPU != -1 {
  198. return opts.NumGPU
  199. }
  200. if runtime.GOOS == "linux" {
  201. vramMib, err := CheckVRAM()
  202. if err != nil {
  203. if err.Error() != "nvidia-smi command failed" {
  204. log.Print(err.Error())
  205. }
  206. // nvidia driver not installed or no nvidia GPU found
  207. return 0
  208. }
  209. freeVramBytes := int64(vramMib) * 1024 * 1024 // 1 MiB = 1024^2 bytes
  210. // Calculate bytes per layer
  211. // TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size
  212. bytesPerLayer := fileSizeBytes / numLayer
  213. // max number of layers we can fit in VRAM, subtract 8% to prevent consuming all available VRAM and running out of memory
  214. layers := int(freeVramBytes/bytesPerLayer) * 92 / 100
  215. log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, layers)
  216. return layers
  217. }
  218. // default to enable metal on macOS
  219. return 1
  220. }
  221. // StatusWriter is a writer that captures error messages from the llama runner process
  222. type StatusWriter struct {
  223. ErrCh chan error
  224. }
  225. func NewStatusWriter() *StatusWriter {
  226. return &StatusWriter{
  227. ErrCh: make(chan error, 1),
  228. }
  229. }
  230. func (w *StatusWriter) Write(b []byte) (int, error) {
  231. if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
  232. w.ErrCh <- fmt.Errorf("llama runner: %s", bytes.TrimSpace(after))
  233. }
  234. return os.Stderr.Write(b)
  235. }
  236. func newLlama(model string, adapters []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
  237. fileInfo, err := os.Stat(model)
  238. if err != nil {
  239. return nil, err
  240. }
  241. if len(adapters) > 1 {
  242. return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
  243. }
  244. numGPU := NumGPU(numLayers, fileInfo.Size(), opts)
  245. params := []string{
  246. "--model", model,
  247. "--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
  248. "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
  249. "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
  250. "--batch-size", fmt.Sprintf("%d", opts.NumBatch),
  251. "--embedding",
  252. }
  253. if numGPU > 0 {
  254. params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", numGPU))
  255. }
  256. if opts.NumGQA > 0 {
  257. params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
  258. }
  259. if len(adapters) > 0 {
  260. // TODO: applying multiple adapters is not supported by the llama.cpp server yet
  261. params = append(params, "--lora", adapters[0])
  262. }
  263. if opts.NumThread > 0 {
  264. params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
  265. }
  266. if !opts.F16KV {
  267. params = append(params, "--memory-f32")
  268. }
  269. if opts.UseMLock {
  270. params = append(params, "--mlock")
  271. }
  272. if !opts.UseMMap {
  273. params = append(params, "--no-mmap")
  274. }
  275. if opts.UseNUMA {
  276. params = append(params, "--numa")
  277. }
  278. var runnerErr error
  279. // start the llama.cpp server with a retry in case the port is already in use
  280. for _, runner := range runners {
  281. if runner.Accelerated && numGPU == 0 {
  282. log.Printf("skipping accelerated runner because num_gpu=0")
  283. continue
  284. }
  285. if _, err := os.Stat(runner.Path); err != nil {
  286. log.Printf("llama runner not found: %v", err)
  287. continue
  288. }
  289. port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  290. ctx, cancel := context.WithCancel(context.Background())
  291. cmd := exec.CommandContext(
  292. ctx,
  293. runner.Path,
  294. append(params, "--port", strconv.Itoa(port))...,
  295. )
  296. cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", filepath.Dir(runner.Path)))
  297. cmd.Stdout = os.Stderr
  298. statusWriter := NewStatusWriter()
  299. cmd.Stderr = statusWriter
  300. llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel, exitCh: make(chan error)}}
  301. log.Print("starting llama runner")
  302. if err := llm.Cmd.Start(); err != nil {
  303. log.Printf("error starting the external llama runner: %v", err)
  304. continue
  305. }
  306. // monitor the llama runner process and signal when it exits
  307. go func() {
  308. err := llm.Cmd.Wait()
  309. llm.exitErr = err
  310. // llm.Cmd.Wait() can only be called once, use this exit channel to signal that the process has exited
  311. llm.exitOnce.Do(func() {
  312. close(llm.exitCh)
  313. })
  314. }()
  315. if err := waitForServer(llm); err != nil {
  316. log.Printf("error starting llama runner: %v", err)
  317. llm.Close()
  318. // default the runnerErr to the error returned by the most recent llama runner process
  319. runnerErr = err
  320. // capture the error directly from the runner process, if any
  321. select {
  322. case runnerErr = <-statusWriter.ErrCh:
  323. default:
  324. // the runner process probably timed out
  325. }
  326. // try again
  327. continue
  328. }
  329. // server started successfully
  330. return llm, nil
  331. }
  332. if runnerErr != nil {
  333. // this is the error returned from the llama runner process that failed most recently
  334. return nil, runnerErr
  335. }
  336. return nil, fmt.Errorf("failed to start a llama runner")
  337. }
  338. func waitForServer(llm *llama) error {
  339. start := time.Now()
  340. expiresAt := time.Now().Add(3 * time.Minute) // be generous with timeout, large models can take a while to load
  341. ticker := time.NewTicker(200 * time.Millisecond)
  342. defer ticker.Stop()
  343. log.Print("waiting for llama runner to start responding")
  344. for {
  345. select {
  346. case <-llm.exitCh:
  347. // failed to start subprocess
  348. return fmt.Errorf("llama runner process has terminated")
  349. case <-ticker.C:
  350. if time.Now().After(expiresAt) {
  351. // timeout
  352. return fmt.Errorf("timed out waiting for llama runner to start")
  353. }
  354. if err := llm.Ping(context.Background()); err == nil {
  355. // success
  356. log.Printf("llama runner started in %f seconds", time.Since(start).Seconds())
  357. return nil
  358. }
  359. }
  360. }
  361. }
  362. func (llm *llama) Close() {
  363. // signal the sub-process to terminate
  364. llm.Cancel()
  365. // wait for the command to exit to prevent race conditions with the next run
  366. <-llm.exitCh
  367. err := llm.exitErr
  368. if err != nil {
  369. log.Printf("llama runner stopped with error: %v", err)
  370. } else {
  371. log.Print("llama runner stopped successfully")
  372. }
  373. }
  374. func (llm *llama) SetOptions(opts api.Options) {
  375. llm.Options = opts
  376. }
  377. type GenerationSettings struct {
  378. FrequencyPenalty float64 `json:"frequency_penalty"`
  379. IgnoreEOS bool `json:"ignore_eos"`
  380. LogitBias []interface{} `json:"logit_bias"`
  381. Mirostat int `json:"mirostat"`
  382. MirostatEta float64 `json:"mirostat_eta"`
  383. MirostatTau float64 `json:"mirostat_tau"`
  384. Model string `json:"model"`
  385. NCtx int `json:"n_ctx"`
  386. NKeep int `json:"n_keep"`
  387. NPredict int `json:"n_predict"`
  388. NProbs int `json:"n_probs"`
  389. PenalizeNl bool `json:"penalize_nl"`
  390. PresencePenalty float64 `json:"presence_penalty"`
  391. RepeatLastN int `json:"repeat_last_n"`
  392. RepeatPenalty float64 `json:"repeat_penalty"`
  393. Seed uint32 `json:"seed"`
  394. Stop []string `json:"stop"`
  395. Stream bool `json:"stream"`
  396. Temp float64 `json:"temp"`
  397. TfsZ float64 `json:"tfs_z"`
  398. TopK int `json:"top_k"`
  399. TopP float64 `json:"top_p"`
  400. TypicalP float64 `json:"typical_p"`
  401. }
  402. type Timings struct {
  403. PredictedN int `json:"predicted_n"`
  404. PredictedMS float64 `json:"predicted_ms"`
  405. PromptN int `json:"prompt_n"`
  406. PromptMS float64 `json:"prompt_ms"`
  407. }
  408. type Prediction struct {
  409. Content string `json:"content"`
  410. Model string `json:"model"`
  411. Prompt string `json:"prompt"`
  412. Stop bool `json:"stop"`
  413. Timings `json:"timings"`
  414. }
  415. type PredictRequest struct {
  416. Prompt string `json:"prompt"`
  417. Stream bool `json:"stream"`
  418. NPredict int `json:"n_predict"`
  419. NKeep int `json:"n_keep"`
  420. Temperature float32 `json:"temperature"`
  421. TopK int `json:"top_k"`
  422. TopP float32 `json:"top_p"`
  423. TfsZ float32 `json:"tfs_z"`
  424. TypicalP float32 `json:"typical_p"`
  425. RepeatLastN int `json:"repeat_last_n"`
  426. RepeatPenalty float32 `json:"repeat_penalty"`
  427. PresencePenalty float32 `json:"presence_penalty"`
  428. FrequencyPenalty float32 `json:"frequency_penalty"`
  429. Mirostat int `json:"mirostat"`
  430. MirostatTau float32 `json:"mirostat_tau"`
  431. MirostatEta float32 `json:"mirostat_eta"`
  432. PenalizeNl bool `json:"penalize_nl"`
  433. Seed int `json:"seed"`
  434. Stop []string `json:"stop,omitempty"`
  435. }
  436. const maxBufferSize = 512 * 1000 // 512KB
  437. func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
  438. prevConvo, err := llm.Decode(ctx, prevContext)
  439. if err != nil {
  440. return err
  441. }
  442. var nextContext strings.Builder
  443. nextContext.WriteString(prevConvo)
  444. nextContext.WriteString(prompt)
  445. endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
  446. predReq := PredictRequest{
  447. Prompt: nextContext.String(),
  448. Stream: true,
  449. NPredict: llm.NumPredict,
  450. NKeep: llm.NumKeep,
  451. Temperature: llm.Temperature,
  452. TopK: llm.TopK,
  453. TopP: llm.TopP,
  454. TfsZ: llm.TFSZ,
  455. TypicalP: llm.TypicalP,
  456. RepeatLastN: llm.RepeatLastN,
  457. RepeatPenalty: llm.RepeatPenalty,
  458. PresencePenalty: llm.PresencePenalty,
  459. FrequencyPenalty: llm.FrequencyPenalty,
  460. Mirostat: llm.Mirostat,
  461. MirostatTau: llm.MirostatTau,
  462. MirostatEta: llm.MirostatEta,
  463. PenalizeNl: llm.PenalizeNewline,
  464. Seed: llm.Seed,
  465. Stop: llm.Stop,
  466. }
  467. data, err := json.Marshal(predReq)
  468. if err != nil {
  469. return fmt.Errorf("error marshaling data: %v", err)
  470. }
  471. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  472. if err != nil {
  473. return fmt.Errorf("error creating POST request: %v", err)
  474. }
  475. req.Header.Set("Content-Type", "application/json")
  476. resp, err := http.DefaultClient.Do(req)
  477. if err != nil {
  478. return fmt.Errorf("POST predict: %v", err)
  479. }
  480. defer resp.Body.Close()
  481. if resp.StatusCode >= 400 {
  482. bodyBytes, err := io.ReadAll(resp.Body)
  483. if err != nil {
  484. return fmt.Errorf("failed reading llm error response: %w", err)
  485. }
  486. log.Printf("llm predict error: %s", bodyBytes)
  487. return fmt.Errorf("%s", bodyBytes)
  488. }
  489. scanner := bufio.NewScanner(resp.Body)
  490. // increase the buffer size to avoid running out of space
  491. buf := make([]byte, 0, maxBufferSize)
  492. scanner.Buffer(buf, maxBufferSize)
  493. for scanner.Scan() {
  494. select {
  495. case <-ctx.Done():
  496. // This handles the request cancellation
  497. return ctx.Err()
  498. default:
  499. line := scanner.Text()
  500. if line == "" {
  501. continue
  502. }
  503. // Read data from the server-side event stream
  504. if strings.HasPrefix(line, "data: ") {
  505. evt := line[6:]
  506. var p Prediction
  507. if err := json.Unmarshal([]byte(evt), &p); err != nil {
  508. return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
  509. }
  510. if p.Content != "" {
  511. fn(api.GenerateResponse{Response: p.Content})
  512. nextContext.WriteString(p.Content)
  513. }
  514. if p.Stop {
  515. embd, err := llm.Encode(ctx, nextContext.String())
  516. if err != nil {
  517. return fmt.Errorf("encoding context: %v", err)
  518. }
  519. fn(api.GenerateResponse{
  520. Done: true,
  521. Context: embd,
  522. PromptEvalCount: p.PromptN,
  523. PromptEvalDuration: parseDurationMs(p.PromptMS),
  524. EvalCount: p.PredictedN,
  525. EvalDuration: parseDurationMs(p.PredictedMS),
  526. })
  527. return nil
  528. }
  529. }
  530. }
  531. }
  532. if err := scanner.Err(); err != nil {
  533. return fmt.Errorf("error reading llm response: %v", err)
  534. }
  535. return nil
  536. }
  537. type TokenizeRequest struct {
  538. Content string `json:"content"`
  539. }
  540. type TokenizeResponse struct {
  541. Tokens []int `json:"tokens"`
  542. }
  543. func (llm *llama) Encode(ctx context.Context, prompt string) ([]int, error) {
  544. endpoint := fmt.Sprintf("http://127.0.0.1:%d/tokenize", llm.Port)
  545. data, err := json.Marshal(TokenizeRequest{Content: prompt})
  546. if err != nil {
  547. return nil, fmt.Errorf("marshaling encode data: %w", err)
  548. }
  549. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  550. if err != nil {
  551. return nil, fmt.Errorf("encode request: %w", err)
  552. }
  553. req.Header.Set("Content-Type", "application/json")
  554. resp, err := http.DefaultClient.Do(req)
  555. if err != nil {
  556. return nil, fmt.Errorf("do encode request: %w", err)
  557. }
  558. defer resp.Body.Close()
  559. body, err := io.ReadAll(resp.Body)
  560. if err != nil {
  561. return nil, fmt.Errorf("read encode request: %w", err)
  562. }
  563. if resp.StatusCode >= 400 {
  564. log.Printf("llm encode error: %s", body)
  565. return nil, fmt.Errorf("%s", body)
  566. }
  567. var encoded TokenizeResponse
  568. if err := json.Unmarshal(body, &encoded); err != nil {
  569. return nil, fmt.Errorf("unmarshal encode response: %w", err)
  570. }
  571. return encoded.Tokens, nil
  572. }
  573. type DetokenizeRequest struct {
  574. Tokens []int `json:"tokens"`
  575. }
  576. type DetokenizeResponse struct {
  577. Content string `json:"content"`
  578. }
  579. func (llm *llama) Decode(ctx context.Context, tokens []int) (string, error) {
  580. if len(tokens) == 0 {
  581. return "", nil
  582. }
  583. endpoint := fmt.Sprintf("http://127.0.0.1:%d/detokenize", llm.Port)
  584. data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
  585. if err != nil {
  586. return "", fmt.Errorf("marshaling decode data: %w", err)
  587. }
  588. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  589. if err != nil {
  590. return "", fmt.Errorf("decode request: %w", err)
  591. }
  592. req.Header.Set("Content-Type", "application/json")
  593. resp, err := http.DefaultClient.Do(req)
  594. if err != nil {
  595. return "", fmt.Errorf("do decode request: %w", err)
  596. }
  597. defer resp.Body.Close()
  598. body, err := io.ReadAll(resp.Body)
  599. if err != nil {
  600. return "", fmt.Errorf("read decode request: %w", err)
  601. }
  602. if resp.StatusCode >= 400 {
  603. log.Printf("llm decode error: %s", body)
  604. return "", fmt.Errorf("%s", body)
  605. }
  606. var decoded DetokenizeResponse
  607. if err := json.Unmarshal(body, &decoded); err != nil {
  608. return "", fmt.Errorf("unmarshal encode response: %w", err)
  609. }
  610. // decoded content contains a leading whitespace
  611. decoded.Content, _ = strings.CutPrefix(decoded.Content, "")
  612. return decoded.Content, nil
  613. }
  614. type EmbeddingRequest struct {
  615. Content string `json:"content"`
  616. }
  617. type EmbeddingResponse struct {
  618. Embedding []float64 `json:"embedding"`
  619. }
  620. func (llm *llama) Embedding(ctx context.Context, input string) ([]float64, error) {
  621. endpoint := fmt.Sprintf("http://127.0.0.1:%d/embedding", llm.Port)
  622. data, err := json.Marshal(TokenizeRequest{Content: input})
  623. if err != nil {
  624. return nil, fmt.Errorf("error marshaling embed data: %w", err)
  625. }
  626. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  627. if err != nil {
  628. return nil, fmt.Errorf("error creating embed request: %w", err)
  629. }
  630. req.Header.Set("Content-Type", "application/json")
  631. resp, err := http.DefaultClient.Do(req)
  632. if err != nil {
  633. return nil, fmt.Errorf("POST embedding: %w", err)
  634. }
  635. defer resp.Body.Close()
  636. body, err := io.ReadAll(resp.Body)
  637. if err != nil {
  638. return nil, fmt.Errorf("error reading embed response: %w", err)
  639. }
  640. if resp.StatusCode >= 400 {
  641. log.Printf("llm encode error: %s", body)
  642. return nil, fmt.Errorf("%s", body)
  643. }
  644. var embedding EmbeddingResponse
  645. if err := json.Unmarshal(body, &embedding); err != nil {
  646. return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
  647. }
  648. return embedding.Embedding, nil
  649. }
  650. // Ping checks that the server subprocess is still running and responding to requests
  651. func (llm *llama) Ping(ctx context.Context) error {
  652. resp, err := http.Head(fmt.Sprintf("http://127.0.0.1:%d", llm.Port))
  653. if err != nil {
  654. return fmt.Errorf("ping resp: %w", err)
  655. }
  656. if resp.StatusCode != http.StatusOK {
  657. return fmt.Errorf("unexpected ping status: %s", resp.Status)
  658. }
  659. return nil
  660. }