소스 검색

refactor tensor read

Michael Yang 1 년 전
부모
커밋
cd22855ef8
1개의 변경된 파일61개의 추가작업 그리고 54개의 파일을 삭제
  1. 61 54
      llm/gguf.go

+ 61 - 54
llm/gguf.go

@@ -69,12 +69,65 @@ type tensor struct {
 	name   string
 	kind   uint32
 	offset uint64
-	size   uint64
 
 	// shape is the number of elements in each dimension
 	shape [4]uint64
 }
 
+func (t tensor) blockSize() uint64 {
+	switch {
+	case t.kind < 2:
+		return 1
+	case t.kind < 10:
+		return 32
+	default:
+		return 256
+	}
+}
+
+func (t tensor) typeSize() uint64 {
+	blockSize := t.blockSize()
+
+	switch t.kind {
+	case 0: // FP32
+		return 4
+	case 1: // FP16
+		return 2
+	case 2: // Q4_0
+		return 2 + blockSize/2
+	case 3: // Q4_1
+		return 2 + 2 + blockSize/2
+	case 6: // Q5_0
+		return 2 + 4 + blockSize/2
+	case 7: // Q5_1
+		return 2 + 2 + 4 + blockSize/2
+	case 8: // Q8_0
+		return 2 + blockSize
+	case 9: // Q8_1
+		return 4 + 4 + blockSize
+	case 10: // Q2_K
+		return blockSize/16 + blockSize/4 + 2 + 2
+	case 11: // Q3_K
+		return blockSize/8 + blockSize/4 + 12 + 2
+	case 12: // Q4_K
+		return 2 + 2 + 12 + blockSize/2
+	case 13: // Q5_K
+		return 2 + 2 + 12 + blockSize/8 + blockSize/2
+	case 14: // Q6_K
+		return blockSize/2 + blockSize/4 + blockSize/16 + 2
+	default:
+		return 0
+	}
+}
+
+func (t tensor) parameters() uint64 {
+	return t.shape[0] * t.shape[1] * t.shape[2] * t.shape[3]
+}
+
+func (t tensor) size() uint64 {
+	return t.parameters() * t.typeSize() / t.blockSize()
+}
+
 type ggufModel struct {
 	*containerGGUF
 
@@ -201,61 +254,15 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
 			shape[i] = llm.readU64(rso)
 		}
 
-		kind := llm.readU32(rso)
-		offset := llm.readU64(rso)
-
-		var blockSize uint64
-		switch {
-		case kind < 2:
-			blockSize = 1
-		case kind < 10:
-			blockSize = 32
-		default:
-			blockSize = 256
-		}
-
-		var typeSize uint64
-		switch kind {
-		case 0: // FP32
-			typeSize = 4
-		case 1: // FP16
-			typeSize = 2
-		case 2: // Q4_0
-			typeSize = 2 + blockSize/2
-		case 3: // Q4_1
-			typeSize = 2 + 2 + blockSize/2
-		case 6: // Q5_0
-			typeSize = 2 + 4 + blockSize/2
-		case 7: // Q5_1
-			typeSize = 2 + 2 + 4 + blockSize/2
-		case 8: // Q8_0
-			typeSize = 2 + blockSize
-		case 9: // Q8_1
-			typeSize = 4 + 4 + blockSize
-		case 10: // Q2_K
-			typeSize = blockSize/16 + blockSize/4 + 2 + 2
-		case 11: // Q3_K
-			typeSize = blockSize/8 + blockSize/4 + 12 + 2
-		case 12: // Q4_K
-			typeSize = 2 + 2 + 12 + blockSize/2
-		case 13: // Q5_K
-			typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2
-		case 14: // Q6_K
-			typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2
-		}
-
-		parameters := shape[0] * shape[1] * shape[2] * shape[3]
-		size := parameters * typeSize / blockSize
-
-		llm.tensors = append(llm.tensors, tensor{
+		tensor := tensor{
 			name:   name,
-			kind:   kind,
-			offset: offset,
-			size:   size,
+			kind:   llm.readU32(rso),
+			offset: llm.readU64(rso),
 			shape:  shape,
-		})
+		}
 
-		llm.parameters += parameters
+		llm.tensors = append(llm.tensors, tensor)
+		llm.parameters += tensor.parameters()
 	}
 
 	alignment, ok := llm.kv["general.alignment"].(uint32)
@@ -265,7 +272,7 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
 
 	rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
 	for _, tensor := range llm.tensors {
-		padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1)
+		padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
 		rso.Seek(padded, io.SeekCurrent)
 	}