123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- package sample
- import (
- "fmt"
- "math"
- "github.com/ollama/ollama/model"
- )
- type PushdownSampler struct {
- // stateful
- curNode *PDANode
- proc model.TextProcessor
- stateToNodeMap map[JSONState]*PDANode
- braceStack []rune
- }
- func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
- startNode, stateToNodeMap, err := BuildGraph(proc)
- if err != nil {
- panic(err)
- }
- err = PreComputeValidStates(stateToNodeMap, proc)
- if err != nil {
- panic(err)
- }
- // for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
- // token, err := proc.Decode([]int32{int32(id)})
- // if err != nil {
- // panic(err)
- // }
- // fmt.Println("id", id, "node", node, "token", token)
- // }
- // time.Sleep(10 * time.Second)
- return &PushdownSampler{
- curNode: startNode,
- proc: proc,
- stateToNodeMap: stateToNodeMap,
- braceStack: []rune{},
- }
- }
- func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
- fmt.Println("sample:", s.curNode.State)
- switch s.curNode.State {
- case StateInObjectEnd:
- // force finish if no braces left
- if len(s.braceStack) == 0 {
- s.curNode = NewPDANode(StateTerminate)
- for i := range logits {
- if s.proc.Is(uint32(i), model.SpecialEOS) {
- logits[i] = 1.0
- } else {
- logits[i] = math.NaN()
- }
- }
- return logits, nil
- }
- valid, err := s.proc.Encode("}")
- if err != nil {
- return nil, err
- }
- for i := range logits {
- for _, token := range valid {
- if i != int(token) {
- logits[i] = math.NaN()
- }
- }
- }
- return logits, nil
- // return logits, nil
- case StateTerminate:
- for i := range logits {
- if s.proc.Is(uint32(i), model.SpecialEOS) {
- logits[i] = 1.0
- } else {
- logits[i] = math.NaN()
- }
- }
- return logits, nil
- // case StateInStringEnd:
- // return logits, nil
- default:
- fmt.Println("masking logits current state", s.curNode.State)
- logits, err := s.maskLogits(logits, s.curNode)
- if err != nil {
- return nil, err
- }
- return logits, nil
- }
- }
- func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
- fmt.Println("update state", s.curNode.State)
- // TODO: need to handle end states and entering object case
- if s.curNode.State == StateInObjectEnd {
- fmt.Println("in object end")
- if len(s.braceStack) > 0 {
- s.braceStack = s.braceStack[:len(s.braceStack)-1]
- return nil
- }
- s.curNode = NewPDANode(StateTerminate)
- // TODO: return here?
- }
- // need this cause there could be multiple transitions
- mappedString, err := s.proc.Decode(tokenSlice)
- if err != nil {
- return err
- }
- for _, r := range mappedString {
- if r == rune('{') {
- s.braceStack = append(s.braceStack, r)
- }
- if r == rune('}') {
- if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
- return fmt.Errorf("unmatched closing brace")
- }
- s.braceStack = s.braceStack[:len(s.braceStack)-1]
- }
- }
- for _, tokenID := range tokenSlice {
- // transition to the next node
- nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
- if !ok {
- return fmt.Errorf("invalid token: %q", mappedString)
- }
- fmt.Println("transitioning to", nextNode)
- s.curNode = s.stateToNodeMap[nextNode]
- }
- return nil
- }
- func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
- for i := range logits {
- _, exists := node.MaskTokenIDToNode[int32(i)]
- if !exists {
- logits[i] = math.NaN()
- }
- }
- return logits, nil
- }
- // TODO: add penalties for string \n stuff
|