routes.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package server
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "io"
  6. "log"
  7. "net"
  8. "net/http"
  9. "path"
  10. "runtime"
  11. "strings"
  12. "text/template"
  13. "github.com/gin-gonic/gin"
  14. "github.com/lithammer/fuzzysearch/fuzzy"
  15. "github.com/jmorganca/ollama/api"
  16. "github.com/jmorganca/ollama/llama"
  17. )
  18. var templates = template.Must(template.ParseGlob("templates/*.prompt"))
  19. func generate(c *gin.Context) {
  20. // TODO: these should be request parameters
  21. gpulayers := 1
  22. tokens := 512
  23. threads := runtime.NumCPU()
  24. var req api.GenerateRequest
  25. if err := c.ShouldBindJSON(&req); err != nil {
  26. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  27. return
  28. }
  29. model, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers))
  30. if err != nil {
  31. fmt.Println("Loading the model failed:", err.Error())
  32. return
  33. }
  34. defer model.Free()
  35. templateNames := make([]string, 0, len(templates.Templates()))
  36. for _, template := range templates.Templates() {
  37. templateNames = append(templateNames, template.Name())
  38. }
  39. match, _ := matchRankOne(path.Base(req.Prompt), templateNames)
  40. if template := templates.Lookup(match); template != nil {
  41. var sb strings.Builder
  42. if err := template.Execute(&sb, req); err != nil {
  43. fmt.Println("Prompt template failed:", err.Error())
  44. return
  45. }
  46. req.Prompt = sb.String()
  47. }
  48. ch := make(chan string)
  49. go func() {
  50. defer close(ch)
  51. _, err := model.Predict(req.Prompt, llama.Debug, llama.SetTokenCallback(func(token string) bool {
  52. ch <- token
  53. return true
  54. }), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
  55. if err != nil {
  56. panic(err)
  57. }
  58. }()
  59. c.Stream(func(w io.Writer) bool {
  60. token, ok := <-ch
  61. if !ok {
  62. return false
  63. }
  64. resp := api.TokenResponse{
  65. Choices: []api.TokenResponseChoice{
  66. {
  67. Text: token,
  68. },
  69. },
  70. }
  71. bts, err := json.Marshal(resp)
  72. if err != nil {
  73. return false
  74. }
  75. bts = append(bts, '\n')
  76. if _, err := w.Write(bts); err != nil {
  77. return false
  78. }
  79. return true
  80. })
  81. }
  82. func Serve(ln net.Listener) error {
  83. r := gin.Default()
  84. r.POST("api/pull", func(c *gin.Context) {
  85. var req api.PullRequest
  86. if err := c.ShouldBindJSON(&req); err != nil {
  87. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  88. return
  89. }
  90. progressCh := make(chan string)
  91. go func() {
  92. defer close(progressCh)
  93. if err := pull(req.Model, progressCh); err != nil {
  94. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  95. return
  96. }
  97. }()
  98. c.Stream(func(w io.Writer) bool {
  99. progress, ok := <-progressCh
  100. if !ok {
  101. return false
  102. }
  103. c.SSEvent("progress", progress)
  104. return true
  105. })
  106. })
  107. r.POST("/api/generate", generate)
  108. log.Printf("Listening on %s", ln.Addr())
  109. s := &http.Server{
  110. Handler: r,
  111. }
  112. return s.Serve(ln)
  113. }
  114. func matchRankOne(source string, targets []string) (bestMatch string, bestRank int) {
  115. for _, target := range targets {
  116. if rank := fuzzy.LevenshteinDistance(source, target); bestRank < rank {
  117. bestRank = rank
  118. bestMatch = target
  119. }
  120. }
  121. return
  122. }