-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprotovalidate.go
98 lines (79 loc) · 2.61 KB
/
protovalidate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/*
Copyright 2024 Yuchen Cheng.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package go_grpc_protovalidate
import (
"context"
"errors"
"github.com/bufbuild/protovalidate-go"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
type validator struct {
*protovalidate.Validator
}
func (v *validator) Validate(ctx context.Context, m interface{}) error {
msg := m.(proto.Message)
err := v.Validator.Validate(msg)
var valErr *protovalidate.ValidationError
if errors.As(err, &valErr) {
st := status.New(codes.InvalidArgument, err.Error())
print(err.Error())
violations := make([]*errdetails.BadRequest_FieldViolation, 0, len(valErr.Violations))
for _, v := range valErr.Violations {
violations = append(violations, &errdetails.BadRequest_FieldViolation{
Field: v.GetFieldPath(),
Description: v.GetMessage(),
})
}
ds, err := st.WithDetails(
&errdetails.ErrorInfo{Reason: "INVALID_ARGUMENT"},
&errdetails.BadRequest{FieldViolations: violations},
)
if err != nil {
return st.Err()
}
return ds.Err()
}
return nil
}
func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
o := evaluateOpts(opts)
v := &validator{Validator: o.validator}
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if err := v.Validate(ctx, req); err != nil {
return nil, err
}
return handler(ctx, req)
}
}
func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor {
o := evaluateOpts(opts)
v := &validator{Validator: o.validator}
return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return handler(srv, &serverStream{ServerStream: stream, ctx: stream.Context(), validator: v})
}
}
type serverStream struct {
grpc.ServerStream
ctx context.Context
validator *validator
}
func (ss *serverStream) RecvMsg(m interface{}) error {
if err := ss.ServerStream.RecvMsg(m); err != nil {
return err
}
return ss.validator.Validate(ss.ctx, m)
}