Browse Source

Add unit test of API routes (#1528)

Patrick Devine 1 year ago
parent
commit
630518f0d9
4 changed files with 122 additions and 30 deletions
  1. 1 6
      cmd/cmd.go
  2. 3 0
      go.mod
  3. 48 24
      server/routes.go
  4. 70 0
      server/routes_test.go

+ 1 - 6
cmd/cmd.go

@@ -1035,12 +1035,7 @@ func RunServer(cmd *cobra.Command, _ []string) error {
 		return err
 	}
 
-	var origins []string
-	if o := os.Getenv("OLLAMA_ORIGINS"); o != "" {
-		origins = strings.Split(o, ",")
-	}
-
-	return server.Serve(ln, origins)
+	return server.Serve(ln)
 }
 
 func getImageData(filePath string) ([]byte, error) {

+ 3 - 0
go.mod

@@ -7,11 +7,14 @@ require (
 	github.com/gin-gonic/gin v1.9.1
 	github.com/olekukonko/tablewriter v0.0.5
 	github.com/spf13/cobra v1.7.0
+	github.com/stretchr/testify v1.8.3
 	golang.org/x/sync v0.3.0
 )
 
 require (
+	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/mattn/go-runewidth v0.0.14 // indirect
+	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/rivo/uniseg v0.2.0 // indirect
 )
 

+ 48 - 24
server/routes.go

@@ -32,6 +32,10 @@ import (
 
 var mode string = gin.DebugMode
 
+type Server struct {
+	WorkDir string
+}
+
 func init() {
 	switch mode {
 	case gin.DebugMode:
@@ -800,27 +804,27 @@ var defaultAllowOrigins = []string{
 	"0.0.0.0",
 }
 
-func Serve(ln net.Listener, allowOrigins []string) error {
-	if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
-		// clean up unused layers and manifests
-		if err := PruneLayers(); err != nil {
-			return err
-		}
+func NewServer() (*Server, error) {
+	workDir, err := os.MkdirTemp("", "ollama")
+	if err != nil {
+		return nil, err
+	}
 
-		manifestsPath, err := GetManifestPath()
-		if err != nil {
-			return err
-		}
+	return &Server{
+		WorkDir: workDir,
+	}, nil
+}
 
-		if err := PruneDirectory(manifestsPath); err != nil {
-			return err
-		}
+func (s *Server) GenerateRoutes() http.Handler {
+	var origins []string
+	if o := os.Getenv("OLLAMA_ORIGINS"); o != "" {
+		origins = strings.Split(o, ",")
 	}
 
 	config := cors.DefaultConfig()
 	config.AllowWildcard = true
 
-	config.AllowOrigins = allowOrigins
+	config.AllowOrigins = origins
 	for _, allowOrigin := range defaultAllowOrigins {
 		config.AllowOrigins = append(config.AllowOrigins,
 			fmt.Sprintf("http://%s", allowOrigin),
@@ -830,17 +834,11 @@ func Serve(ln net.Listener, allowOrigins []string) error {
 		)
 	}
 
-	workDir, err := os.MkdirTemp("", "ollama")
-	if err != nil {
-		return err
-	}
-	defer os.RemoveAll(workDir)
-
 	r := gin.Default()
 	r.Use(
 		cors.New(config),
 		func(c *gin.Context) {
-			c.Set("workDir", workDir)
+			c.Set("workDir", s.WorkDir)
 			c.Next()
 		},
 	)
@@ -868,8 +866,34 @@ func Serve(ln net.Listener, allowOrigins []string) error {
 		})
 	}
 
+	return r
+}
+
+func Serve(ln net.Listener) error {
+	if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
+		// clean up unused layers and manifests
+		if err := PruneLayers(); err != nil {
+			return err
+		}
+
+		manifestsPath, err := GetManifestPath()
+		if err != nil {
+			return err
+		}
+
+		if err := PruneDirectory(manifestsPath); err != nil {
+			return err
+		}
+	}
+
+	s, err := NewServer()
+	if err != nil {
+		return err
+	}
+	r := s.GenerateRoutes()
+
 	log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
-	s := &http.Server{
+	srvr := &http.Server{
 		Handler: r,
 	}
 
@@ -881,7 +905,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
 		if loaded.runner != nil {
 			loaded.runner.Close()
 		}
-		os.RemoveAll(workDir)
+		os.RemoveAll(s.WorkDir)
 		os.Exit(0)
 	}()
 
@@ -892,7 +916,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
 		}
 	}
 
-	return s.Serve(ln)
+	return srvr.Serve(ln)
 }
 
 func waitForStream(c *gin.Context, ch chan interface{}) {

+ 70 - 0
server/routes_test.go

@@ -0,0 +1,70 @@
+package server
+
+import (
+	"context"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func setupServer(t *testing.T) (*Server, error) {
+	t.Helper()
+
+	return NewServer()
+}
+
+func Test_Routes(t *testing.T) {
+	type testCase struct {
+		Name     string
+		Method   string
+		Path     string
+		Setup    func(t *testing.T, req *http.Request)
+		Expected func(t *testing.T, resp *http.Response)
+	}
+
+	testCases := []testCase{
+		{
+			Name:   "Version Handler",
+			Method: http.MethodGet,
+			Path:   "/api/version",
+			Setup: func(t *testing.T, req *http.Request) {
+			},
+			Expected: func(t *testing.T, resp *http.Response) {
+				contentType := resp.Header.Get("Content-Type")
+				assert.Equal(t, contentType, "application/json; charset=utf-8")
+				body, err := io.ReadAll(resp.Body)
+				assert.Nil(t, err)
+				assert.Equal(t, `{"version":"0.0.0"}`, string(body))
+			},
+		},
+	}
+
+	s, err := setupServer(t)
+	assert.Nil(t, err)
+
+	router := s.GenerateRoutes()
+
+	httpSrv := httptest.NewServer(router)
+	t.Cleanup(httpSrv.Close)
+
+	for _, tc := range testCases {
+		u := httpSrv.URL + tc.Path
+		req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
+		assert.Nil(t, err)
+
+		if tc.Setup != nil {
+			tc.Setup(t, req)
+		}
+
+		resp, err := httpSrv.Client().Do(req)
+		assert.Nil(t, err)
+
+		if tc.Expected != nil {
+			tc.Expected(t, resp)
+		}
+	}
+
+}