routes.go 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. package server
  2. import (
  3. "fmt"
  4. "io"
  5. "log"
  6. "net"
  7. "net/http"
  8. "runtime"
  9. "github.com/gin-gonic/gin"
  10. llama "github.com/jmorganca/ollama/llama"
  11. "github.com/jmorganca/ollama/api"
  12. )
  13. func Serve(ln net.Listener) error {
  14. r := gin.Default()
  15. // TODO: these should be request parameters
  16. gpulayers := 0
  17. tokens := 512
  18. threads := runtime.NumCPU()
  19. r.POST("/api/generate", func(c *gin.Context) {
  20. // TODO: set prompt from template
  21. fmt.Println("Generating text...")
  22. var req api.GenerateRequest
  23. if err := c.ShouldBindJSON(&req); err != nil {
  24. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  25. return
  26. }
  27. fmt.Println(req)
  28. l, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers))
  29. if err != nil {
  30. fmt.Println("Loading the model failed:", err.Error())
  31. return
  32. }
  33. ch := make(chan string)
  34. go func() {
  35. defer close(ch)
  36. _, err := l.Predict(req.Prompt, llama.Debug, llama.SetTokenCallback(func(token string) bool {
  37. ch <- token
  38. return true
  39. }), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
  40. if err != nil {
  41. panic(err)
  42. }
  43. }()
  44. c.Stream(func(w io.Writer) bool {
  45. tok, ok := <-ch
  46. if !ok {
  47. return false
  48. }
  49. c.SSEvent("token", tok)
  50. return true
  51. })
  52. })
  53. log.Printf("Listening on %s", ln.Addr())
  54. s := &http.Server{
  55. Handler: r,
  56. }
  57. return s.Serve(ln)
  58. }