You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
419 lines
9.8 KiB
Go
419 lines
9.8 KiB
Go
package tao
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"flx/cite/holmes"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
func init() {
|
|
netIdentifier = NewAtomicInt64(0)
|
|
}
|
|
|
|
var (
|
|
netIdentifier *AtomicInt64
|
|
tlsWrapper func(net.Conn) net.Conn
|
|
)
|
|
|
|
type options struct {
|
|
tlsCfg *tls.Config
|
|
codec Codec
|
|
onConnect onConnectFunc
|
|
onMessage onMessageFunc
|
|
onClose onCloseFunc
|
|
onError onErrorFunc
|
|
workerSize int // numbers of worker go-routines
|
|
bufferSize int // size of buffered channel
|
|
reconnect bool // for ClientConn use only
|
|
}
|
|
|
|
// ServerOption sets server options.
|
|
type ServerOption func(*options)
|
|
|
|
// ReconnectOption returns a ServerOption that will make ClientConn reconnectable.
|
|
func ReconnectOption() ServerOption {
|
|
return func(o *options) {
|
|
o.reconnect = true
|
|
}
|
|
}
|
|
|
|
// CustomCodecOption returns a ServerOption that will apply a custom Codec.
|
|
func CustomCodecOption(codec Codec) ServerOption {
|
|
return func(o *options) {
|
|
o.codec = codec
|
|
}
|
|
}
|
|
|
|
// TLSCredsOption returns a ServerOption that will set TLS credentials for server
|
|
// connections.
|
|
func TLSCredsOption(config *tls.Config) ServerOption {
|
|
return func(o *options) {
|
|
o.tlsCfg = config
|
|
}
|
|
}
|
|
|
|
// WorkerSizeOption returns a ServerOption that will set the number of go-routines
|
|
// in WorkerPool.
|
|
func WorkerSizeOption(workerSz int) ServerOption {
|
|
return func(o *options) {
|
|
o.workerSize = workerSz
|
|
}
|
|
}
|
|
|
|
// BufferSizeOption returns a ServerOption that is the size of buffered channel,
|
|
// for example an indicator of BufferSize256 means a size of 256.
|
|
func BufferSizeOption(indicator int) ServerOption {
|
|
return func(o *options) {
|
|
o.bufferSize = indicator
|
|
}
|
|
}
|
|
|
|
// OnConnectOption returns a ServerOption that will set callback to call when new
|
|
// client connected.
|
|
func OnConnectOption(cb func(WriteCloser) bool) ServerOption {
|
|
return func(o *options) {
|
|
o.onConnect = cb
|
|
}
|
|
}
|
|
|
|
// OnMessageOption returns a ServerOption that will set callback to call when new
|
|
// message arrived.
|
|
func OnMessageOption(cb func(Message, WriteCloser)) ServerOption {
|
|
return func(o *options) {
|
|
o.onMessage = cb
|
|
}
|
|
}
|
|
|
|
// OnCloseOption returns a ServerOption that will set callback to call when client
|
|
// closed.
|
|
func OnCloseOption(cb func(WriteCloser)) ServerOption {
|
|
return func(o *options) {
|
|
o.onClose = cb
|
|
}
|
|
}
|
|
|
|
// OnErrorOption returns a ServerOption that will set callback to call when error
|
|
// occurs.
|
|
func OnErrorOption(cb func(WriteCloser)) ServerOption {
|
|
return func(o *options) {
|
|
o.onError = cb
|
|
}
|
|
}
|
|
|
|
// Server is a server to serve TCP requests.
|
|
type Server struct {
|
|
opts options
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
conns *sync.Map
|
|
timing *TimingWheel
|
|
wg *sync.WaitGroup
|
|
mu sync.Mutex // guards following
|
|
lis map[net.Listener]bool
|
|
// for periodically running function every duration.
|
|
interv time.Duration
|
|
sched onScheduleFunc
|
|
}
|
|
|
|
// NewServer returns a new TCP server which has not started
|
|
// to serve requests yet.
|
|
func NewServer(opt ...ServerOption) *Server {
|
|
var opts options
|
|
for _, o := range opt {
|
|
o(&opts)
|
|
}
|
|
|
|
if opts.codec == nil {
|
|
opts.codec = TypeLengthValueCodec{}
|
|
}
|
|
if opts.workerSize <= 0 {
|
|
opts.workerSize = defaultWorkersNum
|
|
}
|
|
if opts.bufferSize <= 0 {
|
|
opts.bufferSize = BufferSize256
|
|
}
|
|
|
|
// initiates go-routine pool instance
|
|
globalWorkerPool = newWorkerPool(opts.workerSize)
|
|
|
|
s := &Server{
|
|
opts: opts,
|
|
conns: &sync.Map{},
|
|
wg: &sync.WaitGroup{},
|
|
lis: make(map[net.Listener]bool),
|
|
}
|
|
s.ctx, s.cancel = context.WithCancel(context.Background())
|
|
s.timing = NewTimingWheel(s.ctx)
|
|
return s
|
|
}
|
|
|
|
// ConnsSize returns connections size.
|
|
func (s *Server) ConnsSize() int {
|
|
var sz int
|
|
s.conns.Range(func(k, v interface{}) bool {
|
|
sz++
|
|
return true
|
|
})
|
|
return sz
|
|
}
|
|
|
|
// Sched sets a callback to invoke every duration.
|
|
func (s *Server) Sched(dur time.Duration, sched func(time.Time, WriteCloser)) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.interv = dur
|
|
s.sched = onScheduleFunc(sched)
|
|
}
|
|
|
|
// Broadcast broadcasts message to all server connections managed.
|
|
func (s *Server) Broadcast(msg Message) {
|
|
s.conns.Range(func(k, v interface{}) bool {
|
|
c := v.(*ServerConn)
|
|
if err := c.Write(msg); err != nil {
|
|
holmes.Errorf("broadcast error %v, conn id %d", err, k.(int64))
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// 测试
|
|
// func (s *Server) Broadcasts(msg []byte) {
|
|
// s.conns.Range(func(k, v interface{}) bool {
|
|
// c := v.(*ServerConn)
|
|
// if err := c.Write(msg); err != nil {
|
|
// holmes.Errorf("broadcast error %v, conn id %d", err, k.(int64))
|
|
// return false
|
|
// }
|
|
// return true
|
|
// })
|
|
// }
|
|
|
|
// Unicast unicasts message to a specified conn.
|
|
func (s *Server) Unicast(id int64, msg Message) error {
|
|
v, ok := s.conns.Load(id)
|
|
if ok {
|
|
return v.(*ServerConn).Write(msg)
|
|
}
|
|
return fmt.Errorf("conn id %d not found", id)
|
|
}
|
|
|
|
// Conn returns a server connection with specified ID.
|
|
func (s *Server) Conn(id int64) (*ServerConn, bool) {
|
|
v, ok := s.conns.Load(id)
|
|
if ok {
|
|
return v.(*ServerConn), ok
|
|
}
|
|
return nil, ok
|
|
}
|
|
|
|
// Start starts the TCP server, accepting new clients and creating service
|
|
// go-routine for each. The service go-routines read messages and then call
|
|
// the registered handlers to handle them. Start returns when failed with fatal
|
|
// errors, the listener willl be closed when returned.
|
|
func (s *Server) Start(l net.Listener) error {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Error(" (s *Server) Start(l net.Listener) error:", r)
|
|
// 可以在这里处理错误,比如设置 tempresultmap 之类的
|
|
}
|
|
}()
|
|
s.mu.Lock()
|
|
if s.lis == nil {
|
|
s.mu.Unlock()
|
|
l.Close()
|
|
return ErrServerClosed
|
|
}
|
|
s.lis[l] = true
|
|
s.mu.Unlock()
|
|
|
|
defer func() {
|
|
s.mu.Lock()
|
|
if s.lis != nil && s.lis[l] {
|
|
l.Close()
|
|
delete(s.lis, l)
|
|
}
|
|
s.mu.Unlock()
|
|
}()
|
|
|
|
holmes.Infof("server start, net %s addr %s\n", l.Addr().Network(), l.Addr().String())
|
|
|
|
s.wg.Add(1)
|
|
go s.timeOutLoop()
|
|
|
|
var tempDelay time.Duration
|
|
for {
|
|
rawConn, err := l.Accept()
|
|
if err != nil {
|
|
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
|
if tempDelay == 0 {
|
|
tempDelay = 5 * time.Millisecond
|
|
} else {
|
|
tempDelay *= 2
|
|
}
|
|
if max := 1 * time.Second; tempDelay >= max {
|
|
tempDelay = max
|
|
}
|
|
holmes.Errorf("accept error %v, retrying in %d\n", err, tempDelay)
|
|
select {
|
|
case <-time.After(tempDelay):
|
|
case <-s.ctx.Done():
|
|
}
|
|
continue
|
|
}
|
|
return err
|
|
}
|
|
tempDelay = 0
|
|
|
|
// how many connections do we have ?
|
|
sz := s.ConnsSize()
|
|
if sz >= MaxConnections {
|
|
holmes.Warnf("max connections size %d, refuse\n", sz)
|
|
rawConn.Close()
|
|
continue
|
|
}
|
|
|
|
if s.opts.tlsCfg != nil {
|
|
rawConn = tls.Server(rawConn, s.opts.tlsCfg)
|
|
}
|
|
|
|
netid := netIdentifier.GetAndIncrement()
|
|
sc := NewServerConn(netid, s, rawConn)
|
|
sc.SetName(sc.rawConn.RemoteAddr().String())
|
|
|
|
s.mu.Lock()
|
|
if s.sched != nil {
|
|
sc.RunEvery(s.interv, s.sched)
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
s.conns.Store(netid, sc)
|
|
addTotalConn(1)
|
|
|
|
s.wg.Add(1) // this will be Done() in ServerConn.Close()
|
|
go func() {
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
fmt.Println("错误信息: ", err)
|
|
}
|
|
}()
|
|
sc.Start()
|
|
}()
|
|
|
|
holmes.Infof("accepted client %s, id %d, total %d\n", sc.Name(), netid, s.ConnsSize())
|
|
s.conns.Range(func(k, v interface{}) bool {
|
|
i := k.(int64)
|
|
c := v.(*ServerConn)
|
|
holmes.Infof("client(%d) %s", i, c.Name())
|
|
return true
|
|
})
|
|
} // for loop
|
|
}
|
|
|
|
// Stop gracefully closes the server, it blocked until all connections
|
|
// are closed and all go-routines are exited.
|
|
func (s *Server) Stop() {
|
|
// immediately stop accepting new clients
|
|
s.mu.Lock()
|
|
listeners := s.lis
|
|
s.lis = nil
|
|
s.mu.Unlock()
|
|
|
|
for l := range listeners {
|
|
l.Close()
|
|
holmes.Infof("stop accepting at address %s\n", l.Addr().String())
|
|
}
|
|
|
|
// close all connections
|
|
conns := map[int64]*ServerConn{}
|
|
|
|
s.conns.Range(func(k, v interface{}) bool {
|
|
i := k.(int64)
|
|
c := v.(*ServerConn)
|
|
conns[i] = c
|
|
return true
|
|
})
|
|
// let GC do the cleanings
|
|
s.conns = nil
|
|
|
|
for _, c := range conns {
|
|
c.rawConn.Close()
|
|
holmes.Infof("close client %s\n", c.Name())
|
|
}
|
|
|
|
s.mu.Lock()
|
|
s.cancel()
|
|
s.mu.Unlock()
|
|
|
|
s.wg.Wait()
|
|
|
|
holmes.Infoln("server stopped gracefully, bye.")
|
|
// os.Exit(0)
|
|
}
|
|
|
|
// Retrieve the extra data(i.e. net id), and then redispatch timeout callbacks
|
|
// to corresponding client connection, this prevents one client from running
|
|
// callbacks of other clients
|
|
func (s *Server) timeOutLoop() {
|
|
defer s.wg.Done()
|
|
|
|
for {
|
|
select {
|
|
case <-s.ctx.Done():
|
|
return
|
|
|
|
case timeout := <-s.timing.TimeOutChannel():
|
|
netID := timeout.Ctx.Value(netIDCtx).(int64)
|
|
if v, ok := s.conns.Load(netID); ok {
|
|
sc := v.(*ServerConn)
|
|
sc.timerCh <- timeout
|
|
} else {
|
|
holmes.Warnf("invalid client %d", netID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// LoadTLSConfig returns a TLS configuration with the specified cert and key file.
|
|
func LoadTLSConfig(certFile, keyFile string, isSkipVerify bool) (*tls.Config, error) {
|
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
config := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
InsecureSkipVerify: isSkipVerify,
|
|
CipherSuites: []uint16{
|
|
tls.TLS_RSA_WITH_RC4_128_SHA,
|
|
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
|
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
|
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
|
tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
|
tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
|
|
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
|
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
|
tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
|
tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
|
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
|
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
|
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
|
},
|
|
}
|
|
now := time.Now()
|
|
config.Time = func() time.Time { return now }
|
|
config.Rand = rand.Reader
|
|
return config, nil
|
|
}
|