import.go 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. package build
  2. import (
  3. "errors"
  4. "fmt"
  5. "os"
  6. "github.com/ollama/ollama/x/build/internal/blobstore"
  7. "github.com/ollama/ollama/x/encoding/gguf"
  8. )
  9. func importError(err error) (blobstore.ID, gguf.Info, int64, error) {
  10. return blobstore.ID{}, gguf.Info{}, 0, err
  11. }
  12. func (s *Server) importModel(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
  13. info, err := os.Stat(path)
  14. if err != nil {
  15. return importError(err)
  16. }
  17. if info.IsDir() {
  18. return s.importSafeTensor(path)
  19. } else {
  20. return s.importGGUF(path)
  21. }
  22. }
  23. func (s *Server) importGGUF(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
  24. f, err := os.Open(path)
  25. if err != nil {
  26. return importError(err)
  27. }
  28. defer f.Close()
  29. info, err := gguf.StatReader(f)
  30. if errors.Is(err, gguf.ErrBadMagic) {
  31. return importError(ErrUnsupportedModelFormat)
  32. }
  33. if err != nil {
  34. return importError(err)
  35. }
  36. if info.FileType == 0 {
  37. return importError(fmt.Errorf("%w: %q", ErrMissingFileType, path))
  38. }
  39. id, size, err := s.st.Put(f)
  40. if err != nil {
  41. return importError(err)
  42. }
  43. return id, info, size, nil
  44. }
  45. func (s *Server) importSafeTensor(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
  46. path, err := convertSafeTensorToGGUF(path)
  47. if err != nil {
  48. return importError(err)
  49. }
  50. return s.importGGUF(path)
  51. }