|
@@ -94,7 +94,7 @@ func (s *Server) allNil() bool {
|
|
|
return true
|
|
|
}
|
|
|
|
|
|
-func contains(sequence string, stops []string) (bool, string) {
|
|
|
+func findStop(sequence string, stops []string) (bool, string) {
|
|
|
for _, stop := range stops {
|
|
|
if strings.Contains(sequence, stop) {
|
|
|
return true, stop
|
|
@@ -104,9 +104,9 @@ func contains(sequence string, stops []string) (bool, string) {
|
|
|
return false, ""
|
|
|
}
|
|
|
|
|
|
-func overlap(sequence string, stops []string) bool {
|
|
|
+func containsStopSuffix(sequence string, stops []string) bool {
|
|
|
for _, stop := range stops {
|
|
|
- for i := 1; i < len(stop); i++ {
|
|
|
+ for i := 1; i <= len(stop); i++ {
|
|
|
if strings.HasSuffix(sequence, stop[:i]) {
|
|
|
return true
|
|
|
}
|
|
@@ -116,13 +116,50 @@ func overlap(sequence string, stops []string) bool {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
+// truncateStop removes the provided stop string from pieces,
|
|
|
+// returning the partial pieces with stop removed, including truncating
|
|
|
+// the last piece if required
|
|
|
+func truncateStop(pieces []string, stop string) []string {
|
|
|
+ joined := strings.Join(pieces, "")
|
|
|
+
|
|
|
+ index := strings.Index(joined, stop)
|
|
|
+ if index == -1 {
|
|
|
+ return pieces
|
|
|
+ }
|
|
|
+
|
|
|
+ joined = joined[:index]
|
|
|
+
|
|
|
+ // Split truncated string back into pieces of original lengths
|
|
|
+ lengths := make([]int, len(pieces))
|
|
|
+ for i, piece := range pieces {
|
|
|
+ lengths[i] = len(piece)
|
|
|
+ }
|
|
|
+
|
|
|
+ var result []string
|
|
|
+ start := 0
|
|
|
+ for _, length := range lengths {
|
|
|
+ if start >= len(joined) {
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ end := start + length
|
|
|
+ if end > len(joined) {
|
|
|
+ end = len(joined)
|
|
|
+ }
|
|
|
+ result = append(result, joined[start:end])
|
|
|
+ start = end
|
|
|
+ }
|
|
|
+
|
|
|
+ return result
|
|
|
+}
|
|
|
+
|
|
|
func (s *Server) run(ctx context.Context) {
|
|
|
batch := llama.NewBatch(512, 0, s.parallel)
|
|
|
defer batch.Free()
|
|
|
|
|
|
// build up stop sequences as we recognize them
|
|
|
// TODO (jmorganca): simplify this
|
|
|
- sofar := make([][]string, s.parallel)
|
|
|
+ pieces := make([][]string, s.parallel)
|
|
|
|
|
|
for {
|
|
|
select {
|
|
@@ -214,50 +251,41 @@ func (s *Server) run(ctx context.Context) {
|
|
|
|
|
|
close(seq.responses)
|
|
|
seq.samplingCtx.Free()
|
|
|
- sofar[i] = []string{}
|
|
|
+ pieces[i] = []string{}
|
|
|
s.seqs[i] = nil
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
seq.tokens = []int{token}
|
|
|
|
|
|
- // recognize stop sequences
|
|
|
- // TODO (jmorganca): add tests around this
|
|
|
- // TODO (jmorganca): send back parital piece
|
|
|
-
|
|
|
- sequence := strings.Join(append(sofar[i], piece), "")
|
|
|
- if ok, stop := contains(sequence, seq.stop); ok {
|
|
|
+ pieces[i] = append(pieces[i], piece)
|
|
|
+ sequence := strings.Join(pieces[i], "")
|
|
|
+ if ok, stop := findStop(sequence, seq.stop); ok {
|
|
|
slog.Info("hit stop token", "stop", seq.stop)
|
|
|
- for _, p := range sofar[i] {
|
|
|
+
|
|
|
+ truncated := truncateStop(pieces[i], stop)
|
|
|
+
|
|
|
+ for _, p := range truncated {
|
|
|
seq.responses <- p
|
|
|
}
|
|
|
|
|
|
- piece, _, _ := strings.Cut(piece, stop)
|
|
|
- seq.responses <- piece
|
|
|
-
|
|
|
s.lc.KvCacheSeqRm(i, 0, -1)
|
|
|
close(seq.responses)
|
|
|
seq.samplingCtx.Free()
|
|
|
- sofar[i] = []string{}
|
|
|
+ pieces[i] = []string{}
|
|
|
s.seqs[i] = nil
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- if overlap(sequence, seq.stop) {
|
|
|
- slog.Info("overlap", "sequence", sequence)
|
|
|
- // partial stop, don't send
|
|
|
+ if containsStopSuffix(sequence, seq.stop) {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- slog.Info("sending", "sofar", sofar[i])
|
|
|
-
|
|
|
- sofar[i] = append(sofar[i], piece)
|
|
|
-
|
|
|
- for _, p := range sofar[i] {
|
|
|
+ for _, p := range pieces[i] {
|
|
|
seq.responses <- p
|
|
|
}
|
|
|
|
|
|
- sofar[i] = []string{}
|
|
|
+ pieces[i] = []string{}
|
|
|
}
|
|
|
|
|
|
batch.Clear()
|