|
@@ -24,9 +24,28 @@ type Server struct {
|
|
|
model *llama.Model
|
|
|
lc *llama.Context
|
|
|
batch *llama.Batch
|
|
|
+
|
|
|
+ queue chan Sequence
|
|
|
+ seqs []*Sequence
|
|
|
+
|
|
|
+ // mu guards seqs
|
|
|
+ mu sync.Mutex
|
|
|
+}
|
|
|
+
|
|
|
+type Sequence struct {
|
|
|
+ prompt []llama.Token
|
|
|
+ out chan string
|
|
|
}
|
|
|
|
|
|
-var mu sync.Mutex
|
|
|
+func schedule(parallel int, queue <-chan Sequence) {
|
|
|
+ // Fill sequences from the queue
|
|
|
+
|
|
|
+ // once a sequence finishes, remove it from and add a new one from the queue
|
|
|
+}
|
|
|
+
|
|
|
+func process() {
|
|
|
+ // loop through the sequences, fill a batch, decode and sample tokens, responding to appropriate requests
|
|
|
+}
|
|
|
|
|
|
func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
|
|
|
var request Request
|
|
@@ -40,17 +59,23 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
|
|
|
w.Header().Set("Transfer-Encoding", "chunked")
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
|
|
- enc := json.NewEncoder(w)
|
|
|
-
|
|
|
- // main loop
|
|
|
tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
|
|
|
if err != nil {
|
|
|
panic(err)
|
|
|
}
|
|
|
|
|
|
- fmt.Println("tokens", tokens)
|
|
|
+ seq := Sequence{prompt: tokens}
|
|
|
+ s.queue <- seq
|
|
|
|
|
|
- batch := llama.NewBatch(512, 0, 1)
|
|
|
+ // listen for the sequence to finish
|
|
|
+ for {
|
|
|
+ str := <-seq.out
|
|
|
+ if err := json.NewEncoder(w).Encode(&Response{Token: str}); err != nil {
|
|
|
+ log.Println("Failed to encode result:", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ w.(http.Flusher).Flush()
|
|
|
+ }
|
|
|
|
|
|
// prompt eval
|
|
|
for i, t := range tokens {
|
|
@@ -90,6 +115,7 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
|
func main() {
|
|
|
mp := flag.String("model", "", "Path to model binary file")
|
|
|
+ parallel := flag.Int("parallel", 1, "Number of parallel requests to handle")
|
|
|
flag.Parse()
|
|
|
|
|
|
// load the model
|
|
@@ -105,6 +131,8 @@ func main() {
|
|
|
server := &Server{
|
|
|
model: model,
|
|
|
lc: lc,
|
|
|
+ queue: make(chan Sequence, 256),
|
|
|
+ seqs: make([]*Sequence, *parallel),
|
|
|
}
|
|
|
|
|
|
addr := "127.0.0.1:8080"
|