1 Star 0 Fork 0

0x43/dubbo-getty

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
client.go 10.40 KB
一键复制 编辑 原始数据 按行查看 历史
0x43 提交于 2024-12-28 17:14 . gost1.13.1
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package getty
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
import (
"gitee.com/git4chen/gost/bytes"
"gitee.com/git4chen/gost/net"
gxsync "gitee.com/git4chen/gost/sync"
gxtime "gitee.com/git4chen/gost/time"
"github.com/gorilla/websocket"
perrors "github.com/pkg/errors"
)
const (
reconnectInterval = 3e8 // 300ms
connectInterval = 5e8 // 500ms
connectTimeout = 3e9
maxTimes = 10
)
var (
sessionClientKey = "session-client-owner"
connectPingPackage = []byte("connect-ping")
clientID = EndPointID(0)
)
type Client interface {
EndPoint
}
type client struct {
ClientOptions
// endpoint ID
endPointID EndPointID
// net
sync.Mutex
endPointType EndPointType
newSession NewSessionCallback
ssMap map[Session]struct{}
sync.Once
done chan struct{}
wg sync.WaitGroup
}
func (c *client) init(opts ...ClientOption) {
for _, opt := range opts {
opt(&(c.ClientOptions))
}
}
func newClient(t EndPointType, opts ...ClientOption) *client {
c := &client{
endPointID: atomic.AddInt32(&clientID, 1),
endPointType: t,
done: make(chan struct{}),
}
c.init(opts...)
if c.number <= 0 || c.addr == "" {
panic(fmt.Sprintf("client type:%s, @connNum:%d, @serverAddr:%s", t, c.number, c.addr))
}
c.ssMap = make(map[Session]struct{}, c.number)
return c
}
// NewTCPClient builds a tcp client.
func NewTCPClient(opts ...ClientOption) Client {
return newClient(TCP_CLIENT, opts...)
}
// NewUDPClient builds a connected udp client
func NewUDPClient(opts ...ClientOption) Client {
return newClient(UDP_CLIENT, opts...)
}
// NewWSClient builds a ws client.
func NewWSClient(opts ...ClientOption) Client {
c := newClient(WS_CLIENT, opts...)
if !strings.HasPrefix(c.addr, "ws://") {
panic(fmt.Sprintf("the prefix @serverAddr:%s is not ws://", c.addr))
}
return c
}
// NewWSSClient function builds a wss client.
func NewWSSClient(opts ...ClientOption) Client {
c := newClient(WSS_CLIENT, opts...)
if c.cert == "" {
panic(fmt.Sprintf("@cert:%s", c.cert))
}
if !strings.HasPrefix(c.addr, "wss://") {
panic(fmt.Sprintf("the prefix @serverAddr:%s is not wss://", c.addr))
}
return c
}
func (c *client) ID() EndPointID {
return c.endPointID
}
func (c *client) EndPointType() EndPointType {
return c.endPointType
}
func (c *client) dialTCP() Session {
var (
err error
conn net.Conn
)
for {
if c.IsClosed() {
return nil
}
if c.sslEnabled {
if sslConfig, buildTlsConfErr := c.tlsConfigBuilder.BuildTlsConfig(); buildTlsConfErr == nil && sslConfig != nil {
d := &net.Dialer{Timeout: connectTimeout}
conn, err = tls.DialWithDialer(d, "tcp", c.addr, sslConfig)
}
} else {
conn, err = net.DialTimeout("tcp", c.addr, connectTimeout)
}
if err == nil && gxnet.IsSameAddr(conn.RemoteAddr(), conn.LocalAddr()) {
conn.Close()
err = errSelfConnect
}
if err == nil {
return newTCPSession(conn, c)
}
log.Infof("net.DialTimeout(addr:%s, timeout:%v) = error:%+v", c.addr, connectTimeout, perrors.WithStack(err))
<-gxtime.After(connectInterval)
}
}
func (c *client) dialUDP() Session {
var (
err error
conn *net.UDPConn
localAddr *net.UDPAddr
peerAddr *net.UDPAddr
length int
bufp *[]byte
buf []byte
)
bufp = gxbytes.GetBytes(128)
defer gxbytes.PutBytes(bufp)
buf = *bufp
localAddr = &net.UDPAddr{IP: net.IPv4zero, Port: 0}
peerAddr, _ = net.ResolveUDPAddr("udp", c.addr)
for {
if c.IsClosed() {
return nil
}
conn, err = net.DialUDP("udp", localAddr, peerAddr)
if err == nil && gxnet.IsSameAddr(conn.RemoteAddr(), conn.LocalAddr()) {
conn.Close()
err = errSelfConnect
}
if err != nil {
log.Warnf("net.DialTimeout(addr:%s, timeout:%v) = error:%+v", c.addr, perrors.WithStack(err))
<-gxtime.After(connectInterval)
continue
}
// check connection alive by write/read action
conn.SetWriteDeadline(time.Now().Add(1e9))
if length, err = conn.Write(connectPingPackage[:]); err != nil {
conn.Close()
log.Warnf("conn.Write(%s) = {length:%d, err:%+v}", string(connectPingPackage), length, perrors.WithStack(err))
<-gxtime.After(connectInterval)
continue
}
conn.SetReadDeadline(time.Now().Add(1e9))
length, err = conn.Read(buf)
if netErr, ok := perrors.Cause(err).(net.Error); ok && netErr.Timeout() {
err = nil
}
if err != nil {
log.Infof("conn{%#v}.Read() = {length:%d, err:%+v}", conn, length, perrors.WithStack(err))
conn.Close()
<-gxtime.After(connectInterval)
continue
}
return newUDPSession(conn, c)
}
}
func (c *client) dialWS() Session {
var (
err error
dialer websocket.Dialer
conn *websocket.Conn
ss Session
)
dialer.EnableCompression = true
for {
if c.IsClosed() {
return nil
}
conn, _, err = dialer.Dial(c.addr, nil)
log.Infof("websocket.dialer.Dial(addr:%s) = error:%+v", c.addr, perrors.WithStack(err))
if err == nil && gxnet.IsSameAddr(conn.RemoteAddr(), conn.LocalAddr()) {
conn.Close()
err = errSelfConnect
}
if err == nil {
ss = newWSSession(conn, c)
if ss.(*session).maxMsgLen > 0 {
conn.SetReadLimit(int64(ss.(*session).maxMsgLen))
}
return ss
}
log.Infof("websocket.dialer.Dial(addr:%s) = error:%+v", c.addr, perrors.WithStack(err))
<-gxtime.After(connectInterval)
}
}
func (c *client) dialWSS() Session {
var (
err error
root *x509.Certificate
roots []*x509.Certificate
certPool *x509.CertPool
config *tls.Config
dialer websocket.Dialer
conn *websocket.Conn
ss Session
)
dialer.EnableCompression = true
config = &tls.Config{
InsecureSkipVerify: true,
}
if c.cert != "" {
certPEMBlock, err := os.ReadFile(c.cert)
if err != nil {
panic(fmt.Sprintf("os.ReadFile(cert:%s) = error:%+v", c.cert, perrors.WithStack(err)))
}
var cert tls.Certificate
for {
var certDERBlock *pem.Block
certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
if certDERBlock == nil {
break
}
if certDERBlock.Type == "CERTIFICATE" {
cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
}
}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0] = cert
}
certPool = x509.NewCertPool()
for _, c := range config.Certificates {
roots, err = x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
panic(fmt.Sprintf("error parsing server's root cert: %+v\n", perrors.WithStack(err)))
}
for _, root = range roots {
certPool.AddCert(root)
}
}
config.InsecureSkipVerify = true
config.RootCAs = certPool
// dialer.EnableCompression = true
dialer.TLSClientConfig = config
for {
if c.IsClosed() {
return nil
}
conn, _, err = dialer.Dial(c.addr, nil)
if err == nil && gxnet.IsSameAddr(conn.RemoteAddr(), conn.LocalAddr()) {
conn.Close()
err = errSelfConnect
}
if err == nil {
ss = newWSSession(conn, c)
if ss.(*session).maxMsgLen > 0 {
conn.SetReadLimit(int64(ss.(*session).maxMsgLen))
}
ss.SetName(defaultWSSSessionName)
return ss
}
log.Infof("websocket.dialer.Dial(addr:%s) = error:%+v", c.addr, perrors.WithStack(err))
<-gxtime.After(connectInterval)
}
}
func (c *client) dial() Session {
switch c.endPointType {
case TCP_CLIENT:
return c.dialTCP()
case UDP_CLIENT:
return c.dialUDP()
case WS_CLIENT:
return c.dialWS()
case WSS_CLIENT:
return c.dialWSS()
}
return nil
}
func (c *client) GetTaskPool() gxsync.GenericTaskPool {
return c.tPool
}
func (c *client) sessionNum() int {
var num int
c.Lock()
for s := range c.ssMap {
if s.IsClosed() {
delete(c.ssMap, s)
}
}
num = len(c.ssMap)
c.Unlock()
return num
}
func (c *client) connect() {
var (
err error
ss Session
)
for {
ss = c.dial()
if ss == nil {
// client has been closed
break
}
err = c.newSession(ss)
if err == nil {
ss.(*session).run()
c.Lock()
if c.ssMap == nil {
c.Unlock()
break
}
c.ssMap[ss] = struct{}{}
c.Unlock()
ss.SetAttribute(sessionClientKey, c)
break
}
// don't distinguish between tcp connection and websocket connection. Because
// gorilla/websocket/conn.go:(Conn)Close also invoke net.Conn.Close()
ss.Conn().Close()
}
}
// there are two methods to keep connection pool. the first approach is like
// redigo's lazy connection pool(https://github.com/gomodule/redigo/blob/master/redis/pool.go:),
// in which you should apply testOnBorrow to check alive of the connection.
// the second way is as follows. @RunEventLoop detects the aliveness of the connection
// in regular time interval.
// the active method maybe overburden the cpu slightly.
// however, you can get a active tcp connection very quickly.
func (c *client) RunEventLoop(newSession NewSessionCallback) {
c.Lock()
c.newSession = newSession
c.Unlock()
c.reConnect()
}
// a for-loop connect to make sure the connection pool is valid
func (c *client) reConnect() {
var num, max, times, interval int
max = c.number
interval = c.reconnectInterval
if interval == 0 {
interval = reconnectInterval
}
for {
if c.IsClosed() {
log.Warnf("client{peer:%s} goroutine exit now.", c.addr)
break
}
num = c.sessionNum()
if max <= num {
break
}
c.connect()
times++
if maxTimes < times {
times = maxTimes
}
<-gxtime.After(time.Duration(int64(times) * int64(interval)))
}
}
func (c *client) stop() {
select {
case <-c.done:
return
default:
c.Once.Do(func() {
close(c.done)
c.Lock()
for s := range c.ssMap {
s.RemoveAttribute(sessionClientKey)
s.Close()
}
c.ssMap = nil
c.Unlock()
})
}
}
func (c *client) IsClosed() bool {
select {
case <-c.done:
return true
default:
return false
}
}
func (c *client) Close() {
c.stop()
c.wg.Wait()
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/git4chen/dubbo-getty.git
git@gitee.com:git4chen/dubbo-getty.git
git4chen
dubbo-getty
dubbo-getty
master

搜索帮助