Przeglądaj źródła

Fix show parameters (#2017)

Patrick Devine 1 rok temu
rodzic
commit
eef50accb4
2 zmienionych plików z 41 dodań i 20 usunięć
  1. 3 19
      server/routes.go
  2. 38 1
      server/routes_test.go

+ 3 - 19
server/routes.go

@@ -15,7 +15,6 @@ import (
 	"path/filepath"
 	"reflect"
 	"runtime"
-	"strconv"
 	"strings"
 	"sync"
 	"syscall"
@@ -668,27 +667,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
 	cs := 30
 	for k, v := range model.Options {
 		switch val := v.(type) {
-		case string:
-			params = append(params, fmt.Sprintf("%-*s %s", cs, k, val))
-		case int:
-			params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(val)))
-		case float64:
-			params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(val, 'f', 0, 64)))
-		case bool:
-			params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(val)))
 		case []interface{}:
 			for _, nv := range val {
-				switch nval := nv.(type) {
-				case string:
-					params = append(params, fmt.Sprintf("%-*s %s", cs, k, nval))
-				case int:
-					params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(nval)))
-				case float64:
-					params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(nval, 'f', 0, 64)))
-				case bool:
-					params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(nval)))
-				}
+				params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
 			}
+		default:
+			params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
 		}
 	}
 	resp.Parameters = strings.Join(params, "\n")

+ 38 - 1
server/routes_test.go

@@ -9,6 +9,7 @@ import (
 	"net/http"
 	"net/http/httptest"
 	"os"
+	"sort"
 	"strings"
 	"testing"
 
@@ -50,7 +51,7 @@ func Test_Routes(t *testing.T) {
 	createTestModel := func(t *testing.T, name string) {
 		fname := createTestFile(t, "ollama-model")
 
-		modelfile := strings.NewReader(fmt.Sprintf("FROM %s", fname))
+		modelfile := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
 		commands, err := parser.Parse(modelfile)
 		assert.Nil(t, err)
 		fn := func(resp api.ProgressResponse) {
@@ -167,6 +168,42 @@ func Test_Routes(t *testing.T) {
 				assert.Equal(t, "beefsteak:latest", model.ShortName)
 			},
 		},
+		{
+			Name:   "Show Model Handler",
+			Method: http.MethodPost,
+			Path:   "/api/show",
+			Setup: func(t *testing.T, req *http.Request) {
+				createTestModel(t, "show-model")
+				showReq := api.ShowRequest{Model: "show-model"}
+				jsonData, err := json.Marshal(showReq)
+				assert.Nil(t, err)
+				req.Body = io.NopCloser(bytes.NewReader(jsonData))
+			},
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, contentType, "application/json; charset=utf-8")
+				body, err := io.ReadAll(resp.Body)
+				assert.Nil(t, err)
+
+				var showResp api.ShowResponse
+				err = json.Unmarshal(body, &showResp)
+				assert.Nil(t, err)
+
+				var params []string
+				paramsSplit := strings.Split(showResp.Parameters, "\n")
+				for _, p := range paramsSplit {
+					params = append(params, strings.Join(strings.Fields(p), " "))
+				}
+				sort.Strings(params)
+				expectedParams := []string{
+					"seed 42",
+					"stop \"bar\"",
+					"stop \"foo\"",
+					"top_p 0.9",
+				}
+				assert.Equal(t, expectedParams, params)
+			},
+		},
 	}
 
 	s, err := setupServer(t)