llama.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. package llm
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "os"
  8. "os/exec"
  9. "sync"
  10. "time"
  11. "github.com/jmorganca/ollama/api"
  12. "github.com/jmorganca/ollama/format"
  13. )
  14. const jsonGrammar = `
  15. root ::= object
  16. value ::= object | array | string | number | ("true" | "false" | "null") ws
  17. object ::=
  18. "{" ws (
  19. string ":" ws value
  20. ("," ws string ":" ws value)*
  21. )? "}" ws
  22. array ::=
  23. "[" ws (
  24. value
  25. ("," ws value)*
  26. )? "]" ws
  27. string ::=
  28. "\"" (
  29. [^"\\] |
  30. "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
  31. )* "\"" ws
  32. number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
  33. # Optional space: by convention, applied in this grammar after literal chars when allowed
  34. ws ::= ([ \t\n] ws)?
  35. `
  36. type llamaModel struct {
  37. hyperparameters llamaHyperparameters
  38. }
  39. func (llm *llamaModel) ModelFamily() string {
  40. return "llama"
  41. }
  42. func llamaModelType(numLayer uint32) string {
  43. switch numLayer {
  44. case 26:
  45. return "3B"
  46. case 32:
  47. return "7B"
  48. case 40:
  49. return "13B"
  50. case 48:
  51. return "34B"
  52. case 60:
  53. return "30B"
  54. case 80:
  55. return "65B"
  56. default:
  57. return "unknown"
  58. }
  59. }
  60. func (llm *llamaModel) ModelType() string {
  61. return llamaModelType(llm.hyperparameters.NumLayer)
  62. }
  63. func (llm *llamaModel) FileType() string {
  64. return fileType(llm.hyperparameters.FileType)
  65. }
  66. func (llm *llamaModel) NumLayers() int64 {
  67. return int64(llm.hyperparameters.NumLayer)
  68. }
  69. type llamaHyperparameters struct {
  70. // NumVocab is the size of the model's vocabulary.
  71. NumVocab uint32
  72. // NumEmbd is the size of the model's embedding layer.
  73. NumEmbd uint32
  74. NumMult uint32
  75. NumHead uint32
  76. // NumLayer is the number of layers in the model.
  77. NumLayer uint32
  78. NumRot uint32
  79. // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
  80. FileType uint32
  81. }
  82. type Running struct {
  83. Port int
  84. Cmd *exec.Cmd
  85. Cancel context.CancelFunc
  86. exitOnce sync.Once
  87. exitCh chan error // channel to receive the exit status of the subprocess
  88. *StatusWriter // captures error messages from the llama runner process
  89. }
  90. type ImageData struct {
  91. Data []byte `json:"data"`
  92. ID int `json:"id"`
  93. }
  94. type llama struct {
  95. api.Options
  96. ImageData []ImageData
  97. Running
  98. }
  99. var (
  100. errNvidiaSMI = errors.New("warning: gpu support may not be enabled, check that you have installed GPU drivers: nvidia-smi command failed")
  101. errAvailableVRAM = errors.New("not enough VRAM available, falling back to CPU only")
  102. )
  103. // StatusWriter is a writer that captures error messages from the llama runner process
  104. type StatusWriter struct {
  105. ErrCh chan error
  106. LastErrMsg string
  107. }
  108. func NewStatusWriter() *StatusWriter {
  109. return &StatusWriter{
  110. ErrCh: make(chan error, 1),
  111. }
  112. }
  113. func (w *StatusWriter) Write(b []byte) (int, error) {
  114. var errMsg string
  115. if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
  116. errMsg = string(bytes.TrimSpace(after))
  117. } else if _, after, ok := bytes.Cut(b, []byte("CUDA error")); ok {
  118. errMsg = string(bytes.TrimSpace(after))
  119. }
  120. if errMsg != "" {
  121. w.LastErrMsg = errMsg
  122. w.ErrCh <- fmt.Errorf("llama runner: %s", errMsg)
  123. }
  124. return os.Stderr.Write(b)
  125. }
  126. type prediction struct {
  127. Content string `json:"content"`
  128. Model string `json:"model"`
  129. Prompt string `json:"prompt"`
  130. Stop bool `json:"stop"`
  131. Timings struct {
  132. PredictedN int `json:"predicted_n"`
  133. PredictedMS float64 `json:"predicted_ms"`
  134. PromptN int `json:"prompt_n"`
  135. PromptMS float64 `json:"prompt_ms"`
  136. }
  137. }
  138. const maxBufferSize = 512 * format.KiloByte
  139. const maxRetries = 6
  140. type PredictOpts struct {
  141. Prompt string
  142. Format string
  143. Images []api.ImageData
  144. }
  145. type PredictResult struct {
  146. Content string
  147. Done bool
  148. PromptEvalCount int
  149. PromptEvalDuration time.Duration
  150. EvalCount int
  151. EvalDuration time.Duration
  152. }
  153. type TokenizeRequest struct {
  154. Content string `json:"content"`
  155. }
  156. type TokenizeResponse struct {
  157. Tokens []int `json:"tokens"`
  158. }
  159. type DetokenizeRequest struct {
  160. Tokens []int `json:"tokens"`
  161. }
  162. type DetokenizeResponse struct {
  163. Content string `json:"content"`
  164. }
  165. type EmbeddingRequest struct {
  166. Content string `json:"content"`
  167. }
  168. type EmbeddingResponse struct {
  169. Embedding []float64 `json:"embedding"`
  170. }