|
@@ -172,9 +172,6 @@ func (llm *LLM) Close() {
|
|
}
|
|
}
|
|
|
|
|
|
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
|
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
|
- llm.mu.Lock()
|
|
|
|
- defer llm.mu.Unlock()
|
|
|
|
-
|
|
|
|
C.llama_reset_timings(llm.ctx)
|
|
C.llama_reset_timings(llm.ctx)
|
|
|
|
|
|
tokens := make([]C.llama_token, len(ctx))
|
|
tokens := make([]C.llama_token, len(ctx))
|
|
@@ -193,12 +190,12 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
|
|
var b bytes.Buffer
|
|
var b bytes.Buffer
|
|
for {
|
|
for {
|
|
token, err := llm.next()
|
|
token, err := llm.next()
|
|
- if errors.Is(err, io.EOF) {
|
|
|
|
|
|
+ if llm.gc {
|
|
|
|
+ return nil
|
|
|
|
+ } else if errors.Is(err, io.EOF) {
|
|
break
|
|
break
|
|
} else if err != nil {
|
|
} else if err != nil {
|
|
return err
|
|
return err
|
|
- } else if llm.gc {
|
|
|
|
- return io.EOF
|
|
|
|
}
|
|
}
|
|
|
|
|
|
b.WriteString(llm.detokenize(token))
|
|
b.WriteString(llm.detokenize(token))
|
|
@@ -293,6 +290,9 @@ func (llm *LLM) detokenize(tokens ...C.llama_token) string {
|
|
}
|
|
}
|
|
|
|
|
|
func (llm *LLM) next() (C.llama_token, error) {
|
|
func (llm *LLM) next() (C.llama_token, error) {
|
|
|
|
+ llm.mu.Lock()
|
|
|
|
+ defer llm.mu.Unlock()
|
|
|
|
+
|
|
if len(llm.embd) >= llm.NumCtx {
|
|
if len(llm.embd) >= llm.NumCtx {
|
|
numLeft := (llm.NumCtx - llm.NumKeep) / 2
|
|
numLeft := (llm.NumCtx - llm.NumKeep) / 2
|
|
truncated := llm.embd[:llm.NumKeep]
|
|
truncated := llm.embd[:llm.NumKeep]
|
|
@@ -304,6 +304,10 @@ func (llm *LLM) next() (C.llama_token, error) {
|
|
}
|
|
}
|
|
|
|
|
|
for {
|
|
for {
|
|
|
|
+ if llm.gc {
|
|
|
|
+ return 0, io.EOF
|
|
|
|
+ }
|
|
|
|
+
|
|
if llm.cursor >= len(llm.embd) {
|
|
if llm.cursor >= len(llm.embd) {
|
|
break
|
|
break
|
|
}
|
|
}
|