llama.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. package llm
  2. import (
  3. "bytes"
  4. "context"
  5. _ "embed"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "io/fs"
  10. "os"
  11. "os/exec"
  12. "path/filepath"
  13. "sync"
  14. "time"
  15. "github.com/jmorganca/ollama/api"
  16. "github.com/jmorganca/ollama/format"
  17. )
  18. const jsonGrammar = `
  19. root ::= object
  20. value ::= object | array | string | number | ("true" | "false" | "null") ws
  21. object ::=
  22. "{" ws (
  23. string ":" ws value
  24. ("," ws string ":" ws value)*
  25. )? "}" ws
  26. array ::=
  27. "[" ws (
  28. value
  29. ("," ws value)*
  30. )? "]" ws
  31. string ::=
  32. "\"" (
  33. [^"\\] |
  34. "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
  35. )* "\"" ws
  36. number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
  37. # Optional space: by convention, applied in this grammar after literal chars when allowed
  38. ws ::= ([ \t\n] ws)?
  39. `
  40. type llamaModel struct {
  41. hyperparameters llamaHyperparameters
  42. }
  43. func (llm *llamaModel) ModelFamily() string {
  44. return "llama"
  45. }
  46. func llamaModelType(numLayer uint32) string {
  47. switch numLayer {
  48. case 26:
  49. return "3B"
  50. case 32:
  51. return "7B"
  52. case 40:
  53. return "13B"
  54. case 48:
  55. return "34B"
  56. case 60:
  57. return "30B"
  58. case 80:
  59. return "65B"
  60. default:
  61. return "unknown"
  62. }
  63. }
  64. func (llm *llamaModel) ModelType() string {
  65. return llamaModelType(llm.hyperparameters.NumLayer)
  66. }
  67. func (llm *llamaModel) FileType() string {
  68. return fileType(llm.hyperparameters.FileType)
  69. }
  70. func (llm *llamaModel) NumLayers() int64 {
  71. return int64(llm.hyperparameters.NumLayer)
  72. }
  73. type llamaHyperparameters struct {
  74. // NumVocab is the size of the model's vocabulary.
  75. NumVocab uint32
  76. // NumEmbd is the size of the model's embedding layer.
  77. NumEmbd uint32
  78. NumMult uint32
  79. NumHead uint32
  80. // NumLayer is the number of layers in the model.
  81. NumLayer uint32
  82. NumRot uint32
  83. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
  84. FileType uint32
  85. }
  86. type Running struct {
  87. Port int
  88. Cmd *exec.Cmd
  89. Cancel context.CancelFunc
  90. exitOnce sync.Once
  91. exitCh chan error // channel to receive the exit status of the subprocess
  92. *StatusWriter // captures error messages from the llama runner process
  93. }
  94. type ImageData struct {
  95. Data []byte `json:"data"`
  96. ID int `json:"id"`
  97. }
  98. var (
  99. errNvidiaSMI = errors.New("warning: gpu support may not be enabled, check that you have installed GPU drivers: nvidia-smi command failed")
  100. errAvailableVRAM = errors.New("not enough VRAM available, falling back to CPU only")
  101. payloadMissing = fmt.Errorf("expected dynamic library payloads not included in this build of ollama")
  102. )
  103. // StatusWriter is a writer that captures error messages from the llama runner process
  104. type StatusWriter struct {
  105. ErrCh chan error
  106. LastErrMsg string
  107. }
  108. func NewStatusWriter() *StatusWriter {
  109. return &StatusWriter{
  110. ErrCh: make(chan error, 1),
  111. }
  112. }
  113. func (w *StatusWriter) Write(b []byte) (int, error) {
  114. var errMsg string
  115. if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
  116. errMsg = string(bytes.TrimSpace(after))
  117. } else if _, after, ok := bytes.Cut(b, []byte("CUDA error")); ok {
  118. errMsg = string(bytes.TrimSpace(after))
  119. }
  120. if errMsg != "" {
  121. w.LastErrMsg = errMsg
  122. w.ErrCh <- fmt.Errorf("llama runner: %s", errMsg)
  123. }
  124. return os.Stderr.Write(b)
  125. }
  126. type prediction struct {
  127. Content string `json:"content"`
  128. Model string `json:"model"`
  129. Prompt string `json:"prompt"`
  130. Stop bool `json:"stop"`
  131. Timings struct {
  132. PredictedN int `json:"predicted_n"`
  133. PredictedMS float64 `json:"predicted_ms"`
  134. PromptN int `json:"prompt_n"`
  135. PromptMS float64 `json:"prompt_ms"`
  136. }
  137. }
  138. const maxBufferSize = 512 * format.KiloByte
  139. const maxRetries = 3
  140. const retryDelay = 1 * time.Second
  141. type PredictOpts struct {
  142. Prompt string
  143. Format string
  144. Images []api.ImageData
  145. }
  146. type PredictResult struct {
  147. Content string
  148. Done bool
  149. PromptEvalCount int
  150. PromptEvalDuration time.Duration
  151. EvalCount int
  152. EvalDuration time.Duration
  153. }
  154. type TokenizeRequest struct {
  155. Content string `json:"content"`
  156. }
  157. type TokenizeResponse struct {
  158. Tokens []int `json:"tokens"`
  159. }
  160. type DetokenizeRequest struct {
  161. Tokens []int `json:"tokens"`
  162. }
  163. type DetokenizeResponse struct {
  164. Content string `json:"content"`
  165. }
  166. type EmbeddingRequest struct {
  167. Content string `json:"content"`
  168. }
  169. type EmbeddingResponse struct {
  170. Embedding []float64 `json:"embedding"`
  171. }
  172. func extractDynamicLibs(workDir, glob string) ([]string, error) {
  173. files, err := fs.Glob(libEmbed, glob)
  174. if err != nil || len(files) == 0 {
  175. return nil, payloadMissing
  176. }
  177. libs := make([]string, len(files))
  178. for i, file := range files {
  179. srcFile, err := libEmbed.Open(file)
  180. if err != nil {
  181. return nil, fmt.Errorf("read payload %s: %v", file, err)
  182. }
  183. defer srcFile.Close()
  184. if err := os.MkdirAll(workDir, 0o755); err != nil {
  185. return nil, fmt.Errorf("create payload temp dir %s: %v", workDir, err)
  186. }
  187. destFile := filepath.Join(workDir, filepath.Base(file))
  188. libs[i] = destFile
  189. _, err = os.Stat(destFile)
  190. switch {
  191. case errors.Is(err, os.ErrNotExist):
  192. destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  193. if err != nil {
  194. return nil, fmt.Errorf("write payload %s: %v", file, err)
  195. }
  196. defer destFile.Close()
  197. if _, err := io.Copy(destFile, srcFile); err != nil {
  198. return nil, fmt.Errorf("copy payload %s: %v", file, err)
  199. }
  200. case err != nil:
  201. return nil, fmt.Errorf("stat payload %s: %v", file, err)
  202. }
  203. }
  204. return libs, nil
  205. }