Browse Source

fix stream errors

Michael Yang 1 year ago
parent
commit
1f27d7f1b8
3 changed files with 13 additions and 15 deletions
  1. 4 0
      api/client.go
  2. 0 7
      server/images.go
  3. 9 8
      server/routes.go

+ 4 - 0
api/client.go

@@ -131,6 +131,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
 			return fmt.Errorf("unmarshal: %w", err)
 			return fmt.Errorf("unmarshal: %w", err)
 		}
 		}
 
 
+		if errorResponse.Error != "" {
+			return fmt.Errorf("stream: %s", errorResponse.Error)
+		}
+
 		if response.StatusCode >= 400 {
 		if response.StatusCode >= 400 {
 			return StatusError{
 			return StatusError{
 				StatusCode: response.StatusCode,
 				StatusCode: response.StatusCode,

+ 0 - 7
server/images.go

@@ -192,7 +192,6 @@ func CreateModel(name string, path string, fn func(status string)) error {
 	fn("parsing modelfile")
 	fn("parsing modelfile")
 	commands, err := parser.Parse(mf)
 	commands, err := parser.Parse(mf)
 	if err != nil {
 	if err != nil {
-		fn(fmt.Sprintf("error: %v", err))
 		return err
 		return err
 	}
 	}
 
 
@@ -227,14 +226,12 @@ func CreateModel(name string, path string, fn func(status string)) error {
 				fn("creating model layer")
 				fn("creating model layer")
 				file, err := os.Open(fp)
 				file, err := os.Open(fp)
 				if err != nil {
 				if err != nil {
-					fn(fmt.Sprintf("couldn't find model '%s'", c.Args))
 					return fmt.Errorf("failed to open file: %v", err)
 					return fmt.Errorf("failed to open file: %v", err)
 				}
 				}
 				defer file.Close()
 				defer file.Close()
 
 
 				l, err := CreateLayer(file)
 				l, err := CreateLayer(file)
 				if err != nil {
 				if err != nil {
-					fn(fmt.Sprintf("couldn't create model layer: %v", err))
 					return fmt.Errorf("failed to create layer: %v", err)
 					return fmt.Errorf("failed to create layer: %v", err)
 				}
 				}
 				l.MediaType = "application/vnd.ollama.image.model"
 				l.MediaType = "application/vnd.ollama.image.model"
@@ -244,7 +241,6 @@ func CreateModel(name string, path string, fn func(status string)) error {
 				for _, l := range mf.Layers {
 				for _, l := range mf.Layers {
 					newLayer, err := GetLayerWithBufferFromLayer(l)
 					newLayer, err := GetLayerWithBufferFromLayer(l)
 					if err != nil {
 					if err != nil {
-						fn(fmt.Sprintf("couldn't read layer: %v", err))
 						return err
 						return err
 					}
 					}
 					layers = append(layers, newLayer)
 					layers = append(layers, newLayer)
@@ -304,7 +300,6 @@ func CreateModel(name string, path string, fn func(status string)) error {
 
 
 	err = SaveLayers(layers, fn, false)
 	err = SaveLayers(layers, fn, false)
 	if err != nil {
 	if err != nil {
-		fn(fmt.Sprintf("error saving layers: %v", err))
 		return err
 		return err
 	}
 	}
 
 
@@ -312,7 +307,6 @@ func CreateModel(name string, path string, fn func(status string)) error {
 	fn("writing manifest")
 	fn("writing manifest")
 	err = CreateManifest(name, cfg, manifestLayers)
 	err = CreateManifest(name, cfg, manifestLayers)
 	if err != nil {
 	if err != nil {
-		fn(fmt.Sprintf("error creating manifest: %v", err))
 		return err
 		return err
 	}
 	}
 
 
@@ -610,7 +604,6 @@ func PullModel(name, username, password string, fn func(api.ProgressResponse)) e
 
 
 	for _, layer := range layers {
 	for _, layer := range layers {
 		if err := downloadBlob(mp, layer.Digest, username, password, fn); err != nil {
 		if err := downloadBlob(mp, layer.Digest, username, password, fn); err != nil {
-			fn(api.ProgressResponse{Status: fmt.Sprintf("error downloading: %v", err), Digest: layer.Digest})
 			return err
 			return err
 		}
 		}
 	}
 	}

+ 9 - 8
server/routes.go

@@ -60,7 +60,7 @@ func generate(c *gin.Context) {
 	ch := make(chan any)
 	ch := make(chan any)
 	go func() {
 	go func() {
 		defer close(ch)
 		defer close(ch)
-		llm.Predict(req.Context, prompt, func(r api.GenerateResponse) {
+		fn := func(r api.GenerateResponse) {
 			r.Model = req.Model
 			r.Model = req.Model
 			r.CreatedAt = time.Now().UTC()
 			r.CreatedAt = time.Now().UTC()
 			if r.Done {
 			if r.Done {
@@ -68,7 +68,11 @@ func generate(c *gin.Context) {
 			}
 			}
 
 
 			ch <- r
 			ch <- r
-		})
+		}
+
+		if err := llm.Predict(req.Context, prompt, fn); err != nil {
+			ch <- gin.H{"error": err.Error()}
+		}
 	}()
 	}()
 
 
 	streamResponse(c, ch)
 	streamResponse(c, ch)
@@ -89,8 +93,7 @@ func pull(c *gin.Context) {
 		}
 		}
 
 
 		if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
 		if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
-			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-			return
+			ch <- gin.H{"error": err.Error()}
 		}
 		}
 	}()
 	}()
 
 
@@ -112,8 +115,7 @@ func push(c *gin.Context) {
 		}
 		}
 
 
 		if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
 		if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
-			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-			return
+			ch <- gin.H{"error": err.Error()}
 		}
 		}
 	}()
 	}()
 
 
@@ -137,8 +139,7 @@ func create(c *gin.Context) {
 		}
 		}
 
 
 		if err := CreateModel(req.Name, req.Path, fn); err != nil {
 		if err := CreateModel(req.Name, req.Path, fn); err != nil {
-			c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
-			return
+			ch <- gin.H{"error": err.Error()}
 		}
 		}
 	}()
 	}()