|
@@ -10,6 +10,7 @@ import (
|
|
|
"os"
|
|
|
"path/filepath"
|
|
|
"strings"
|
|
|
+ "time"
|
|
|
|
|
|
"github.com/ollama/ollama/cache"
|
|
|
"github.com/ollama/ollama/ml"
|
|
@@ -27,6 +28,7 @@ var args struct {
|
|
|
}
|
|
|
|
|
|
func temp() error {
|
|
|
+ start := time.Now()
|
|
|
flag.IntVar(&args.n, "n", 10, "number of samples")
|
|
|
flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
|
|
|
flag.StringVar(&args.image, "image", "", "path to image file")
|
|
@@ -104,9 +106,11 @@ func temp() error {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- pdaSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
|
|
|
- var stringBuffer string
|
|
|
+ pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
|
|
|
+
|
|
|
var offset int
|
|
|
+ var stringBuffer string
|
|
|
+ var firstTokenTime time.Duration
|
|
|
for range args.n {
|
|
|
logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
|
|
|
if err != nil {
|
|
@@ -118,15 +122,21 @@ func temp() error {
|
|
|
for i, f32 := range f32s {
|
|
|
f64s[i] = float64(f32)
|
|
|
}
|
|
|
+ sampleTime := time.Now()
|
|
|
+ samplers := []sample.Sampler{
|
|
|
+ pushdownSampler,
|
|
|
+ // sample.Weighed(),
|
|
|
+ // sample.TopP(0.9),
|
|
|
+ // sample.Weighed(),
|
|
|
+ sample.Greedy(),
|
|
|
+ }
|
|
|
|
|
|
- // do sampling
|
|
|
- // []ints back
|
|
|
- // ints map to sampled logits
|
|
|
- f64s, err = sample.Sample(f64s, pdaSampler, sample.Greedy())
|
|
|
-
|
|
|
+ f64s, err = sample.Sample(f64s, samplers...)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
+ finishTime := time.Now()
|
|
|
+ fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
|
|
|
|
|
|
var outputIDs []int32
|
|
|
for _, f64 := range f64s {
|
|
@@ -134,7 +144,6 @@ func temp() error {
|
|
|
outputIDs = append(outputIDs, int32(f64))
|
|
|
}
|
|
|
}
|
|
|
- pdaSampler.UpdateState(outputIDs)
|
|
|
|
|
|
if len(outputIDs) == 0 {
|
|
|
break
|
|
@@ -147,14 +156,29 @@ func temp() error {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
- // fmt.Print(s)
|
|
|
+ if firstTokenTime == 0 {
|
|
|
+ firstTokenTime = time.Since(start)
|
|
|
+ fmt.Printf("Time to first token: %vms\n", firstTokenTime.Milliseconds())
|
|
|
+ }
|
|
|
+
|
|
|
+ // fmt.Printf("--- token: %q\n", s)
|
|
|
+ // fmt.Printf("--- outputIDs: %v\n", outputIDs)
|
|
|
stringBuffer += s
|
|
|
fmt.Println("--- stringBuffer", stringBuffer)
|
|
|
+
|
|
|
+ err = pushdownSampler.UpdateState(outputIDs)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
inputIDs = append(inputIDs, outputIDs...)
|
|
|
if args.cache {
|
|
|
offset = len(inputIDs) - 1
|
|
|
}
|
|
|
}
|
|
|
+ fmt.Println("\n------ Output: ------")
|
|
|
+ fmt.Println(stringBuffer)
|
|
|
+ fmt.Println("--------------------")
|
|
|
|
|
|
return nil
|
|
|
}
|