Browse Source

fix model name returned by `/api/generate` being different than the model name provided

Jeffrey Morgan 1 year ago
parent
commit
fa2f095bd9
2 changed files with 2 additions and 8 deletions
  1. 0 4
      llm/llama.go
  2. 2 4
      server/routes.go

+ 0 - 4
llm/llama.go

@@ -545,7 +545,6 @@ type prediction struct {
 const maxBufferSize = 512 * format.KiloByte
 
 type PredictOpts struct {
-	Model            string
 	Prompt           string
 	Format           string
 	CheckpointStart  time.Time
@@ -553,7 +552,6 @@ type PredictOpts struct {
 }
 
 type PredictResult struct {
-	Model              string
 	CreatedAt          time.Time
 	TotalDuration      time.Duration
 	LoadDuration       time.Duration
@@ -651,7 +649,6 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
 
 			if p.Content != "" {
 				fn(PredictResult{
-					Model:     predict.Model,
 					CreatedAt: time.Now().UTC(),
 					Content:   p.Content,
 				})
@@ -659,7 +656,6 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
 
 			if p.Stop {
 				fn(PredictResult{
-					Model:         predict.Model,
 					CreatedAt:     time.Now().UTC(),
 					TotalDuration: time.Since(predict.CheckpointStart),
 

+ 2 - 4
server/routes.go

@@ -260,7 +260,7 @@ func GenerateHandler(c *gin.Context) {
 			}
 
 			resp := api.GenerateResponse{
-				Model:     r.Model,
+				Model:     req.Model,
 				CreatedAt: r.CreatedAt,
 				Done:      r.Done,
 				Response:  r.Content,
@@ -288,7 +288,6 @@ func GenerateHandler(c *gin.Context) {
 
 		// Start prediction
 		predictReq := llm.PredictOpts{
-			Model:            model.Name,
 			Prompt:           prompt,
 			Format:           req.Format,
 			CheckpointStart:  checkpointStart,
@@ -985,7 +984,7 @@ func ChatHandler(c *gin.Context) {
 			loaded.expireTimer.Reset(sessionDuration)
 
 			resp := api.ChatResponse{
-				Model:     r.Model,
+				Model:     req.Model,
 				CreatedAt: r.CreatedAt,
 				Done:      r.Done,
 				Metrics: api.Metrics{
@@ -1007,7 +1006,6 @@ func ChatHandler(c *gin.Context) {
 
 		// Start prediction
 		predictReq := llm.PredictOpts{
-			Model:            model.Name,
 			Prompt:           prompt,
 			Format:           req.Format,
 			CheckpointStart:  checkpointStart,