Browse Source

Merge pull request #102 from jmorganca/session-id

Session
Michael Yang 1 year ago
parent
commit
db77dfe01f
5 changed files with 343 additions and 212 deletions
  1. 58 2
      api/types.go
  2. 16 8
      cmd/cmd.go
  3. 187 86
      llama/llama.go
  4. 9 98
      llama/utils.go
  5. 73 18
      server/routes.go

+ 58 - 2
api/types.go

@@ -1,7 +1,9 @@
 package api
 
 import (
+	"encoding/json"
 	"fmt"
+	"math"
 	"os"
 	"runtime"
 	"time"
@@ -28,6 +30,9 @@ func (e StatusError) Error() string {
 }
 
 type GenerateRequest struct {
+	SessionID       int64    `json:"session_id"`
+	SessionDuration Duration `json:"session_duration,omitempty"`
+
 	Model   string `json:"model"`
 	Prompt  string `json:"prompt"`
 	Context []int  `json:"context,omitempty"`
@@ -81,6 +86,9 @@ type ListResponseModel struct {
 }
 
 type GenerateResponse struct {
+	SessionID        int64     `json:"session_id"`
+	SessionExpiresAt time.Time `json:"session_expires_at"`
+
 	Model     string    `json:"model"`
 	CreatedAt time.Time `json:"created_at"`
 	Response  string    `json:"response,omitempty"`
@@ -89,6 +97,9 @@ type GenerateResponse struct {
 	Context []int `json:"context,omitempty"`
 
 	TotalDuration      time.Duration `json:"total_duration,omitempty"`
+	LoadDuration       time.Duration `json:"load_duration,omitempty"`
+	SampleCount        int           `json:"sample_count,omitempty"`
+	SampleDuration     time.Duration `json:"sample_duration,omitempty"`
 	PromptEvalCount    int           `json:"prompt_eval_count,omitempty"`
 	PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
 	EvalCount          int           `json:"eval_count,omitempty"`
@@ -100,6 +111,19 @@ func (r *GenerateResponse) Summary() {
 		fmt.Fprintf(os.Stderr, "total duration:       %v\n", r.TotalDuration)
 	}
 
+	if r.LoadDuration > 0 {
+		fmt.Fprintf(os.Stderr, "load duration:        %v\n", r.LoadDuration)
+	}
+
+	if r.SampleCount > 0 {
+		fmt.Fprintf(os.Stderr, "sample count:         %d token(s)\n", r.SampleCount)
+	}
+
+	if r.SampleDuration > 0 {
+		fmt.Fprintf(os.Stderr, "sample duration:      %s\n", r.SampleDuration)
+		fmt.Fprintf(os.Stderr, "sample rate:          %.2f tokens/s\n", float64(r.SampleCount)/r.SampleDuration.Seconds())
+	}
+
 	if r.PromptEvalCount > 0 {
 		fmt.Fprintf(os.Stderr, "prompt eval count:    %d token(s)\n", r.PromptEvalCount)
 	}
@@ -127,6 +151,7 @@ type Options struct {
 
 	// Model options
 	NumCtx        int  `json:"num_ctx,omitempty"`
+	NumKeep       int  `json:"num_keep,omitempty"`
 	NumBatch      int  `json:"num_batch,omitempty"`
 	NumGPU        int  `json:"num_gpu,omitempty"`
 	MainGPU       int  `json:"main_gpu,omitempty"`
@@ -151,6 +176,7 @@ type Options struct {
 	Mirostat         int     `json:"mirostat,omitempty"`
 	MirostatTau      float32 `json:"mirostat_tau,omitempty"`
 	MirostatEta      float32 `json:"mirostat_eta,omitempty"`
+	PenalizeNewline  bool    `json:"penalize_newline,omitempty"`
 
 	NumThread int `json:"num_thread,omitempty"`
 }
@@ -162,14 +188,14 @@ func DefaultOptions() Options {
 		UseNUMA: false,
 
 		NumCtx:   2048,
-		NumBatch: 512,
+		NumBatch: 1024,
 		NumGPU:   1,
 		LowVRAM:  false,
 		F16KV:    true,
 		UseMMap:  true,
 		UseMLock: false,
 
-		RepeatLastN:      512,
+		RepeatLastN:      64,
 		RepeatPenalty:    1.1,
 		FrequencyPenalty: 0.0,
 		PresencePenalty:  0.0,
@@ -181,7 +207,37 @@ func DefaultOptions() Options {
 		Mirostat:         0,
 		MirostatTau:      5.0,
 		MirostatEta:      0.1,
+		PenalizeNewline:  true,
 
 		NumThread: runtime.NumCPU(),
 	}
 }
+
+type Duration struct {
+	time.Duration
+}
+
+func (d *Duration) UnmarshalJSON(b []byte) (err error) {
+	var v any
+	if err := json.Unmarshal(b, &v); err != nil {
+		return err
+	}
+
+	d.Duration = 5 * time.Minute
+
+	switch t := v.(type) {
+	case float64:
+		if t < 0 {
+			t = math.MaxFloat64
+		}
+
+		d.Duration = time.Duration(t)
+	case string:
+		d.Duration, err = time.ParseDuration(t)
+		if err != nil {
+			return err
+		}
+	}
+
+	return nil
+}

+ 16 - 8
cmd/cmd.go

@@ -244,7 +244,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
 	return generateBatch(cmd, args[0])
 }
 
-var generateContextKey struct{}
+type generateContextKey string
 
 func generate(cmd *cobra.Command, model, prompt string) error {
 	if len(strings.TrimSpace(prompt)) > 0 {
@@ -255,22 +255,25 @@ func generate(cmd *cobra.Command, model, prompt string) error {
 
 		var latest api.GenerateResponse
 
-		generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
+		generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
 		if !ok {
 			generateContext = []int{}
 		}
 
-		request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
-		fn := func(resp api.GenerateResponse) error {
+		generateSession, ok := cmd.Context().Value(generateContextKey("session")).(int64)
+		if !ok {
+			generateSession = 0
+		}
+
+		request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, SessionID: generateSession}
+		fn := func(response api.GenerateResponse) error {
 			if !spinner.IsFinished() {
 				spinner.Finish()
 			}
 
-			latest = resp
+			latest = response
 
-			fmt.Print(resp.Response)
-
-			cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
+			fmt.Print(response.Response)
 			return nil
 		}
 
@@ -289,6 +292,11 @@ func generate(cmd *cobra.Command, model, prompt string) error {
 		if verbose {
 			latest.Summary()
 		}
+
+		ctx := cmd.Context()
+		ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
+		ctx = context.WithValue(ctx, generateContextKey("session"), latest.SessionID)
+		cmd.SetContext(ctx)
 	}
 
 	return nil

+ 187 - 86
llama/llama.go

@@ -1,8 +1,8 @@
 package llama
 
 /*
-#cgo CPPFLAGS: -O3 -DNDEBUG=1 -DGGML_USE_K_QUANTS
-#cgo CXXFLAGS: -std=c++11
+#cgo CPPFLAGS: -O3 -Wall -Wextra -Werror -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
+#cgo CXXFLAGS: -std=gnu++11
 #cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE -DGGML_USE_METAL -DGGML_METAL_NDEBUG
 #cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
 #include <stdlib.h>
@@ -21,6 +21,7 @@ struct llama_sample_options
 	int mirostat;
 	float mirostat_tau;
 	float mirostat_eta;
+	bool penalize_newline;
 };
 
 llama_token llama_sample(
@@ -37,6 +38,8 @@ llama_token llama_sample(
 		false,
 	};
 
+	struct llama_token_data newline = candidates_p.data[llama_token_nl()];
+
 	llama_sample_repetition_penalty(
 		ctx, &candidates_p,
 		last_tokens, n_last_tokens,
@@ -47,6 +50,10 @@ llama_token llama_sample(
 		last_tokens, n_last_tokens,
 		opts->frequency_penalty, opts->presence_penalty);
 
+	if (!opts->penalize_newline) {
+		candidates_p.data[llama_token_nl()] = newline;
+	}
+
 	if (opts->temperature <= 0) {
 		return llama_sample_token_greedy(ctx, &candidates_p);
 	}
@@ -82,29 +89,37 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"log"
 	"os"
 	"strings"
-	"time"
+	"sync"
 	"unicode/utf8"
 	"unsafe"
 
 	"github.com/jmorganca/ollama/api"
 )
 
-type llama struct {
+type LLM struct {
 	params *C.struct_llama_context_params
 	model  *C.struct_llama_model
 	ctx    *C.struct_llama_context
 
+	last   []C.llama_token
+	embd   []C.llama_token
+	cursor int
+
+	mu sync.Mutex
+	gc bool
+
 	api.Options
 }
 
-func New(model string, opts api.Options) (*llama, error) {
+func New(model string, opts api.Options) (*LLM, error) {
 	if _, err := os.Stat(model); err != nil {
 		return nil, err
 	}
 
-	llm := llama{Options: opts}
+	llm := LLM{Options: opts}
 
 	C.llama_backend_init(C.bool(llm.UseNUMA))
 
@@ -144,27 +159,118 @@ func New(model string, opts api.Options) (*llama, error) {
 	return &llm, nil
 }
 
-func (llm *llama) Close() {
+func (llm *LLM) Close() {
+	llm.gc = true
+
+	llm.mu.Lock()
+	defer llm.mu.Unlock()
+
 	defer C.llama_free_model(llm.model)
 	defer C.llama_free(llm.ctx)
 
 	C.llama_print_timings(llm.ctx)
 }
 
-func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
-	if input := llm.tokenize(prompt); input != nil {
-		embd := make([]C.llama_token, len(ctx))
-		for i := range ctx {
-			embd[i] = C.llama_token(ctx[i])
+func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
+	C.llama_reset_timings(llm.ctx)
+
+	tokens := make([]C.llama_token, len(ctx))
+	for i := range tokens {
+		tokens[i] = C.llama_token(ctx[i])
+	}
+
+	if len(tokens) == 0 {
+		tokens = llm.tokenize(" ")
+	}
+
+	llm.marshalPrompt(tokens, prompt)
+
+	C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
+
+	var b bytes.Buffer
+	for {
+		token, err := llm.next()
+		if llm.gc {
+			return nil
+		} else if errors.Is(err, io.EOF) {
+			break
+		} else if err != nil {
+			return err
+		}
+
+		b.WriteString(llm.detokenize(token))
+		if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
+			fn(api.GenerateResponse{Response: b.String()})
+			b.Reset()
 		}
+	}
 
-		return llm.generate(append(embd, input...), fn)
+	last := make([]int, 0, len(llm.last))
+	for _, i := range llm.last {
+		if i != 0 {
+			last = append(last, int(i))
+		}
 	}
 
-	return errors.New("llama: tokenize")
+	timings := C.llama_get_timings(llm.ctx)
+	fn(api.GenerateResponse{
+		Done:               true,
+		Context:            last,
+		SampleCount:        int(timings.n_sample),
+		SampleDuration:     parseDurationMs(float64(timings.t_sample_ms)),
+		PromptEvalCount:    int(timings.n_p_eval),
+		PromptEvalDuration: parseDurationMs(float64(timings.t_p_eval_ms)),
+		EvalCount:          int(timings.n_eval),
+		EvalDuration:       parseDurationMs(float64(timings.t_eval_ms)),
+	})
+
+	return nil
 }
 
-func (llm *llama) tokenize(prompt string) []C.llama_token {
+func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
+	tokens := append(ctx, llm.tokenize(prompt)...)
+	if llm.NumKeep < 0 {
+		llm.NumKeep = len(tokens)
+	}
+
+	// min(llm.NumCtx - 4, llm.NumKeep)
+	if llm.NumCtx-4 < llm.NumKeep {
+		llm.NumKeep = llm.NumCtx - 4
+	}
+
+	if len(tokens) >= llm.NumCtx {
+		// truncate input
+		numLeft := (llm.NumCtx - llm.NumKeep) / 2
+		truncated := tokens[:llm.NumKeep]
+		erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
+		truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
+		copy(llm.last, tokens[len(tokens)-llm.NumCtx:])
+
+		tokens = truncated
+		log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
+	} else {
+		llm.last = make([]C.llama_token, llm.NumCtx-len(tokens))
+		llm.last = append(llm.last, tokens...)
+	}
+
+	var i int
+	for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ {
+		// noop
+	}
+
+	llm.embd = tokens
+	if i == len(tokens) {
+		// evaluate at least one token to generate logits
+		i--
+	}
+
+	llm.cursor = i
+
+	log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
+	return tokens
+}
+
+func (llm *LLM) tokenize(prompt string) []C.llama_token {
 	cPrompt := C.CString(prompt)
 	defer C.free(unsafe.Pointer(cPrompt))
 
@@ -176,7 +282,7 @@ func (llm *llama) tokenize(prompt string) []C.llama_token {
 	return nil
 }
 
-func (llm *llama) detokenize(tokens ...C.llama_token) string {
+func (llm *LLM) detokenize(tokens ...C.llama_token) string {
 	var sb strings.Builder
 	for _, token := range tokens {
 		sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
@@ -185,98 +291,93 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
 	return sb.String()
 }
 
-func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
-	var opts C.struct_llama_sample_options
-	opts.repeat_penalty = C.float(llm.RepeatPenalty)
-	opts.frequency_penalty = C.float(llm.FrequencyPenalty)
-	opts.presence_penalty = C.float(llm.PresencePenalty)
-	opts.temperature = C.float(llm.Temperature)
-	opts.top_k = C.int(llm.TopK)
-	opts.top_p = C.float(llm.TopP)
-	opts.tfs_z = C.float(llm.TFSZ)
-	opts.typical_p = C.float(llm.TypicalP)
-	opts.mirostat = C.int(llm.Mirostat)
-	opts.mirostat_tau = C.float(llm.MirostatTau)
-	opts.mirostat_eta = C.float(llm.MirostatEta)
-
-	output := deque[C.llama_token]{capacity: llm.NumCtx}
-
-	context := deque[int]{capacity: llm.NumCtx / 2}
-	for _, in := range input {
-		context.PushLeft(int(in))
+func (llm *LLM) next() (C.llama_token, error) {
+	llm.mu.Lock()
+	defer llm.mu.Unlock()
+
+	if len(llm.embd) >= llm.NumCtx {
+		numLeft := (llm.NumCtx - llm.NumKeep) / 2
+		truncated := llm.embd[:llm.NumKeep]
+		truncated = append(truncated, llm.embd[len(llm.embd)-numLeft:]...)
+
+		llm.embd = truncated
+		llm.cursor = llm.NumKeep
+		log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d cursor=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated), llm.cursor)
 	}
 
-	var b bytes.Buffer
-	for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
-		if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
-			return errors.New("llama: eval")
+	for {
+		if llm.gc {
+			return 0, io.EOF
 		}
 
-		token, err := llm.sample(output, &opts)
-		if errors.Is(err, io.EOF) {
+		if llm.cursor >= len(llm.embd) {
 			break
-		} else if err != nil {
-			return err
 		}
 
-		b.WriteString(llm.detokenize(token))
-		if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
-			// call the callback
-			fn(api.GenerateResponse{
-				Response: b.String(),
-			})
-
-			output.PushLeft(token)
-			context.PushLeft(int(token))
-			b.Reset()
+		numEval := len(llm.embd) - llm.cursor
+		if numEval > llm.NumBatch {
+			numEval = llm.NumBatch
 		}
 
-		input = []C.llama_token{token}
-	}
-
-	dur := func(ms float64) time.Duration {
-		d, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
-		if err != nil {
-			panic(err)
+		if retval := C.llama_eval(llm.ctx, unsafe.SliceData(llm.embd[llm.cursor:]), C.int(numEval), C.int(llm.cursor), C.int(llm.NumThread)); retval != 0 {
+			return 0, fmt.Errorf("llama_eval: %d", retval)
 		}
 
-		return d
+		llm.cursor += numEval
 	}
 
-	timings := C.llama_get_timings(llm.ctx)
-	fn(api.GenerateResponse{
-		Done:               true,
-		Context:            context.Data(),
-		PromptEvalCount:    int(timings.n_p_eval),
-		PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
-		EvalCount:          int(timings.n_eval),
-		EvalDuration:       dur(float64(timings.t_eval_ms)),
-	})
-
-	return nil
-}
-
-func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
-	numVocab := int(C.llama_n_vocab(llm.ctx))
+	var sampleOpts C.struct_llama_sample_options
+	sampleOpts.repeat_penalty = C.float(llm.RepeatPenalty)
+	sampleOpts.frequency_penalty = C.float(llm.FrequencyPenalty)
+	sampleOpts.presence_penalty = C.float(llm.PresencePenalty)
+	sampleOpts.temperature = C.float(llm.Temperature)
+	sampleOpts.top_k = C.int(llm.TopK)
+	sampleOpts.top_p = C.float(llm.TopP)
+	sampleOpts.tfs_z = C.float(llm.TFSZ)
+	sampleOpts.typical_p = C.float(llm.TypicalP)
+	sampleOpts.mirostat = C.int(llm.Mirostat)
+	sampleOpts.mirostat_tau = C.float(llm.MirostatTau)
+	sampleOpts.mirostat_eta = C.float(llm.MirostatEta)
+	sampleOpts.penalize_newline = C.bool(llm.PenalizeNewline)
+
+	numVocab := C.llama_n_vocab(llm.ctx)
 	logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
 
-	candidates := deque[C.struct_llama_token_data]{capacity: numVocab}
-	for i := 0; i < candidates.Cap(); i++ {
-		candidates.PushLeft(C.struct_llama_token_data{
+	// TODO: logit bias
+
+	candidates := make([]C.llama_token_data, numVocab)
+	for i := range logits {
+		candidates[i] = C.llama_token_data{
 			id:    C.int(i),
 			logit: logits[i],
 			p:     0,
-		})
+		}
 	}
 
+	repeatLastN := llm.RepeatLastN
+	if len(llm.last) < repeatLastN {
+		repeatLastN = len(llm.last)
+	}
+
+	if llm.NumCtx < repeatLastN {
+		repeatLastN = llm.NumCtx
+	}
+
+	lastN := llm.last[len(llm.last)-repeatLastN:]
+
 	token := C.llama_sample(
 		llm.ctx,
-		unsafe.SliceData(candidates.Data()), C.size_t(candidates.Len()),
-		unsafe.SliceData(output.Data()), C.size_t(output.Len()),
-		opts)
-	if token != C.llama_token_eos() {
-		return token, nil
+		unsafe.SliceData(candidates), C.size_t(len(candidates)),
+		unsafe.SliceData(lastN), C.size_t(len(lastN)),
+		&sampleOpts,
+	)
+
+	llm.last = append(llm.last, token)
+	llm.embd = append(llm.embd, token)
+
+	if token == C.llama_token_eos() {
+		return 0, io.EOF
 	}
 
-	return 0, io.EOF
+	return token, nil
 }

+ 9 - 98
llama/utils.go

@@ -1,104 +1,15 @@
 package llama
 
-type node[T any] struct {
-	t    T
-	next *node[T]
-	prev *node[T]
-}
-
-type deque[T any] struct {
-	head     *node[T]
-	tail     *node[T]
-	size     int
-	capacity int
-}
-
-func (d *deque[T]) Empty() bool {
-	return d.size == 0
-}
-
-func (d *deque[T]) Len() int {
-	return d.size
-}
-
-func (d *deque[T]) Cap() int {
-	return d.capacity
-}
-
-func (d *deque[T]) Push(t T) {
-	if d.capacity > 0 && d.size >= d.capacity {
-		d.PopLeft()
-	}
-
-	n := node[T]{t: t}
-	if d.head != nil {
-		n.next = d.head
-		d.head.prev = &n
-		d.head = &n
-	} else {
-		d.head = &n
-		d.tail = &n
-	}
-
-	d.size++
-}
-
-func (d *deque[T]) PushLeft(t T) {
-	if d.capacity > 0 && d.size >= d.capacity {
-		d.Pop()
-	}
-
-	n := node[T]{t: t}
-	if d.tail != nil {
-		n.prev = d.tail
-		d.tail.next = &n
-		d.tail = &n
-	} else {
-		d.head = &n
-		d.tail = &n
-	}
-
-	d.size++
-}
-
-func (d *deque[T]) Pop() *T {
-	if d.Empty() {
-		return nil
-	}
-
-	head := d.head
-	d.head = head.next
-	if d.head != nil {
-		d.head.prev = nil
-	} else {
-		d.tail = nil
-	}
-
-	d.size--
-	return &head.t
-}
-
-func (d *deque[T]) PopLeft() *T {
-	if d.Empty() {
-		return nil
-	}
-
-	tail := d.tail
-	d.tail = tail.prev
-	if d.tail != nil {
-		d.tail.next = nil
-	} else {
-		d.head = nil
-	}
-
-	d.size--
-	return &tail.t
-}
+import (
+	"fmt"
+	"time"
+)
 
-func (d *deque[T]) Data() (data []T) {
-	for n := d.head; n != nil; n = n.next {
-		data = append(data, n.t)
+func parseDurationMs(ms float64) time.Duration {
+	dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
+	if err != nil {
+		panic(err)
 	}
 
-	return data
+	return dur
 }

+ 73 - 18
server/routes.go

@@ -11,6 +11,7 @@ import (
 	"os"
 	"path/filepath"
 	"strings"
+	"sync"
 	"time"
 
 	"dario.cat/mergo"
@@ -21,8 +22,21 @@ import (
 	"github.com/jmorganca/ollama/llama"
 )
 
+var activeSession struct {
+	mu sync.Mutex
+
+	id  int64
+	llm *llama.LLM
+
+	expireAt    time.Time
+	expireTimer *time.Timer
+}
+
 func GenerateHandler(c *gin.Context) {
-	start := time.Now()
+	activeSession.mu.Lock()
+	defer activeSession.mu.Unlock()
+
+	checkpointStart := time.Now()
 
 	var req api.GenerateRequest
 	if err := c.ShouldBindJSON(&req); err != nil {
@@ -36,44 +50,85 @@ func GenerateHandler(c *gin.Context) {
 		return
 	}
 
-	opts := api.DefaultOptions()
-	if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
+	if req.SessionID == 0 || req.SessionID != activeSession.id {
+		if activeSession.llm != nil {
+			activeSession.llm.Close()
+			activeSession.llm = nil
+		}
 
-	if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
+		opts := api.DefaultOptions()
+		if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+
+		if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+
+		llm, err := llama.New(model.ModelPath, opts)
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+
+		activeSession.id = time.Now().UnixNano()
+		activeSession.llm = llm
 	}
 
-	prompt, err := model.Prompt(req)
-	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
+	sessionDuration := req.SessionDuration
+	sessionID := activeSession.id
+
+	activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
+	if activeSession.expireTimer == nil {
+		activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() {
+			activeSession.mu.Lock()
+			defer activeSession.mu.Unlock()
+
+			if sessionID != activeSession.id {
+				return
+			}
+
+			if time.Now().Before(activeSession.expireAt) {
+				return
+			}
+
+			activeSession.llm.Close()
+			activeSession.llm = nil
+			activeSession.id = 0
+		})
 	}
+	activeSession.expireTimer.Reset(sessionDuration.Duration)
+
+	checkpointLoaded := time.Now()
 
-	llm, err := llama.New(model.ModelPath, opts)
+	prompt, err := model.Prompt(req)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
-	defer llm.Close()
 
 	ch := make(chan any)
 	go func() {
 		defer close(ch)
 		fn := func(r api.GenerateResponse) {
+			activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
+			activeSession.expireTimer.Reset(sessionDuration.Duration)
+
 			r.Model = req.Model
 			r.CreatedAt = time.Now().UTC()
+			r.SessionID = activeSession.id
+			r.SessionExpiresAt = activeSession.expireAt.UTC()
 			if r.Done {
-				r.TotalDuration = time.Since(start)
+				r.TotalDuration = time.Since(checkpointStart)
+				r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
 			}
 
 			ch <- r
 		}
 
-		if err := llm.Predict(req.Context, prompt, fn); err != nil {
+		if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}
 	}()
@@ -223,7 +278,7 @@ func ListModelsHandler(c *gin.Context) {
 		return
 	}
 
-	c.JSON(http.StatusOK, api.ListResponse{models})
+	c.JSON(http.StatusOK, api.ListResponse{Models: models})
 }
 
 func CopyModelHandler(c *gin.Context) {