-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.go
274 lines (248 loc) · 7.32 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
package goRPC
import (
"encoding/json"
"errors"
"fmt"
"goRPC/Maincodec/codec"
"goRPC/serviceRegister"
"io"
"log"
"net"
"net/http"
"reflect"
"strings"
"sync"
"time"
)
const MagicNumber = 0x3bef5c
type Option struct {
MagicNumber int //MagicNumber标记这是一个rcp请求
CodecType codec.Type //客户端可以选择不同的编解码器对正文进行编码,这里只实现了god
ConnectTimeout time.Duration //超时连接
HandleTimeout time.Duration
}
var DefaultOption = &Option{
MagicNumber: MagicNumber,
CodecType: codec.GobType,
ConnectTimeout: time.Second * 10, //设置超时时间
}
// RPC Server
type Server struct {
ServiceMap sync.Map
}
// return new server
func NewServer() *Server {
return &Server{}
}
// 默认实例
var DefaultServer = NewServer()
// 接收请求
func (server *Server) Accept(lis net.Listener) {
for {
conn, err := lis.Accept()
if err != nil {
log.Println("rpc server: accept error:", err)
return
}
//创建服务器连接
go server.ServeConn(conn)
}
}
// 接收请求
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
defer func() { _ = conn.Close() }()
var opt Option
//反序列化,读取消息
if err := json.NewDecoder(conn).Decode(&opt); err != nil {
log.Println("rpc server: options error: ", err)
}
//检查MagicNumber是否对应
if opt.MagicNumber != MagicNumber {
log.Printf("rpc server: invalid magic number %x", opt.MagicNumber)
}
//检查编码器是否正确
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
log.Printf("rpc server: invalid codec type %s", opt.CodecType)
}
server.serveCodec(f(conn), &opt)
}
var invalidRequest = struct{}{}
func (server *Server) serveCodec(cc codec.Codec, opt *Option) {
sending := new(sync.Mutex) //锁住线程资源
wg := new(sync.WaitGroup) //等待所有请求被处理完毕
for {
//读取请求
req, err := server.readRequest(cc)
if err != nil {
if req == nil {
break // it's not possible to recover, so close the connection
}
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
continue
}
wg.Add(1)
//并发执行处理请求
go server.handleRequest(cc, req, sending, wg, opt.HandleTimeout)
}
//等待请求结束
wg.Wait()
//关闭请求
_ = cc.Close()
}
// request stores all information of a call
type request struct {
h *codec.Header // 请求头
argv, replyv reflect.Value // argv and replyv of request
mtype *serviceRegister.MethodType
svc *serviceRegister.Service
}
// 读取请求
func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
var h codec.Header
if err := cc.ReadHeader(&h); err != nil {
if err != io.EOF && err != io.ErrUnexpectedEOF {
log.Println("rpc server: read header error:", err)
}
return nil, err
}
return &h, nil
}
// 处理请求
func (server *Server) readRequest(cc codec.Codec) (*request, error) {
h, err := server.readRequestHeader(cc)
if err != nil {
return nil, err
}
req := &request{h: h}
req.svc, req.mtype, err = server.findService(h.ServiceMethod)
if err != nil {
return req, err
}
req.argv = req.mtype.NewArgv()
req.replyv = req.mtype.NewReplyv()
argvi := req.argv.Interface()
//确保argvi是一个指针,ReadBody需要一个指针作为参数
if req.argv.Type().Kind() != reflect.Ptr {
argvi = req.argv.Addr().Interface()
}
if err = cc.ReadBody(argvi); err != nil {
log.Println("rpc server: read argv err:", err)
return req, err
}
return req, nil
}
// 回复请求
func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) {
sending.Lock()
defer sending.Unlock()
if err := cc.Write(h, body); err != nil {
log.Println("rpc server: write response error:", err)
}
}
// 处理请求
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
// TODO, should call registered rpc methods to get the right replyv
// day 1, just print argv and send a hello message
defer wg.Done()
called := make(chan struct{})
sent := make(chan struct{})
go func() {
err := req.svc.Call(req.mtype, req.argv, req.replyv)
called <- struct{}{}
if err != nil {
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
sent <- struct{}{}
return
}
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
sent <- struct{}{}
}()
if timeout == 0 {
<-called
<-sent
return
}
select {
case <-time.After(timeout): //先于 called 接收到消息,说明处理已经超时
req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout)
server.sendResponse(cc, req.h, invalidRequest, sending)
case <-called:
<-sent
}
/*
err := req.svc.Call(req.mtype, req.argv, req.replyv)
if err != nil {
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
return
}
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
*/
}
// Register publishes in the server the set of methods of the
func (server *Server) Register(rcvr interface{}) error {
s := serviceRegister.NewService(rcvr)
if _, dup := server.ServiceMap.LoadOrStore(s.Name, s); dup {
return errors.New("rpc: service already defined: " + s.Name)
}
return nil
}
// Register publishes the receiver's methods in the DefaultServer.
func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
func (server *Server) findService(serviceMethod string) (svc *serviceRegister.Service, mtype *serviceRegister.MethodType, err error) {
dot := strings.LastIndex(serviceMethod, ".")
if dot < 0 {
err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod)
return
}
//找到实例,再从实列中找到对应的method
serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
svci, ok := server.ServiceMap.Load(serviceName)
if !ok {
err = errors.New("rpc server: can't find service " + serviceName)
return
}
svc = svci.(*serviceRegister.Service)
mtype = svc.Method[methodName]
if mtype == nil {
err = errors.New("rpc server: can't find method " + methodName)
}
return
}
// 支持http协议
const (
connected = "200 Connected to RPC"
defaultRPCPath = "/gorpc"
defaultDebugPath = "/debug/gorpc"
)
// ServeHTTP实现了一个http处理程序,用于回答RPC请求
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "CONNECT" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
_, _ = io.WriteString(w, "405 must CONNECT\n")
return
}
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
return
}
_, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
server.ServeConn(conn)
}
// HandleHTTP registers an HTTP handler for RPC messages on rpcPath.
// It is still necessary to invoke http.Serve(), typically in a go statement.
func (server *Server) HandleHTTP() {
http.Handle(defaultRPCPath, server)
http.Handle(defaultDebugPath, DebugHTTP{server})
log.Println("rpc server debug path:", defaultDebugPath)
}
// HandleHTTP is a convenient approach for default server to register HTTP handlers
func HandleHTTP() {
DefaultServer.HandleHTTP()
}