瀏覽代碼

on disk copy

Josh Yan 10 月之前
父節點
當前提交
a993a3a85c
共有 4 個文件被更改,包括 77 次插入2 次删除
  1. 9 0
      api/client.go
  2. 5 0
      api/types.go
  3. 58 2
      cmd/cmd.go
  4. 5 0
      server/routes.go

+ 9 - 0
api/client.go

@@ -403,3 +403,12 @@ func (c *Client) IsLocal() bool {
 
 	return false
 }
+
+// EnvConfig returns the environment configuration for the server.
+func (c *Client) ServerConfig(ctx context.Context) (*ServerConfig, error) {
+	var config ServerConfig
+	if err := c.do(ctx, http.MethodGet, "/api/config", nil, &config); err != nil {
+		return nil, err
+	}
+	return &config, nil
+}

+ 5 - 0
api/types.go

@@ -451,6 +451,11 @@ type ModelDetails struct {
 	QuantizationLevel string   `json:"quantization_level"`
 }
 
+// EnvConfig is the configuration for the environment.
+type ServerConfig struct {
+	ModelDir string `json:"model_dir"`
+}
+
 func (m *Metrics) Summary() {
 	if m.TotalDuration > 0 {
 		fmt.Fprintf(os.Stderr, "total duration:       %v\n", m.TotalDuration)

+ 58 - 2
cmd/cmd.go

@@ -287,7 +287,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 	// Resolve server to IP
 	// Check if server is local
 	if client.IsLocal() {
-		err := createBlobLocal(cmd, client, digest)
+		err := createBlobLocal(cmd.Context(), client, path, digest)
 		if err == nil {
 			return digest, nil
 		}
@@ -299,10 +299,66 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 	return digest, nil
 }
 
-func createBlobLocal(cmd *cobra.Command, client *api.Client, digest string) error {
+func createBlobLocal(ctx context.Context, client *api.Client, path string, digest string) error {
 	// This function should be called if the server is local
 	// It should find the model directory, copy the blob over, and return the digest
 
+	// Get the model directory
+	config, err := client.ServerConfig(ctx)
+	if err != nil {
+		return err
+	}
+
+	modelDir := config.ModelDir
+
+	// Get blob destination
+	digest = strings.ReplaceAll(digest, ":", "-")
+	dest := filepath.Join(modelDir, "blobs", digest)
+	dirPath := filepath.Dir(dest)
+	if digest == "" {
+		dirPath = dest
+	}
+
+	if err := os.MkdirAll(dirPath, 0o755); err != nil {
+		return err
+	}
+
+	// Check blob exists
+	_, err = os.Stat(dest)
+	switch {
+	case errors.Is(err, os.ErrNotExist):
+		// noop
+	case err != nil:
+		return err
+	default:
+		// blob already exists
+		return nil
+	}
+
+	// Copy blob over
+	sourceFile, err := os.Open(path)
+	if err != nil {
+		return fmt.Errorf("could not open source file: %v", err)
+	}
+	defer sourceFile.Close()
+
+	destFile, err := os.Create(dest)
+	if err != nil {
+		return fmt.Errorf("could not create destination file: %v", err)
+	}
+	defer destFile.Close()
+
+	_, err = io.Copy(destFile, sourceFile)
+	if err != nil {
+		return fmt.Errorf("error copying file: %v", err)
+	}
+
+	err = destFile.Sync()
+	if err != nil {
+		return fmt.Errorf("error flushing file: %v", err)
+	}
+
+	return nil
 }
 
 func RunHandler(cmd *cobra.Command, args []string) error {

+ 5 - 0
server/routes.go

@@ -1069,6 +1069,7 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.POST("/api/blobs/:digest", s.CreateBlobHandler)
 	r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
 	r.GET("/api/ps", s.ProcessHandler)
+	r.GET("/api/config", s.ConfigHandler)
 
 	// Compatibility endpoints
 	r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
@@ -1422,3 +1423,7 @@ func handleScheduleError(c *gin.Context, name string, err error) {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 	}
 }
+
+func (s *Server) ConfigHandler(c *gin.Context) {
+	c.JSON(http.StatusOK, api.ServerConfig{ModelDir: envconfig.ModelsDir})
+}