jmorganca 11 months ago
parent
commit
25fd8fd045
1 changed files with 34 additions and 6 deletions
  1. 34 6
      llama/runner/main.go

+ 34 - 6
llama/runner/main.go

@@ -24,9 +24,28 @@ type Server struct {
 	model *llama.Model
 	lc    *llama.Context
 	batch *llama.Batch
+
+	queue chan Sequence
+	seqs  []*Sequence
+
+	// mu guards seqs
+	mu sync.Mutex
+}
+
+type Sequence struct {
+	prompt []llama.Token
+	out    chan string
 }
 
-var mu sync.Mutex
+func schedule(parallel int, queue <-chan Sequence) {
+	// Fill sequences from the queue
+
+	// once a sequence finishes, remove it from and add a new one from the queue
+}
+
+func process() {
+	// loop through the sequences, fill a batch, decode and sample tokens, responding to appropriate requests
+}
 
 func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
 	var request Request
@@ -40,17 +59,23 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Transfer-Encoding", "chunked")
 	w.WriteHeader(http.StatusOK)
 
-	enc := json.NewEncoder(w)
-
-	// main loop
 	tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
 	if err != nil {
 		panic(err)
 	}
 
-	fmt.Println("tokens", tokens)
+	seq := Sequence{prompt: tokens}
+	s.queue <- seq
 
-	batch := llama.NewBatch(512, 0, 1)
+	// listen for the sequence to finish
+	for {
+		str := <-seq.out
+		if err := json.NewEncoder(w).Encode(&Response{Token: str}); err != nil {
+			log.Println("Failed to encode result:", err)
+			return
+		}
+		w.(http.Flusher).Flush()
+	}
 
 	// prompt eval
 	for i, t := range tokens {
@@ -90,6 +115,7 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
 
 func main() {
 	mp := flag.String("model", "", "Path to model binary file")
+	parallel := flag.Int("parallel", 1, "Number of parallel requests to handle")
 	flag.Parse()
 
 	// load the model
@@ -105,6 +131,8 @@ func main() {
 	server := &Server{
 		model: model,
 		lc:    lc,
+		queue: make(chan Sequence, 256),
+		seqs:  make([]*Sequence, *parallel),
 	}
 
 	addr := "127.0.0.1:8080"