routes.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. package server
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "io"
  6. "log"
  7. "net"
  8. "net/http"
  9. "path"
  10. "runtime"
  11. "strings"
  12. "text/template"
  13. "github.com/gin-gonic/gin"
  14. "github.com/lithammer/fuzzysearch/fuzzy"
  15. "github.com/jmorganca/ollama/api"
  16. "github.com/jmorganca/ollama/llama"
  17. )
  18. var templates = template.Must(template.ParseGlob("templates/*.prompt"))
  19. func generate(c *gin.Context) {
  20. // TODO: these should be request parameters
  21. gpulayers := 1
  22. tokens := 512
  23. threads := runtime.NumCPU()
  24. var req api.GenerateRequest
  25. if err := c.ShouldBindJSON(&req); err != nil {
  26. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  27. return
  28. }
  29. l, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers))
  30. if err != nil {
  31. fmt.Println("Loading the model failed:", err.Error())
  32. return
  33. }
  34. templateNames := make([]string, 0, len(templates.Templates()))
  35. for _, template := range templates.Templates() {
  36. templateNames = append(templateNames, template.Name())
  37. }
  38. match, _ := matchRankOne(path.Base(req.Prompt), templateNames)
  39. if template := templates.Lookup(match); template != nil {
  40. var sb strings.Builder
  41. if err := template.Execute(&sb, req); err != nil {
  42. fmt.Println("Prompt template failed:", err.Error())
  43. return
  44. }
  45. req.Prompt = sb.String()
  46. }
  47. ch := make(chan string)
  48. go func() {
  49. defer close(ch)
  50. _, err := l.Predict(req.Prompt, llama.Debug, llama.SetTokenCallback(func(token string) bool {
  51. ch <- token
  52. return true
  53. }), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
  54. if err != nil {
  55. panic(err)
  56. }
  57. }()
  58. c.Stream(func(w io.Writer) bool {
  59. token, ok := <-ch
  60. if !ok {
  61. return false
  62. }
  63. resp := api.TokenResponse{
  64. Choices: []api.TokenResponseChoice{
  65. {
  66. Text: token,
  67. },
  68. },
  69. }
  70. bts, err := json.Marshal(resp)
  71. if err != nil {
  72. return false
  73. }
  74. bts = append(bts, '\n')
  75. if _, err := w.Write(bts); err != nil {
  76. return false
  77. }
  78. return true
  79. })
  80. }
  81. func Serve(ln net.Listener) error {
  82. r := gin.Default()
  83. r.POST("api/pull", func(c *gin.Context) {
  84. var req api.PullRequest
  85. if err := c.ShouldBindJSON(&req); err != nil {
  86. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  87. return
  88. }
  89. progressCh := make(chan string)
  90. go func() {
  91. defer close(progressCh)
  92. if err := pull(req.Model, progressCh); err != nil {
  93. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  94. return
  95. }
  96. }()
  97. c.Stream(func(w io.Writer) bool {
  98. progress, ok := <-progressCh
  99. if !ok {
  100. return false
  101. }
  102. c.SSEvent("progress", progress)
  103. return true
  104. })
  105. })
  106. r.POST("/api/generate", generate)
  107. log.Printf("Listening on %s", ln.Addr())
  108. s := &http.Server{
  109. Handler: r,
  110. }
  111. return s.Serve(ln)
  112. }
  113. func matchRankOne(source string, targets []string) (bestMatch string, bestRank int) {
  114. for _, target := range targets {
  115. if rank := fuzzy.LevenshteinDistance(source, target); bestRank < rank {
  116. bestRank = rank
  117. bestMatch = target
  118. }
  119. }
  120. return
  121. }