You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
	
	
		
			419 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Go
		
	
		
		
			
		
	
	
			419 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Go
		
	
| 
											5 months ago
										 | package tao | ||
|  | 
 | ||
|  | import ( | ||
|  | 	"context" | ||
|  | 	"crypto/rand" | ||
|  | 	"crypto/tls" | ||
|  | 	"fmt" | ||
|  | 	"net" | ||
|  | 	"sync" | ||
|  | 	"time" | ||
|  | 
 | ||
|  | 	"flx/cite/holmes" | ||
|  | 
 | ||
|  | 	log "github.com/sirupsen/logrus" | ||
|  | ) | ||
|  | 
 | ||
|  | func init() { | ||
|  | 	netIdentifier = NewAtomicInt64(0) | ||
|  | } | ||
|  | 
 | ||
|  | var ( | ||
|  | 	netIdentifier *AtomicInt64 | ||
|  | 	tlsWrapper    func(net.Conn) net.Conn | ||
|  | ) | ||
|  | 
 | ||
|  | type options struct { | ||
|  | 	tlsCfg     *tls.Config | ||
|  | 	codec      Codec | ||
|  | 	onConnect  onConnectFunc | ||
|  | 	onMessage  onMessageFunc | ||
|  | 	onClose    onCloseFunc | ||
|  | 	onError    onErrorFunc | ||
|  | 	workerSize int  // numbers of worker go-routines
 | ||
|  | 	bufferSize int  // size of buffered channel
 | ||
|  | 	reconnect  bool // for ClientConn use only
 | ||
|  | } | ||
|  | 
 | ||
|  | // ServerOption sets server options.
 | ||
|  | type ServerOption func(*options) | ||
|  | 
 | ||
|  | // ReconnectOption returns a ServerOption that will make ClientConn reconnectable.
 | ||
|  | func ReconnectOption() ServerOption { | ||
|  | 	return func(o *options) { | ||
|  | 		o.reconnect = true | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | // CustomCodecOption returns a ServerOption that will apply a custom Codec.
 | ||
|  | func CustomCodecOption(codec Codec) ServerOption { | ||
|  | 	return func(o *options) { | ||
|  | 		o.codec = codec | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | // TLSCredsOption returns a ServerOption that will set TLS credentials for server
 | ||
|  | // connections.
 | ||
|  | func TLSCredsOption(config *tls.Config) ServerOption { | ||
|  | 	return func(o *options) { | ||
|  | 		o.tlsCfg = config | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | // WorkerSizeOption returns a ServerOption that will set the number of go-routines
 | ||
|  | // in WorkerPool.
 | ||
|  | func WorkerSizeOption(workerSz int) ServerOption { | ||
|  | 	return func(o *options) { | ||
|  | 		o.workerSize = workerSz | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | // BufferSizeOption returns a ServerOption that is the size of buffered channel,
 | ||
|  | // for example an indicator of BufferSize256 means a size of 256.
 | ||
|  | func BufferSizeOption(indicator int) ServerOption { | ||
|  | 	return func(o *options) { | ||
|  | 		o.bufferSize = indicator | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | // OnConnectOption returns a ServerOption that will set callback to call when new
 | ||
|  | // client connected.
 | ||
|  | func OnConnectOption(cb func(WriteCloser) bool) ServerOption { | ||
|  | 	return func(o *options) { | ||
|  | 		o.onConnect = cb | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | // OnMessageOption returns a ServerOption that will set callback to call when new
 | ||
|  | // message arrived.
 | ||
|  | func OnMessageOption(cb func(Message, WriteCloser)) ServerOption { | ||
|  | 	return func(o *options) { | ||
|  | 		o.onMessage = cb | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | // OnCloseOption returns a ServerOption that will set callback to call when client
 | ||
|  | // closed.
 | ||
|  | func OnCloseOption(cb func(WriteCloser)) ServerOption { | ||
|  | 	return func(o *options) { | ||
|  | 		o.onClose = cb | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | // OnErrorOption returns a ServerOption that will set callback to call when error
 | ||
|  | // occurs.
 | ||
|  | func OnErrorOption(cb func(WriteCloser)) ServerOption { | ||
|  | 	return func(o *options) { | ||
|  | 		o.onError = cb | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | // Server  is a server to serve TCP requests.
 | ||
|  | type Server struct { | ||
|  | 	opts   options | ||
|  | 	ctx    context.Context | ||
|  | 	cancel context.CancelFunc | ||
|  | 	conns  *sync.Map | ||
|  | 	timing *TimingWheel | ||
|  | 	wg     *sync.WaitGroup | ||
|  | 	mu     sync.Mutex // guards following
 | ||
|  | 	lis    map[net.Listener]bool | ||
|  | 	// for periodically running function every duration.
 | ||
|  | 	interv time.Duration | ||
|  | 	sched  onScheduleFunc | ||
|  | } | ||
|  | 
 | ||
|  | // NewServer returns a new TCP server which has not started
 | ||
|  | // to serve requests yet.
 | ||
|  | func NewServer(opt ...ServerOption) *Server { | ||
|  | 	var opts options | ||
|  | 	for _, o := range opt { | ||
|  | 		o(&opts) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	if opts.codec == nil { | ||
|  | 		opts.codec = TypeLengthValueCodec{} | ||
|  | 	} | ||
|  | 	if opts.workerSize <= 0 { | ||
|  | 		opts.workerSize = defaultWorkersNum | ||
|  | 	} | ||
|  | 	if opts.bufferSize <= 0 { | ||
|  | 		opts.bufferSize = BufferSize256 | ||
|  | 	} | ||
|  | 
 | ||
|  | 	// initiates go-routine pool instance
 | ||
|  | 	globalWorkerPool = newWorkerPool(opts.workerSize) | ||
|  | 
 | ||
|  | 	s := &Server{ | ||
|  | 		opts:  opts, | ||
|  | 		conns: &sync.Map{}, | ||
|  | 		wg:    &sync.WaitGroup{}, | ||
|  | 		lis:   make(map[net.Listener]bool), | ||
|  | 	} | ||
|  | 	s.ctx, s.cancel = context.WithCancel(context.Background()) | ||
|  | 	s.timing = NewTimingWheel(s.ctx) | ||
|  | 	return s | ||
|  | } | ||
|  | 
 | ||
|  | // ConnsSize returns connections size.
 | ||
|  | func (s *Server) ConnsSize() int { | ||
|  | 	var sz int | ||
|  | 	s.conns.Range(func(k, v interface{}) bool { | ||
|  | 		sz++ | ||
|  | 		return true | ||
|  | 	}) | ||
|  | 	return sz | ||
|  | } | ||
|  | 
 | ||
|  | // Sched sets a callback to invoke every duration.
 | ||
|  | func (s *Server) Sched(dur time.Duration, sched func(time.Time, WriteCloser)) { | ||
|  | 	s.mu.Lock() | ||
|  | 	defer s.mu.Unlock() | ||
|  | 	s.interv = dur | ||
|  | 	s.sched = onScheduleFunc(sched) | ||
|  | } | ||
|  | 
 | ||
|  | // Broadcast broadcasts message to all server connections managed.
 | ||
|  | func (s *Server) Broadcast(msg Message) { | ||
|  | 	s.conns.Range(func(k, v interface{}) bool { | ||
|  | 		c := v.(*ServerConn) | ||
|  | 		if err := c.Write(msg); err != nil { | ||
|  | 			holmes.Errorf("broadcast error %v, conn id %d", err, k.(int64)) | ||
|  | 			return false | ||
|  | 		} | ||
|  | 		return true | ||
|  | 	}) | ||
|  | } | ||
|  | 
 | ||
|  | // 测试
 | ||
|  | // func (s *Server) Broadcasts(msg []byte) {
 | ||
|  | // 	s.conns.Range(func(k, v interface{}) bool {
 | ||
|  | // 		c := v.(*ServerConn)
 | ||
|  | // 		if err := c.Write(msg); err != nil {
 | ||
|  | // 			holmes.Errorf("broadcast error %v, conn id %d", err, k.(int64))
 | ||
|  | // 			return false
 | ||
|  | // 		}
 | ||
|  | // 		return true
 | ||
|  | // 	})
 | ||
|  | // }
 | ||
|  | 
 | ||
|  | // Unicast unicasts message to a specified conn.
 | ||
|  | func (s *Server) Unicast(id int64, msg Message) error { | ||
|  | 	v, ok := s.conns.Load(id) | ||
|  | 	if ok { | ||
|  | 		return v.(*ServerConn).Write(msg) | ||
|  | 	} | ||
|  | 	return fmt.Errorf("conn id %d not found", id) | ||
|  | } | ||
|  | 
 | ||
|  | // Conn returns a server connection with specified ID.
 | ||
|  | func (s *Server) Conn(id int64) (*ServerConn, bool) { | ||
|  | 	v, ok := s.conns.Load(id) | ||
|  | 	if ok { | ||
|  | 		return v.(*ServerConn), ok | ||
|  | 	} | ||
|  | 	return nil, ok | ||
|  | } | ||
|  | 
 | ||
|  | // Start starts the TCP server, accepting new clients and creating service
 | ||
|  | // go-routine for each. The service go-routines read messages and then call
 | ||
|  | // the registered handlers to handle them. Start returns when failed with fatal
 | ||
|  | // errors, the listener willl be closed when returned.
 | ||
|  | func (s *Server) Start(l net.Listener) error { | ||
|  | 	defer func() { | ||
|  | 		if r := recover(); r != nil { | ||
|  | 			log.Error(" (s *Server) Start(l net.Listener) error:", r) | ||
|  | 			// 可以在这里处理错误,比如设置 tempresultmap 之类的
 | ||
|  | 		} | ||
|  | 	}() | ||
|  | 	s.mu.Lock() | ||
|  | 	if s.lis == nil { | ||
|  | 		s.mu.Unlock() | ||
|  | 		l.Close() | ||
|  | 		return ErrServerClosed | ||
|  | 	} | ||
|  | 	s.lis[l] = true | ||
|  | 	s.mu.Unlock() | ||
|  | 
 | ||
|  | 	defer func() { | ||
|  | 		s.mu.Lock() | ||
|  | 		if s.lis != nil && s.lis[l] { | ||
|  | 			l.Close() | ||
|  | 			delete(s.lis, l) | ||
|  | 		} | ||
|  | 		s.mu.Unlock() | ||
|  | 	}() | ||
|  | 
 | ||
|  | 	holmes.Infof("server start, net %s addr %s\n", l.Addr().Network(), l.Addr().String()) | ||
|  | 
 | ||
|  | 	s.wg.Add(1) | ||
|  | 	go s.timeOutLoop() | ||
|  | 
 | ||
|  | 	var tempDelay time.Duration | ||
|  | 	for { | ||
|  | 		rawConn, err := l.Accept() | ||
|  | 		if err != nil { | ||
|  | 			if ne, ok := err.(net.Error); ok && ne.Temporary() { | ||
|  | 				if tempDelay == 0 { | ||
|  | 					tempDelay = 5 * time.Millisecond | ||
|  | 				} else { | ||
|  | 					tempDelay *= 2 | ||
|  | 				} | ||
|  | 				if max := 1 * time.Second; tempDelay >= max { | ||
|  | 					tempDelay = max | ||
|  | 				} | ||
|  | 				holmes.Errorf("accept error %v, retrying in %d\n", err, tempDelay) | ||
|  | 				select { | ||
|  | 				case <-time.After(tempDelay): | ||
|  | 				case <-s.ctx.Done(): | ||
|  | 				} | ||
|  | 				continue | ||
|  | 			} | ||
|  | 			return err | ||
|  | 		} | ||
|  | 		tempDelay = 0 | ||
|  | 
 | ||
|  | 		// how many connections do we have ?
 | ||
|  | 		sz := s.ConnsSize() | ||
|  | 		if sz >= MaxConnections { | ||
|  | 			holmes.Warnf("max connections size %d, refuse\n", sz) | ||
|  | 			rawConn.Close() | ||
|  | 			continue | ||
|  | 		} | ||
|  | 
 | ||
|  | 		if s.opts.tlsCfg != nil { | ||
|  | 			rawConn = tls.Server(rawConn, s.opts.tlsCfg) | ||
|  | 		} | ||
|  | 
 | ||
|  | 		netid := netIdentifier.GetAndIncrement() | ||
|  | 		sc := NewServerConn(netid, s, rawConn) | ||
|  | 		sc.SetName(sc.rawConn.RemoteAddr().String()) | ||
|  | 
 | ||
|  | 		s.mu.Lock() | ||
|  | 		if s.sched != nil { | ||
|  | 			sc.RunEvery(s.interv, s.sched) | ||
|  | 		} | ||
|  | 		s.mu.Unlock() | ||
|  | 
 | ||
|  | 		s.conns.Store(netid, sc) | ||
|  | 		addTotalConn(1) | ||
|  | 
 | ||
|  | 		s.wg.Add(1) // this will be Done() in ServerConn.Close()
 | ||
|  | 		go func() { | ||
|  | 			defer func() { | ||
|  | 				if err := recover(); err != nil { | ||
|  | 					fmt.Println("错误信息: ", err) | ||
|  | 				} | ||
|  | 			}() | ||
|  | 			sc.Start() | ||
|  | 		}() | ||
|  | 
 | ||
|  | 		holmes.Infof("accepted client %s, id %d, total %d\n", sc.Name(), netid, s.ConnsSize()) | ||
|  | 		s.conns.Range(func(k, v interface{}) bool { | ||
|  | 			i := k.(int64) | ||
|  | 			c := v.(*ServerConn) | ||
|  | 			holmes.Infof("client(%d) %s", i, c.Name()) | ||
|  | 			return true | ||
|  | 		}) | ||
|  | 	} // for loop
 | ||
|  | } | ||
|  | 
 | ||
|  | // Stop gracefully closes the server, it blocked until all connections
 | ||
|  | // are closed and all go-routines are exited.
 | ||
|  | func (s *Server) Stop() { | ||
|  | 	// immediately stop accepting new clients
 | ||
|  | 	s.mu.Lock() | ||
|  | 	listeners := s.lis | ||
|  | 	s.lis = nil | ||
|  | 	s.mu.Unlock() | ||
|  | 
 | ||
|  | 	for l := range listeners { | ||
|  | 		l.Close() | ||
|  | 		holmes.Infof("stop accepting at address %s\n", l.Addr().String()) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	// close all connections
 | ||
|  | 	conns := map[int64]*ServerConn{} | ||
|  | 
 | ||
|  | 	s.conns.Range(func(k, v interface{}) bool { | ||
|  | 		i := k.(int64) | ||
|  | 		c := v.(*ServerConn) | ||
|  | 		conns[i] = c | ||
|  | 		return true | ||
|  | 	}) | ||
|  | 	// let GC do the cleanings
 | ||
|  | 	s.conns = nil | ||
|  | 
 | ||
|  | 	for _, c := range conns { | ||
|  | 		c.rawConn.Close() | ||
|  | 		holmes.Infof("close client %s\n", c.Name()) | ||
|  | 	} | ||
|  | 
 | ||
|  | 	s.mu.Lock() | ||
|  | 	s.cancel() | ||
|  | 	s.mu.Unlock() | ||
|  | 
 | ||
|  | 	s.wg.Wait() | ||
|  | 
 | ||
|  | 	holmes.Infoln("server stopped gracefully, bye.") | ||
|  | 	// os.Exit(0)
 | ||
|  | } | ||
|  | 
 | ||
|  | // Retrieve the extra data(i.e. net id), and then redispatch timeout callbacks
 | ||
|  | // to corresponding client connection, this prevents one client from running
 | ||
|  | // callbacks of other clients
 | ||
|  | func (s *Server) timeOutLoop() { | ||
|  | 	defer s.wg.Done() | ||
|  | 
 | ||
|  | 	for { | ||
|  | 		select { | ||
|  | 		case <-s.ctx.Done(): | ||
|  | 			return | ||
|  | 
 | ||
|  | 		case timeout := <-s.timing.TimeOutChannel(): | ||
|  | 			netID := timeout.Ctx.Value(netIDCtx).(int64) | ||
|  | 			if v, ok := s.conns.Load(netID); ok { | ||
|  | 				sc := v.(*ServerConn) | ||
|  | 				sc.timerCh <- timeout | ||
|  | 			} else { | ||
|  | 				holmes.Warnf("invalid client %d", netID) | ||
|  | 			} | ||
|  | 		} | ||
|  | 	} | ||
|  | } | ||
|  | 
 | ||
|  | // LoadTLSConfig returns a TLS configuration with the specified cert and key file.
 | ||
|  | func LoadTLSConfig(certFile, keyFile string, isSkipVerify bool) (*tls.Config, error) { | ||
|  | 	cert, err := tls.LoadX509KeyPair(certFile, keyFile) | ||
|  | 	if err != nil { | ||
|  | 		return nil, err | ||
|  | 	} | ||
|  | 	config := &tls.Config{ | ||
|  | 		Certificates:       []tls.Certificate{cert}, | ||
|  | 		InsecureSkipVerify: isSkipVerify, | ||
|  | 		CipherSuites: []uint16{ | ||
|  | 			tls.TLS_RSA_WITH_RC4_128_SHA, | ||
|  | 			tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, | ||
|  | 			tls.TLS_RSA_WITH_AES_128_CBC_SHA, | ||
|  | 			tls.TLS_RSA_WITH_AES_256_CBC_SHA, | ||
|  | 			tls.TLS_RSA_WITH_AES_128_GCM_SHA256, | ||
|  | 			tls.TLS_RSA_WITH_AES_256_GCM_SHA384, | ||
|  | 			tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, | ||
|  | 			tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, | ||
|  | 			tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, | ||
|  | 			tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, | ||
|  | 			tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, | ||
|  | 			tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, | ||
|  | 			tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, | ||
|  | 			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, | ||
|  | 			tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, | ||
|  | 			tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, | ||
|  | 			tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, | ||
|  | 		}, | ||
|  | 	} | ||
|  | 	now := time.Now() | ||
|  | 	config.Time = func() time.Time { return now } | ||
|  | 	config.Rand = rand.Reader | ||
|  | 	return config, nil | ||
|  | } |