ggml_llama.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. package llm
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "embed"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/fs"
  12. "log"
  13. "math/rand"
  14. "net/http"
  15. "os"
  16. "os/exec"
  17. "path"
  18. "path/filepath"
  19. "runtime"
  20. "strconv"
  21. "strings"
  22. "sync"
  23. "time"
  24. "github.com/jmorganca/ollama/api"
  25. )
  26. const ModelFamilyLlama ModelFamily = "llama"
  27. //go:embed llama.cpp/ggml/build/*/bin/*
  28. var llamaCppEmbed embed.FS
  29. var (
  30. ggmlGPU = path.Join("llama.cpp", "ggml", "build", "gpu", "bin")
  31. ggmlCPU = path.Join("llama.cpp", "ggml", "build", "cpu", "bin")
  32. )
  33. var (
  34. ggmlInit sync.Once
  35. ggmlRunnerPath string
  36. )
  37. func osPath(llamaPath string) string {
  38. if runtime.GOOS == "windows" {
  39. return path.Join(llamaPath, "Release")
  40. }
  41. return llamaPath
  42. }
  43. func initGGML() {
  44. ggmlInit.Do(func() {
  45. tmpDir, err := os.MkdirTemp("", "llama-*")
  46. if err != nil {
  47. log.Fatalf("llama.cpp: failed to create temp dir: %v", err)
  48. }
  49. llamaPath := osPath(ggmlGPU)
  50. if _, err := fs.Stat(llamaCppEmbed, llamaPath); err != nil {
  51. llamaPath = osPath(ggmlCPU)
  52. if _, err := fs.Stat(llamaCppEmbed, llamaPath); err != nil {
  53. log.Fatalf("llama.cpp executable not found")
  54. }
  55. }
  56. files := []string{"server"}
  57. switch runtime.GOOS {
  58. case "windows":
  59. files = []string{"server.exe"}
  60. case "darwin":
  61. files = append(files, "ggml-metal.metal")
  62. }
  63. for _, f := range files {
  64. srcPath := path.Join(llamaPath, f)
  65. destPath := filepath.Join(tmpDir, f)
  66. srcFile, err := llamaCppEmbed.Open(srcPath)
  67. if err != nil {
  68. log.Fatalf("read llama.cpp %s: %v", f, err)
  69. }
  70. defer srcFile.Close()
  71. destFile, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  72. if err != nil {
  73. log.Fatalf("write llama.cpp %s: %v", f, err)
  74. }
  75. defer destFile.Close()
  76. if _, err := io.Copy(destFile, srcFile); err != nil {
  77. log.Fatalf("copy llama.cpp %s: %v", f, err)
  78. }
  79. }
  80. ggmlRunnerPath = filepath.Join(tmpDir, "server")
  81. if runtime.GOOS == "windows" {
  82. ggmlRunnerPath = filepath.Join(tmpDir, "server.exe")
  83. }
  84. })
  85. }
  86. type ModelRunner struct {
  87. Path string // path to the model runner executable
  88. }
  89. func ggmlRunner() ModelRunner {
  90. initGGML()
  91. return ModelRunner{Path: ggmlRunnerPath}
  92. }
  93. type llamaModel struct {
  94. hyperparameters llamaHyperparameters
  95. }
  96. func (llm *llamaModel) ModelFamily() ModelFamily {
  97. return ModelFamilyLlama
  98. }
  99. func (llm *llamaModel) ModelType() ModelType {
  100. switch llm.hyperparameters.NumLayer {
  101. case 26:
  102. return ModelType3B
  103. case 32:
  104. return ModelType7B
  105. case 40:
  106. return ModelType13B
  107. case 48:
  108. return ModelType34B
  109. case 60:
  110. return ModelType30B
  111. case 80:
  112. return ModelType65B
  113. }
  114. // TODO: find a better default
  115. return ModelType7B
  116. }
  117. func (llm *llamaModel) FileType() FileType {
  118. return llm.hyperparameters.FileType
  119. }
  120. type llamaHyperparameters struct {
  121. // NumVocab is the size of the model's vocabulary.
  122. NumVocab uint32
  123. // NumEmbd is the size of the model's embedding layer.
  124. NumEmbd uint32
  125. NumMult uint32
  126. NumHead uint32
  127. // NumLayer is the number of layers in the model.
  128. NumLayer uint32
  129. NumRot uint32
  130. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
  131. FileType llamaFileType
  132. }
  133. type llamaFileType uint32
  134. const (
  135. llamaFileTypeF32 llamaFileType = iota
  136. llamaFileTypeF16
  137. llamaFileTypeQ4_0
  138. llamaFileTypeQ4_1
  139. llamaFileTypeQ4_1_F16
  140. llamaFileTypeQ8_0 llamaFileType = iota + 2
  141. llamaFileTypeQ5_0
  142. llamaFileTypeQ5_1
  143. llamaFileTypeQ2_K
  144. llamaFileTypeQ3_K_S
  145. llamaFileTypeQ3_K_M
  146. llamaFileTypeQ3_K_L
  147. llamaFileTypeQ4_K_S
  148. llamaFileTypeQ4_K_M
  149. llamaFileTypeQ5_K_S
  150. llamaFileTypeQ5_K_M
  151. llamaFileTypeQ6_K
  152. )
  153. func (ft llamaFileType) String() string {
  154. switch ft {
  155. case llamaFileTypeF32:
  156. return "F32"
  157. case llamaFileTypeF16:
  158. return "F16"
  159. case llamaFileTypeQ4_0:
  160. return "Q4_0"
  161. case llamaFileTypeQ4_1:
  162. return "Q4_1"
  163. case llamaFileTypeQ4_1_F16:
  164. return "Q4_1_F16"
  165. case llamaFileTypeQ8_0:
  166. return "Q8_0"
  167. case llamaFileTypeQ5_0:
  168. return "Q5_0"
  169. case llamaFileTypeQ5_1:
  170. return "Q5_1"
  171. case llamaFileTypeQ2_K:
  172. return "Q2_K"
  173. case llamaFileTypeQ3_K_S:
  174. return "Q3_K_S"
  175. case llamaFileTypeQ3_K_M:
  176. return "Q3_K_M"
  177. case llamaFileTypeQ3_K_L:
  178. return "Q3_K_L"
  179. case llamaFileTypeQ4_K_S:
  180. return "Q4_K_S"
  181. case llamaFileTypeQ4_K_M:
  182. return "Q4_K_M"
  183. case llamaFileTypeQ5_K_S:
  184. return "Q5_K_S"
  185. case llamaFileTypeQ5_K_M:
  186. return "Q5_K_M"
  187. case llamaFileTypeQ6_K:
  188. return "Q6_K"
  189. default:
  190. return "Unknown"
  191. }
  192. }
  193. type Running struct {
  194. Port int
  195. Cmd *exec.Cmd
  196. Cancel context.CancelFunc
  197. }
  198. type llama struct {
  199. api.Options
  200. Running
  201. }
  202. func newLlama(model string, adapters []string, runner ModelRunner, opts api.Options) (*llama, error) {
  203. if _, err := os.Stat(model); err != nil {
  204. return nil, err
  205. }
  206. if _, err := os.Stat(runner.Path); err != nil {
  207. return nil, err
  208. }
  209. if len(adapters) > 1 {
  210. return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
  211. }
  212. params := []string{
  213. "--model", model,
  214. "--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
  215. "--gqa", fmt.Sprintf("%d", opts.NumGQA),
  216. "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
  217. "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
  218. "--batch-size", fmt.Sprintf("%d", opts.NumBatch),
  219. "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU),
  220. "--embedding",
  221. }
  222. if len(adapters) > 0 {
  223. // TODO: applying multiple adapters is not supported by the llama.cpp server yet
  224. params = append(params, "--lora", adapters[0])
  225. }
  226. if opts.NumThread > 0 {
  227. params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
  228. }
  229. if !opts.F16KV {
  230. params = append(params, "--memory-f32")
  231. }
  232. if opts.UseMLock {
  233. params = append(params, "--mlock")
  234. }
  235. if !opts.UseMMap {
  236. params = append(params, "--no-mmap")
  237. }
  238. if opts.UseNUMA {
  239. params = append(params, "--numa")
  240. }
  241. // start the llama.cpp server with a retry in case the port is already in use
  242. for try := 0; try < 3; try++ {
  243. port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  244. ctx, cancel := context.WithCancel(context.Background())
  245. cmd := exec.CommandContext(
  246. ctx,
  247. runner.Path,
  248. append(params, "--port", strconv.Itoa(port))...,
  249. )
  250. var stderr bytes.Buffer
  251. cmd.Stderr = &stderr
  252. llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
  253. if err := waitForServer(llm); err != nil {
  254. log.Printf("error starting llama.cpp server: %v", err)
  255. llm.Close()
  256. // try again
  257. continue
  258. }
  259. // server started successfully
  260. return llm, nil
  261. }
  262. return nil, fmt.Errorf("max retry exceeded starting llama.cpp")
  263. }
  264. func waitForServer(llm *llama) error {
  265. log.Print("starting llama.cpp server")
  266. var stderr bytes.Buffer
  267. llm.Cmd.Stderr = &stderr
  268. err := llm.Cmd.Start()
  269. if err != nil {
  270. return fmt.Errorf("error starting the external llama.cpp server: %w", err)
  271. }
  272. exitChan := make(chan error, 1)
  273. // the server is a long running process, watch for it exiting to keep track of something going wrong
  274. go func() {
  275. err := llm.Cmd.Wait()
  276. log.Print(stderr.String())
  277. exitChan <- err
  278. }()
  279. // wait for the server to start responding
  280. start := time.Now()
  281. expiresAt := time.Now().Add(30 * time.Second)
  282. ticker := time.NewTicker(100 * time.Millisecond)
  283. log.Print("waiting for llama.cpp server to start responding")
  284. for {
  285. select {
  286. case <-ticker.C:
  287. if time.Now().After(expiresAt) {
  288. return fmt.Errorf("llama.cpp server did not start responding within 30 seconds, retrying")
  289. }
  290. if err := llm.Ping(context.Background()); err == nil {
  291. log.Printf("llama.cpp server started in %f seconds", time.Since(start).Seconds())
  292. return nil
  293. }
  294. case err := <-exitChan:
  295. return fmt.Errorf("llama.cpp server exited unexpectedly: %w", err)
  296. }
  297. }
  298. }
  299. func (llm *llama) Close() {
  300. llm.Running.Cmd.Cancel()
  301. }
  302. func (llm *llama) SetOptions(opts api.Options) {
  303. llm.Options = opts
  304. }
  305. type Prediction struct {
  306. Content string `json:"content"`
  307. Stop bool `json:"stop"`
  308. }
  309. type GenerationSettings struct {
  310. FrequencyPenalty float64 `json:"frequency_penalty"`
  311. IgnoreEOS bool `json:"ignore_eos"`
  312. LogitBias []interface{} `json:"logit_bias"`
  313. Mirostat int `json:"mirostat"`
  314. MirostatEta float64 `json:"mirostat_eta"`
  315. MirostatTau float64 `json:"mirostat_tau"`
  316. Model string `json:"model"`
  317. NCtx int `json:"n_ctx"`
  318. NKeep int `json:"n_keep"`
  319. NPredict int `json:"n_predict"`
  320. NProbs int `json:"n_probs"`
  321. PenalizeNl bool `json:"penalize_nl"`
  322. PresencePenalty float64 `json:"presence_penalty"`
  323. RepeatLastN int `json:"repeat_last_n"`
  324. RepeatPenalty float64 `json:"repeat_penalty"`
  325. Seed uint32 `json:"seed"`
  326. Stop []string `json:"stop"`
  327. Stream bool `json:"stream"`
  328. Temp float64 `json:"temp"`
  329. TfsZ float64 `json:"tfs_z"`
  330. TopK int `json:"top_k"`
  331. TopP float64 `json:"top_p"`
  332. TypicalP float64 `json:"typical_p"`
  333. }
  334. type Timings struct {
  335. PredictedMS float64 `json:"predicted_ms"`
  336. PredictedN int `json:"predicted_n"`
  337. PredictedPerSecond float64 `json:"predicted_per_second"`
  338. PredictedPerTokenMS float64 `json:"predicted_per_token_ms"`
  339. PromptMS float64 `json:"prompt_ms"`
  340. PromptN int `json:"prompt_n"`
  341. PromptPerSecond float64 `json:"prompt_per_second"`
  342. PromptPerTokenMS float64 `json:"prompt_per_token_ms"`
  343. }
  344. type PredictComplete struct {
  345. Content string `json:"content"`
  346. GenerationSettings GenerationSettings `json:"generation_settings"`
  347. Model string `json:"model"`
  348. Prompt string `json:"prompt"`
  349. Stop bool `json:"stop"`
  350. StoppedEOS bool `json:"stopped_eos"`
  351. StoppedLimit bool `json:"stopped_limit"`
  352. StoppedWord bool `json:"stopped_word"`
  353. StoppingWord string `json:"stopping_word"`
  354. Timings Timings `json:"timings"`
  355. TokensCached int `json:"tokens_cached"`
  356. TokensEvaluated int `json:"tokens_evaluated"`
  357. TokensPredicted int `json:"tokens_predicted"`
  358. Truncated bool `json:"truncated"`
  359. }
  360. type PredictRequest struct {
  361. Stream bool `json:"stream"`
  362. NPredict int `json:"n_predict,omitempty"`
  363. TopK int `json:"top_k,omitempty"`
  364. TopP float32 `json:"top_p,omitempty"`
  365. TfsZ float32 `json:"tfs_z,omitempty"`
  366. TypicalP float32 `json:"typical_p,omitempty"`
  367. RepeatLastN int `json:"repeat_last_n,omitempty"`
  368. Temperature float32 `json:"temperature,omitempty"`
  369. RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
  370. PresencePenalty float32 `json:"presence_penalty,omitempty"`
  371. FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
  372. Mirostat int `json:"mirostat,omitempty"`
  373. MirostatTau float32 `json:"mirostat_tau,omitempty"`
  374. MirostatEta float32 `json:"mirostat_eta,omitempty"`
  375. PenalizeNl bool `json:"penalize_nl,omitempty"`
  376. NKeep int `json:"n_keep,omitempty"`
  377. Seed int `json:"seed,omitempty"`
  378. Prompt string `json:"prompt,omitempty"`
  379. NProbs int `json:"n_probs,omitempty"`
  380. LogitBias map[int]float32 `json:"logit_bias,omitempty"`
  381. IgnoreEos bool `json:"ignore_eos,omitempty"`
  382. Stop []string `json:"stop,omitempty"`
  383. }
  384. func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, fn func(api.GenerateResponse)) error {
  385. // we need to find the trimmed prompt context before predicting so that we can return it to the client
  386. trimmedPrompt, err := llm.marshalPrompt(ctx, predictCtx, prompt)
  387. if err != nil {
  388. return fmt.Errorf("marshaling prompt: %v", err)
  389. }
  390. endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
  391. predReq := PredictRequest{
  392. Prompt: trimmedPrompt,
  393. Stream: true,
  394. NPredict: llm.NumPredict,
  395. NKeep: llm.NumKeep,
  396. Temperature: llm.Temperature,
  397. TopK: llm.TopK,
  398. TopP: llm.TopP,
  399. TfsZ: llm.TFSZ,
  400. TypicalP: llm.TypicalP,
  401. RepeatLastN: llm.RepeatLastN,
  402. RepeatPenalty: llm.RepeatPenalty,
  403. PresencePenalty: llm.PresencePenalty,
  404. FrequencyPenalty: llm.FrequencyPenalty,
  405. Mirostat: llm.Mirostat,
  406. MirostatTau: llm.MirostatTau,
  407. MirostatEta: llm.MirostatEta,
  408. PenalizeNl: llm.PenalizeNewline,
  409. Stop: llm.Stop,
  410. }
  411. data, err := json.Marshal(predReq)
  412. if err != nil {
  413. return fmt.Errorf("error marshaling data: %v", err)
  414. }
  415. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  416. if err != nil {
  417. return fmt.Errorf("error creating POST request: %v", err)
  418. }
  419. req.Header.Set("Content-Type", "application/json")
  420. resp, err := http.DefaultClient.Do(req)
  421. if err != nil {
  422. return fmt.Errorf("POST predict: %v", err)
  423. }
  424. defer resp.Body.Close()
  425. if resp.StatusCode >= 400 {
  426. bodyBytes, err := io.ReadAll(resp.Body)
  427. if err != nil {
  428. return fmt.Errorf("failed reading llm error response: %w", err)
  429. }
  430. log.Printf("llm predict error: %s", bodyBytes)
  431. return fmt.Errorf("%s", bodyBytes)
  432. }
  433. scanner := bufio.NewScanner(resp.Body)
  434. genCtx := trimmedPrompt // start with the trimmed prompt
  435. for scanner.Scan() {
  436. select {
  437. case <-ctx.Done():
  438. // This handles the request cancellation
  439. return ctx.Err()
  440. default:
  441. line := scanner.Text()
  442. if line == "" {
  443. continue
  444. }
  445. // Read data from the server-side event stream
  446. if strings.HasPrefix(line, "data: ") {
  447. evt := line[6:]
  448. var complete PredictComplete
  449. if err := json.Unmarshal([]byte(evt), &complete); err != nil {
  450. return fmt.Errorf("error unmarshaling llm complete response: %v", err)
  451. }
  452. if complete.Timings.PredictedMS > 0 {
  453. genCtx += complete.Content
  454. embd, err := llm.Encode(ctx, genCtx)
  455. if err != nil {
  456. return fmt.Errorf("encoding context: %v", err)
  457. }
  458. fn(api.GenerateResponse{
  459. Done: true,
  460. Context: embd,
  461. PromptEvalCount: int(complete.Timings.PromptN),
  462. PromptEvalDuration: parseDurationMs(float64(complete.Timings.PromptMS)),
  463. EvalCount: int(complete.Timings.PredictedN),
  464. EvalDuration: parseDurationMs(float64(complete.Timings.PredictedMS)),
  465. })
  466. return nil
  467. }
  468. var pred Prediction
  469. if err := json.Unmarshal([]byte(evt), &pred); err != nil {
  470. return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
  471. }
  472. genCtx += pred.Content
  473. fn(api.GenerateResponse{Response: pred.Content})
  474. }
  475. }
  476. }
  477. if err := scanner.Err(); err != nil {
  478. return fmt.Errorf("error reading llm response: %v", err)
  479. }
  480. return nil
  481. }
  482. func (llm *llama) marshalPrompt(ctx context.Context, pCtx []int, prompt string) (string, error) {
  483. pEncode, err := llm.Encode(ctx, prompt)
  484. if err != nil {
  485. return "", fmt.Errorf("encoding prompt context: %w", err)
  486. }
  487. tokens := append(pCtx, pEncode...)
  488. if llm.NumKeep < 0 {
  489. llm.NumKeep = len(tokens)
  490. }
  491. // min(llm.NumCtx - 4, llm.NumKeep)
  492. if llm.NumCtx-4 < llm.NumKeep {
  493. llm.NumKeep = llm.NumCtx - 4
  494. }
  495. if len(tokens) >= llm.NumCtx {
  496. // truncate input
  497. numLeft := (llm.NumCtx - llm.NumKeep) / 2
  498. truncated := tokens[:llm.NumKeep]
  499. erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
  500. truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
  501. tokens = truncated
  502. log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
  503. }
  504. return llm.Decode(ctx, tokens)
  505. }
  506. type TokenizeRequest struct {
  507. Content string `json:"content"`
  508. }
  509. type TokenizeResponse struct {
  510. Tokens []int `json:"tokens"`
  511. }
  512. func (llm *llama) Encode(ctx context.Context, prompt string) ([]int, error) {
  513. endpoint := fmt.Sprintf("http://127.0.0.1:%d/tokenize", llm.Port)
  514. data, err := json.Marshal(TokenizeRequest{Content: prompt})
  515. if err != nil {
  516. return nil, fmt.Errorf("marshaling encode data: %w", err)
  517. }
  518. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  519. if err != nil {
  520. return nil, fmt.Errorf("encode request: %w", err)
  521. }
  522. req.Header.Set("Content-Type", "application/json")
  523. resp, err := http.DefaultClient.Do(req)
  524. if err != nil {
  525. return nil, fmt.Errorf("do encode request: %w", err)
  526. }
  527. defer resp.Body.Close()
  528. body, err := io.ReadAll(resp.Body)
  529. if err != nil {
  530. return nil, fmt.Errorf("read encode request: %w", err)
  531. }
  532. if resp.StatusCode >= 400 {
  533. log.Printf("llm encode error: %s", body)
  534. return nil, fmt.Errorf("%s", body)
  535. }
  536. var encoded TokenizeResponse
  537. if err := json.Unmarshal(body, &encoded); err != nil {
  538. return nil, fmt.Errorf("unmarshal encode response: %w", err)
  539. }
  540. return encoded.Tokens, nil
  541. }
  542. type DetokenizeRequest struct {
  543. Tokens []int `json:"tokens"`
  544. }
  545. type DetokenizeResponse struct {
  546. Content string `json:"content"`
  547. }
  548. func (llm *llama) Decode(ctx context.Context, tokens []int) (string, error) {
  549. if len(tokens) == 0 {
  550. return "", nil
  551. }
  552. endpoint := fmt.Sprintf("http://127.0.0.1:%d/detokenize", llm.Port)
  553. data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
  554. if err != nil {
  555. return "", fmt.Errorf("marshaling decode data: %w", err)
  556. }
  557. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  558. if err != nil {
  559. return "", fmt.Errorf("decode request: %w", err)
  560. }
  561. req.Header.Set("Content-Type", "application/json")
  562. resp, err := http.DefaultClient.Do(req)
  563. if err != nil {
  564. return "", fmt.Errorf("do decode request: %w", err)
  565. }
  566. defer resp.Body.Close()
  567. body, err := io.ReadAll(resp.Body)
  568. if err != nil {
  569. return "", fmt.Errorf("read decode request: %w", err)
  570. }
  571. if resp.StatusCode >= 400 {
  572. log.Printf("llm decode error: %s", body)
  573. return "", fmt.Errorf("%s", body)
  574. }
  575. var decoded DetokenizeResponse
  576. if err := json.Unmarshal(body, &decoded); err != nil {
  577. return "", fmt.Errorf("unmarshal encode response: %w", err)
  578. }
  579. // decoded content contains a leading whitespace
  580. decoded.Content, _ = strings.CutPrefix(decoded.Content, "")
  581. return decoded.Content, nil
  582. }
  583. type EmbeddingRequest struct {
  584. Content string `json:"content"`
  585. }
  586. type EmbeddingResponse struct {
  587. Embedding []float64 `json:"embedding"`
  588. }
  589. func (llm *llama) Embedding(ctx context.Context, input string) ([]float64, error) {
  590. endpoint := fmt.Sprintf("http://127.0.0.1:%d/embedding", llm.Port)
  591. data, err := json.Marshal(TokenizeRequest{Content: input})
  592. if err != nil {
  593. return nil, fmt.Errorf("error marshaling embed data: %w", err)
  594. }
  595. req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
  596. if err != nil {
  597. return nil, fmt.Errorf("error creating embed request: %w", err)
  598. }
  599. req.Header.Set("Content-Type", "application/json")
  600. resp, err := http.DefaultClient.Do(req)
  601. if err != nil {
  602. return nil, fmt.Errorf("POST embedding: %w", err)
  603. }
  604. defer resp.Body.Close()
  605. body, err := io.ReadAll(resp.Body)
  606. if err != nil {
  607. return nil, fmt.Errorf("error reading embed response: %w", err)
  608. }
  609. if resp.StatusCode >= 400 {
  610. log.Printf("llm encode error: %s", body)
  611. return nil, fmt.Errorf("%s", body)
  612. }
  613. var embedding EmbeddingResponse
  614. if err := json.Unmarshal(body, &embedding); err != nil {
  615. return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
  616. }
  617. return embedding.Embedding, nil
  618. }
  619. // Ping checks that the server subprocess is still running and responding to requests
  620. func (llm *llama) Ping(ctx context.Context) error {
  621. resp, err := http.Head(fmt.Sprintf("http://127.0.0.1:%d", llm.Running.Port))
  622. if err != nil {
  623. return fmt.Errorf("ping resp: %w", err)
  624. }
  625. if resp.StatusCode != http.StatusOK {
  626. return fmt.Errorf("unexpected ping status: %s", resp.Status)
  627. }
  628. return nil
  629. }