context.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. package shared
  2. import (
  3. "bytes"
  4. "fmt"
  5. "text/template"
  6. "github.com/envoyproxy/protoc-gen-validate/validate"
  7. pgs "github.com/lyft/protoc-gen-star"
  8. "google.golang.org/protobuf/proto"
  9. )
  10. type RuleContext struct {
  11. Field pgs.Field
  12. Rules proto.Message
  13. MessageRules *validate.MessageRules
  14. Typ string
  15. WrapperTyp string
  16. OnKey bool
  17. Index string
  18. AccessorOverride string
  19. }
  20. func rulesContext(f pgs.Field) (out RuleContext, err error) {
  21. out.Field = f
  22. var rules validate.FieldRules
  23. if _, err = f.Extension(validate.E_Rules, &rules); err != nil {
  24. return
  25. }
  26. var wrapped bool
  27. if out.Typ, out.Rules, out.MessageRules, wrapped = resolveRules(f.Type(), &rules); wrapped {
  28. out.WrapperTyp = out.Typ
  29. out.Typ = "wrapper"
  30. }
  31. if out.Typ == "error" {
  32. err = fmt.Errorf("unknown rule type (%T)", rules.Type)
  33. }
  34. return
  35. }
  36. func (ctx RuleContext) Key(name, idx string) (out RuleContext, err error) {
  37. rules, ok := ctx.Rules.(*validate.MapRules)
  38. if !ok {
  39. err = fmt.Errorf("cannot get Key RuleContext from %T", ctx.Field)
  40. return
  41. }
  42. out.Field = ctx.Field
  43. out.AccessorOverride = name
  44. out.Index = idx
  45. out.Typ, out.Rules, out.MessageRules, _ = resolveRules(ctx.Field.Type().Key(), rules.GetKeys())
  46. if out.Typ == "error" {
  47. err = fmt.Errorf("unknown rule type (%T)", rules)
  48. }
  49. return
  50. }
  51. func (ctx RuleContext) Elem(name, idx string) (out RuleContext, err error) {
  52. out.Field = ctx.Field
  53. out.AccessorOverride = name
  54. out.Index = idx
  55. var rules *validate.FieldRules
  56. switch r := ctx.Rules.(type) {
  57. case *validate.MapRules:
  58. rules = r.GetValues()
  59. case *validate.RepeatedRules:
  60. rules = r.GetItems()
  61. default:
  62. err = fmt.Errorf("cannot get Elem RuleContext from %T", ctx.Field)
  63. return
  64. }
  65. var wrapped bool
  66. if out.Typ, out.Rules, out.MessageRules, wrapped = resolveRules(ctx.Field.Type().Element(), rules); wrapped {
  67. out.WrapperTyp = out.Typ
  68. out.Typ = "wrapper"
  69. }
  70. if out.Typ == "error" {
  71. err = fmt.Errorf("unknown rule type (%T)", rules)
  72. }
  73. return
  74. }
  75. func (ctx RuleContext) Unwrap(name string) (out RuleContext, err error) {
  76. if ctx.Typ != "wrapper" {
  77. err = fmt.Errorf("cannot unwrap non-wrapper type %q", ctx.Typ)
  78. return
  79. }
  80. return RuleContext{
  81. Field: ctx.Field,
  82. Rules: ctx.Rules,
  83. MessageRules: ctx.MessageRules,
  84. Typ: ctx.WrapperTyp,
  85. AccessorOverride: name,
  86. }, nil
  87. }
  88. func Render(tpl *template.Template) func(ctx RuleContext) (string, error) {
  89. return func(ctx RuleContext) (string, error) {
  90. var b bytes.Buffer
  91. err := tpl.ExecuteTemplate(&b, ctx.Typ, ctx)
  92. return b.String(), err
  93. }
  94. }
  95. func resolveRules(typ interface{ IsEmbed() bool }, rules *validate.FieldRules) (ruleType string, rule proto.Message, messageRule *validate.MessageRules, wrapped bool) {
  96. switch r := rules.GetType().(type) {
  97. case *validate.FieldRules_Float:
  98. ruleType, rule, wrapped = "float", r.Float, typ.IsEmbed()
  99. case *validate.FieldRules_Double:
  100. ruleType, rule, wrapped = "double", r.Double, typ.IsEmbed()
  101. case *validate.FieldRules_Int32:
  102. ruleType, rule, wrapped = "int32", r.Int32, typ.IsEmbed()
  103. case *validate.FieldRules_Int64:
  104. ruleType, rule, wrapped = "int64", r.Int64, typ.IsEmbed()
  105. case *validate.FieldRules_Uint32:
  106. ruleType, rule, wrapped = "uint32", r.Uint32, typ.IsEmbed()
  107. case *validate.FieldRules_Uint64:
  108. ruleType, rule, wrapped = "uint64", r.Uint64, typ.IsEmbed()
  109. case *validate.FieldRules_Sint32:
  110. ruleType, rule, wrapped = "sint32", r.Sint32, false
  111. case *validate.FieldRules_Sint64:
  112. ruleType, rule, wrapped = "sint64", r.Sint64, false
  113. case *validate.FieldRules_Fixed32:
  114. ruleType, rule, wrapped = "fixed32", r.Fixed32, false
  115. case *validate.FieldRules_Fixed64:
  116. ruleType, rule, wrapped = "fixed64", r.Fixed64, false
  117. case *validate.FieldRules_Sfixed32:
  118. ruleType, rule, wrapped = "sfixed32", r.Sfixed32, false
  119. case *validate.FieldRules_Sfixed64:
  120. ruleType, rule, wrapped = "sfixed64", r.Sfixed64, false
  121. case *validate.FieldRules_Bool:
  122. ruleType, rule, wrapped = "bool", r.Bool, typ.IsEmbed()
  123. case *validate.FieldRules_String_:
  124. ruleType, rule, wrapped = "string", r.String_, typ.IsEmbed()
  125. case *validate.FieldRules_Bytes:
  126. ruleType, rule, wrapped = "bytes", r.Bytes, typ.IsEmbed()
  127. case *validate.FieldRules_Enum:
  128. ruleType, rule, wrapped = "enum", r.Enum, false
  129. case *validate.FieldRules_Repeated:
  130. ruleType, rule, wrapped = "repeated", r.Repeated, false
  131. case *validate.FieldRules_Map:
  132. ruleType, rule, wrapped = "map", r.Map, false
  133. case *validate.FieldRules_Any:
  134. ruleType, rule, wrapped = "any", r.Any, false
  135. case *validate.FieldRules_Duration:
  136. ruleType, rule, wrapped = "duration", r.Duration, false
  137. case *validate.FieldRules_Timestamp:
  138. ruleType, rule, wrapped = "timestamp", r.Timestamp, false
  139. case nil:
  140. if ft, ok := typ.(pgs.FieldType); ok && ft.IsRepeated() {
  141. return "repeated", &validate.RepeatedRules{}, rules.Message, false
  142. } else if ok && ft.IsMap() && ft.Element().IsEmbed() {
  143. return "map", &validate.MapRules{}, rules.Message, false
  144. } else if typ.IsEmbed() {
  145. return "message", rules.GetMessage(), rules.GetMessage(), false
  146. }
  147. return "none", nil, nil, false
  148. default:
  149. ruleType, rule, wrapped = "error", nil, false
  150. }
  151. return ruleType, rule, rules.Message, wrapped
  152. }