llama.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. package llm
  2. import (
  3. "bytes"
  4. "context"
  5. _ "embed"
  6. "errors"
  7. "fmt"
  8. "os"
  9. "os/exec"
  10. "sync"
  11. "time"
  12. "github.com/jmorganca/ollama/api"
  13. "github.com/jmorganca/ollama/format"
  14. )
  15. const jsonGrammar = `
  16. root ::= object
  17. value ::= object | array | string | number | ("true" | "false" | "null") ws
  18. object ::=
  19. "{" ws (
  20. string ":" ws value
  21. ("," ws string ":" ws value)*
  22. )? "}" ws
  23. array ::=
  24. "[" ws (
  25. value
  26. ("," ws value)*
  27. )? "]" ws
  28. string ::=
  29. "\"" (
  30. [^"\\] |
  31. "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
  32. )* "\"" ws
  33. number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
  34. # Optional space: by convention, applied in this grammar after literal chars when allowed
  35. ws ::= ([ \t\n] ws)?
  36. `
  37. type llamaModel struct {
  38. hyperparameters llamaHyperparameters
  39. }
  40. func (llm *llamaModel) ModelFamily() string {
  41. return "llama"
  42. }
  43. func llamaModelType(numLayer uint32) string {
  44. switch numLayer {
  45. case 26:
  46. return "3B"
  47. case 32:
  48. return "7B"
  49. case 40:
  50. return "13B"
  51. case 48:
  52. return "34B"
  53. case 60:
  54. return "30B"
  55. case 80:
  56. return "65B"
  57. default:
  58. return "unknown"
  59. }
  60. }
  61. func (llm *llamaModel) ModelType() string {
  62. return llamaModelType(llm.hyperparameters.NumLayer)
  63. }
  64. func (llm *llamaModel) FileType() string {
  65. return fileType(llm.hyperparameters.FileType)
  66. }
  67. func (llm *llamaModel) NumLayers() int64 {
  68. return int64(llm.hyperparameters.NumLayer)
  69. }
  70. type llamaHyperparameters struct {
  71. // NumVocab is the size of the model's vocabulary.
  72. NumVocab uint32
  73. // NumEmbd is the size of the model's embedding layer.
  74. NumEmbd uint32
  75. NumMult uint32
  76. NumHead uint32
  77. // NumLayer is the number of layers in the model.
  78. NumLayer uint32
  79. NumRot uint32
  80. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
  81. FileType uint32
  82. }
  83. type Running struct {
  84. Port int
  85. Cmd *exec.Cmd
  86. Cancel context.CancelFunc
  87. exitOnce sync.Once
  88. exitCh chan error // channel to receive the exit status of the subprocess
  89. *StatusWriter // captures error messages from the llama runner process
  90. }
  91. type ImageData struct {
  92. Data []byte `json:"data"`
  93. ID int `json:"id"`
  94. }
  95. var (
  96. errNvidiaSMI = errors.New("warning: gpu support may not be enabled, check that you have installed GPU drivers: nvidia-smi command failed")
  97. errAvailableVRAM = errors.New("not enough VRAM available, falling back to CPU only")
  98. payloadMissing = fmt.Errorf("expected dynamic library payloads not included in this build of ollama")
  99. )
  100. // StatusWriter is a writer that captures error messages from the llama runner process
  101. type StatusWriter struct {
  102. ErrCh chan error
  103. LastErrMsg string
  104. }
  105. func NewStatusWriter() *StatusWriter {
  106. return &StatusWriter{
  107. ErrCh: make(chan error, 1),
  108. }
  109. }
  110. func (w *StatusWriter) Write(b []byte) (int, error) {
  111. var errMsg string
  112. if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
  113. errMsg = string(bytes.TrimSpace(after))
  114. } else if _, after, ok := bytes.Cut(b, []byte("CUDA error")); ok {
  115. errMsg = string(bytes.TrimSpace(after))
  116. }
  117. if errMsg != "" {
  118. w.LastErrMsg = errMsg
  119. w.ErrCh <- fmt.Errorf("llama runner: %s", errMsg)
  120. }
  121. return os.Stderr.Write(b)
  122. }
  123. type prediction struct {
  124. Content string `json:"content"`
  125. Model string `json:"model"`
  126. Prompt string `json:"prompt"`
  127. Stop bool `json:"stop"`
  128. Timings struct {
  129. PredictedN int `json:"predicted_n"`
  130. PredictedMS float64 `json:"predicted_ms"`
  131. PromptN int `json:"prompt_n"`
  132. PromptMS float64 `json:"prompt_ms"`
  133. }
  134. }
  135. const maxBufferSize = 512 * format.KiloByte
  136. const maxRetries = 3
  137. const retryDelay = 1 * time.Second
  138. type PredictOpts struct {
  139. Prompt string
  140. Format string
  141. Images []api.ImageData
  142. Options api.Options
  143. }
  144. type PredictResult struct {
  145. Content string
  146. Done bool
  147. PromptEvalCount int
  148. PromptEvalDuration time.Duration
  149. EvalCount int
  150. EvalDuration time.Duration
  151. }
  152. type TokenizeRequest struct {
  153. Content string `json:"content"`
  154. }
  155. type TokenizeResponse struct {
  156. Tokens []int `json:"tokens"`
  157. }
  158. type DetokenizeRequest struct {
  159. Tokens []int `json:"tokens"`
  160. }
  161. type DetokenizeResponse struct {
  162. Content string `json:"content"`
  163. }
  164. type EmbeddingRequest struct {
  165. Content string `json:"content"`
  166. }
  167. type EmbeddingResponse struct {
  168. Embedding []float64 `json:"embedding"`
  169. }