template.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. package template
  2. import (
  3. "bytes"
  4. "embed"
  5. "encoding/json"
  6. "errors"
  7. "io"
  8. "math"
  9. "slices"
  10. "strings"
  11. "sync"
  12. "text/template"
  13. "text/template/parse"
  14. "github.com/agnivade/levenshtein"
  15. "golang.org/x/exp/maps"
  16. )
  17. //go:embed index.json
  18. var indexBytes []byte
  19. //go:embed *.gotmpl
  20. var templatesFS embed.FS
  21. var templatesOnce = sync.OnceValues(func() ([]*named, error) {
  22. var templates []*named
  23. if err := json.Unmarshal(indexBytes, &templates); err != nil {
  24. return nil, err
  25. }
  26. for _, t := range templates {
  27. bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
  28. if err != nil {
  29. return nil, err
  30. }
  31. // normalize line endings
  32. t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
  33. }
  34. return templates, nil
  35. })
  36. type named struct {
  37. Name string `json:"name"`
  38. Template string `json:"template"`
  39. Bytes []byte
  40. }
  41. func (t named) Reader() io.Reader {
  42. return bytes.NewReader(t.Bytes)
  43. }
  44. func Named(s string) (*named, error) {
  45. templates, err := templatesOnce()
  46. if err != nil {
  47. return nil, err
  48. }
  49. var template *named
  50. score := math.MaxInt
  51. for _, t := range templates {
  52. if s := levenshtein.ComputeDistance(s, t.Template); s < score {
  53. score = s
  54. template = t
  55. }
  56. }
  57. if score < 100 {
  58. return template, nil
  59. }
  60. return nil, errors.New("no matching template found")
  61. }
  62. type Template struct {
  63. *template.Template
  64. raw string
  65. }
  66. func (t *Template) String() string {
  67. return t.raw
  68. }
  69. var DefaultTemplate, _ = Parse("{{ .Prompt }}")
  70. func Parse(s string) (*Template, error) {
  71. t, err := template.New("").Option("missingkey=zero").Parse(s)
  72. if err != nil {
  73. return nil, err
  74. }
  75. return &Template{Template: t, raw: s}, nil
  76. }
  77. func (t *Template) Vars() []string {
  78. var vars []string
  79. for _, n := range t.Tree.Root.Nodes {
  80. vars = append(vars, parseNode(n)...)
  81. }
  82. set := make(map[string]struct{})
  83. for _, n := range vars {
  84. set[strings.ToLower(n)] = struct{}{}
  85. }
  86. vars = maps.Keys(set)
  87. slices.Sort(vars)
  88. return vars
  89. }
  90. func parseNode(n parse.Node) []string {
  91. switch n := n.(type) {
  92. case *parse.ActionNode:
  93. return parseNode(n.Pipe)
  94. case *parse.IfNode:
  95. names := parseNode(n.Pipe)
  96. names = append(names, parseNode(n.List)...)
  97. if n.ElseList != nil {
  98. names = append(names, parseNode(n.ElseList)...)
  99. }
  100. return names
  101. case *parse.RangeNode:
  102. names := parseNode(n.Pipe)
  103. names = append(names, parseNode(n.List)...)
  104. if n.ElseList != nil {
  105. names = append(names, parseNode(n.ElseList)...)
  106. }
  107. return names
  108. case *parse.WithNode:
  109. names := parseNode(n.Pipe)
  110. names = append(names, parseNode(n.List)...)
  111. if n.ElseList != nil {
  112. names = append(names, parseNode(n.ElseList)...)
  113. }
  114. return names
  115. case *parse.PipeNode:
  116. var names []string
  117. for _, c := range n.Cmds {
  118. for _, a := range c.Args {
  119. names = append(names, parseNode(a)...)
  120. }
  121. }
  122. return names
  123. case *parse.ListNode:
  124. var names []string
  125. for _, n := range n.Nodes {
  126. names = append(names, parseNode(n)...)
  127. }
  128. return names
  129. case *parse.FieldNode:
  130. return n.Ident
  131. }
  132. return nil
  133. }