123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- package main
- import (
- "encoding/json"
- "flag"
- "fmt"
- "log"
- "net"
- "net/http"
- "sync"
- "github.com/ollama/ollama/llama"
- )
- type Request struct {
- Prompt string `json:"prompt"`
- }
- type Response struct {
- Token string `json:"token"`
- }
- type Server struct {
- model *llama.Model
- lc *llama.Context
- }
- var mu sync.Mutex
- func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
- var request Request
- if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
- http.Error(w, "Bad request", http.StatusBadRequest)
- return
- }
- // Set the headers to indicate streaming
- w.Header().Set("Content-Type", "application/json")
- w.Header().Set("Transfer-Encoding", "chunked")
- w.WriteHeader(http.StatusOK)
- enc := json.NewEncoder(w)
- // main loop
- tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
- if err != nil {
- panic(err)
- }
- batch := llama.NewBatch(512, 0, 1)
- // prompt eval
- for i, t := range tokens {
- batch.Add(t, llama.Pos(i), []llama.SeqId{0}, true)
- }
- // main loop
- for n := batch.NumTokens(); n < 2048; n++ {
- mu.Lock()
- err = s.lc.Decode(batch)
- if err != nil {
- panic("Failed to decode")
- }
- // sample a token
- token := s.lc.SampleTokenGreedy(batch)
- mu.Unlock()
- // if it's an end of sequence token, break
- if s.model.TokenIsEog(token) {
- break
- }
- // print the token
- str := s.model.TokenToPiece(token)
- if err := enc.Encode(&Response{Token: str}); err != nil {
- log.Println("Failed to encode result:", err)
- return
- }
- w.(http.Flusher).Flush()
- batch.Clear()
- batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true)
- }
- }
- func main() {
- mp := flag.String("model", "", "Path to model binary file")
- flag.Parse()
- // load the model
- llama.BackendInit()
- params := llama.NewModelParams()
- model := llama.LoadModelFromFile(*mp, params)
- ctxParams := llama.NewContextParams()
- lc := llama.NewContextWithModel(model, ctxParams)
- if lc == nil {
- panic("Failed to create context")
- }
- server := &Server{
- model: model,
- lc: lc,
- }
- addr := "127.0.0.1:8080"
- listener, err := net.Listen("tcp", addr)
- if err != nil {
- fmt.Println("Listen error:", err)
- return
- }
- defer listener.Close()
- httpServer := http.Server{
- Handler: http.HandlerFunc(server.stream),
- }
- log.Println("Server listening on", addr)
- if err := httpServer.Serve(listener); err != nil {
- log.Fatal("server error:", err)
- }
- }
|