routes.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. package server
  2. import (
  3. "fmt"
  4. "io"
  5. "log"
  6. "net"
  7. "net/http"
  8. "runtime"
  9. "github.com/gin-gonic/gin"
  10. llama "github.com/jmorganca/ollama/llama"
  11. "github.com/jmorganca/ollama/api"
  12. )
  13. func Serve(ln net.Listener) error {
  14. r := gin.Default()
  15. var l *llama.LLama
  16. gpulayers := 1
  17. tokens := 512
  18. threads := runtime.NumCPU()
  19. model := "/Users/pdevine/.cache/gpt4all/GPT4All-13B-snoozy.ggmlv3.q4_0.bin"
  20. r.POST("/api/load", func(c *gin.Context) {
  21. var err error
  22. l, err = llama.New(model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers))
  23. if err != nil {
  24. fmt.Println("Loading the model failed:", err.Error())
  25. }
  26. })
  27. r.POST("/api/unload", func(c *gin.Context) {
  28. })
  29. r.POST("/api/generate", func(c *gin.Context) {
  30. var req api.GenerateRequest
  31. if err := c.ShouldBindJSON(&req); err != nil {
  32. c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
  33. return
  34. }
  35. ch := make(chan string)
  36. go func() {
  37. defer close(ch)
  38. _, err := l.Predict(req.Prompt, llama.Debug, llama.SetTokenCallback(func(token string) bool {
  39. ch <- token
  40. return true
  41. }), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
  42. if err != nil {
  43. panic(err)
  44. }
  45. }()
  46. c.Stream(func(w io.Writer) bool {
  47. tok, ok := <-ch
  48. if !ok {
  49. return false
  50. }
  51. c.SSEvent("token", tok)
  52. return true
  53. })
  54. /*
  55. embeds, err := l.Embeddings(text)
  56. if err != nil {
  57. fmt.Printf("Embeddings: error %s \n", err.Error())
  58. }
  59. */
  60. })
  61. log.Printf("Listening on %s", ln.Addr())
  62. s := &http.Server{
  63. Handler: r,
  64. }
  65. return s.Serve(ln)
  66. }