123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- package shared
- import (
- "bytes"
- "fmt"
- "text/template"
- "github.com/envoyproxy/protoc-gen-validate/validate"
- pgs "github.com/lyft/protoc-gen-star"
- "google.golang.org/protobuf/proto"
- )
- type RuleContext struct {
- Field pgs.Field
- Rules proto.Message
- MessageRules *validate.MessageRules
- Typ string
- WrapperTyp string
- OnKey bool
- Index string
- AccessorOverride string
- }
- func rulesContext(f pgs.Field) (out RuleContext, err error) {
- out.Field = f
- var rules validate.FieldRules
- if _, err = f.Extension(validate.E_Rules, &rules); err != nil {
- return
- }
- var wrapped bool
- if out.Typ, out.Rules, out.MessageRules, wrapped = resolveRules(f.Type(), &rules); wrapped {
- out.WrapperTyp = out.Typ
- out.Typ = "wrapper"
- }
- if out.Typ == "error" {
- err = fmt.Errorf("unknown rule type (%T)", rules.Type)
- }
- return
- }
- func (ctx RuleContext) Key(name, idx string) (out RuleContext, err error) {
- rules, ok := ctx.Rules.(*validate.MapRules)
- if !ok {
- err = fmt.Errorf("cannot get Key RuleContext from %T", ctx.Field)
- return
- }
- out.Field = ctx.Field
- out.AccessorOverride = name
- out.Index = idx
- out.Typ, out.Rules, out.MessageRules, _ = resolveRules(ctx.Field.Type().Key(), rules.GetKeys())
- if out.Typ == "error" {
- err = fmt.Errorf("unknown rule type (%T)", rules)
- }
- return
- }
- func (ctx RuleContext) Elem(name, idx string) (out RuleContext, err error) {
- out.Field = ctx.Field
- out.AccessorOverride = name
- out.Index = idx
- var rules *validate.FieldRules
- switch r := ctx.Rules.(type) {
- case *validate.MapRules:
- rules = r.GetValues()
- case *validate.RepeatedRules:
- rules = r.GetItems()
- default:
- err = fmt.Errorf("cannot get Elem RuleContext from %T", ctx.Field)
- return
- }
- var wrapped bool
- if out.Typ, out.Rules, out.MessageRules, wrapped = resolveRules(ctx.Field.Type().Element(), rules); wrapped {
- out.WrapperTyp = out.Typ
- out.Typ = "wrapper"
- }
- if out.Typ == "error" {
- err = fmt.Errorf("unknown rule type (%T)", rules)
- }
- return
- }
- func (ctx RuleContext) Unwrap(name string) (out RuleContext, err error) {
- if ctx.Typ != "wrapper" {
- err = fmt.Errorf("cannot unwrap non-wrapper type %q", ctx.Typ)
- return
- }
- return RuleContext{
- Field: ctx.Field,
- Rules: ctx.Rules,
- MessageRules: ctx.MessageRules,
- Typ: ctx.WrapperTyp,
- AccessorOverride: name,
- }, nil
- }
- func Render(tpl *template.Template) func(ctx RuleContext) (string, error) {
- return func(ctx RuleContext) (string, error) {
- var b bytes.Buffer
- err := tpl.ExecuteTemplate(&b, ctx.Typ, ctx)
- return b.String(), err
- }
- }
- func resolveRules(typ interface{ IsEmbed() bool }, rules *validate.FieldRules) (ruleType string, rule proto.Message, messageRule *validate.MessageRules, wrapped bool) {
- switch r := rules.GetType().(type) {
- case *validate.FieldRules_Float:
- ruleType, rule, wrapped = "float", r.Float, typ.IsEmbed()
- case *validate.FieldRules_Double:
- ruleType, rule, wrapped = "double", r.Double, typ.IsEmbed()
- case *validate.FieldRules_Int32:
- ruleType, rule, wrapped = "int32", r.Int32, typ.IsEmbed()
- case *validate.FieldRules_Int64:
- ruleType, rule, wrapped = "int64", r.Int64, typ.IsEmbed()
- case *validate.FieldRules_Uint32:
- ruleType, rule, wrapped = "uint32", r.Uint32, typ.IsEmbed()
- case *validate.FieldRules_Uint64:
- ruleType, rule, wrapped = "uint64", r.Uint64, typ.IsEmbed()
- case *validate.FieldRules_Sint32:
- ruleType, rule, wrapped = "sint32", r.Sint32, false
- case *validate.FieldRules_Sint64:
- ruleType, rule, wrapped = "sint64", r.Sint64, false
- case *validate.FieldRules_Fixed32:
- ruleType, rule, wrapped = "fixed32", r.Fixed32, false
- case *validate.FieldRules_Fixed64:
- ruleType, rule, wrapped = "fixed64", r.Fixed64, false
- case *validate.FieldRules_Sfixed32:
- ruleType, rule, wrapped = "sfixed32", r.Sfixed32, false
- case *validate.FieldRules_Sfixed64:
- ruleType, rule, wrapped = "sfixed64", r.Sfixed64, false
- case *validate.FieldRules_Bool:
- ruleType, rule, wrapped = "bool", r.Bool, typ.IsEmbed()
- case *validate.FieldRules_String_:
- ruleType, rule, wrapped = "string", r.String_, typ.IsEmbed()
- case *validate.FieldRules_Bytes:
- ruleType, rule, wrapped = "bytes", r.Bytes, typ.IsEmbed()
- case *validate.FieldRules_Enum:
- ruleType, rule, wrapped = "enum", r.Enum, false
- case *validate.FieldRules_Repeated:
- ruleType, rule, wrapped = "repeated", r.Repeated, false
- case *validate.FieldRules_Map:
- ruleType, rule, wrapped = "map", r.Map, false
- case *validate.FieldRules_Any:
- ruleType, rule, wrapped = "any", r.Any, false
- case *validate.FieldRules_Duration:
- ruleType, rule, wrapped = "duration", r.Duration, false
- case *validate.FieldRules_Timestamp:
- ruleType, rule, wrapped = "timestamp", r.Timestamp, false
- case nil:
- if ft, ok := typ.(pgs.FieldType); ok && ft.IsRepeated() {
- return "repeated", &validate.RepeatedRules{}, rules.Message, false
- } else if ok && ft.IsMap() && ft.Element().IsEmbed() {
- return "map", &validate.MapRules{}, rules.Message, false
- } else if typ.IsEmbed() {
- return "message", rules.GetMessage(), rules.GetMessage(), false
- }
- return "none", nil, nil, false
- default:
- ruleType, rule, wrapped = "error", nil, false
- }
- return ruleType, rule, rules.Message, wrapped
- }
|