forked from bepass-org/bepass
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathws.go
179 lines (152 loc) · 4.65 KB
/
ws.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
// Package transport provides WebSocket tunneling functionality.
package transport
import (
"context"
"encoding/binary"
"github.com/bepass-org/bepass/config"
"github.com/bepass-org/bepass/dialer"
"github.com/bepass-org/bepass/logger"
"github.com/bepass-org/bepass/net/adapter/ws"
"net"
"strings"
"time"
"github.com/gorilla/websocket"
)
// EstablishedTunnel represents an established tunnel.
type EstablishedTunnel struct {
tunnelWriteChannel chan UDPPacket
bindWriteChannels map[uint16]chan UDPPacket
channelIndex uint16
}
// WSTunnel represents a WebSocket tunnel.
type WSTunnel struct {
BindAddress string
Dialer *dialer.Dialer
ReadTimeout int
WriteTimeout int
LinkIdleTimeout int64
EstablishedTunnels map[string]*EstablishedTunnel
ShortClientID string
}
// Dial establishes a WebSocket connection.
func (w *WSTunnel) Dial(endpoint string) (*websocket.Conn, error) {
d := websocket.Dialer{
NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return w.Dialer.HttpDial(network, config.G.WorkerIPPortAddress)
},
NetDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return w.Dialer.TLSDial(func(network, addr string) (net.Conn, error) {
return w.Dialer.FragmentDial(network, config.G.WorkerIPPortAddress)
}, network, addr)
},
}
conn, _, err := d.Dial(endpoint, nil)
return conn, err
}
// PersistentDial establishes a persistent WebSocket connection.
func (w *WSTunnel) PersistentDial(tunnelEndpoint string, bindWriteChannel chan UDPPacket) (chan UDPPacket, uint16, error) {
if tunnel, ok := w.EstablishedTunnels[tunnelEndpoint]; ok {
tunnel.channelIndex = tunnel.channelIndex + 1
tunnel.bindWriteChannels[tunnel.channelIndex] = bindWriteChannel
return tunnel.tunnelWriteChannel, tunnel.channelIndex, nil
}
tunnelWriteChannel := make(chan UDPPacket)
w.EstablishedTunnels[tunnelEndpoint] = &EstablishedTunnel{
tunnelWriteChannel: tunnelWriteChannel,
bindWriteChannels: make(map[uint16]chan UDPPacket),
channelIndex: 1,
}
w.EstablishedTunnels[tunnelEndpoint].bindWriteChannels[1] = bindWriteChannel
lastActivityStamp := time.Now().Unix()
go func() {
defer delete(w.EstablishedTunnels, tunnelEndpoint)
if time.Now().Unix()-lastActivityStamp > w.LinkIdleTimeout {
return
}
for limit := 0; limit < 10; limit++ {
done := make(chan struct{})
doneR := make(chan struct{})
logger.Infof("connecting to %s\r\n", tunnelEndpoint)
c, err := w.Dial(tunnelEndpoint)
conn := ws.New(c)
if err != nil {
logger.Errorf("error dialing udp over tcp tunnel: %v\r\n", err)
continue
}
// Write
go func() {
defer func() {
close(doneR)
_ = conn.Close()
}()
defer logger.Info("write closed")
for {
select {
case <-done:
return
case rt := <-tunnelWriteChannel:
err := conn.SetWriteDeadline(time.Now().Add(time.Duration(w.WriteTimeout) * time.Second))
if err != nil {
return
}
bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, rt.Channel)
_, err = conn.Write(append([]byte(w.ShortClientID), append(bs, rt.Data...)...))
if err != nil {
logger.Info("write:", err)
return
}
lastActivityStamp = time.Now().Unix()
}
}
}()
// Read
func() {
defer func() {
close(done)
_ = conn.Close()
}()
err := conn.SetReadDeadline(time.Now().Add(time.Duration(w.ReadTimeout) * time.Second))
if err != nil {
return
}
defer logger.Info("read closed")
for {
select {
case <-doneR:
return
default:
// 1- unpack the message
// 2- find the channel that the message should write on
// 3- write the message on that channel
rawPacket := make([]byte, 32*1024)
n, err := conn.Read(rawPacket)
if n < 2 && err == nil {
continue
}
if err != nil {
if strings.Contains(err.Error(), "websocket: close") ||
strings.Contains(err.Error(), "limit/o") {
logger.Errorf("reading from udp over tcp error: %v\r\n", err)
return
}
logger.Errorf("reading from udp over TCP tunnel packet size error: %v\r\n", err)
return
}
// The first 2 packets of response are channel ID
channelID := binary.BigEndian.Uint16(rawPacket[:2])
pkt := UDPPacket{
channelID,
rawPacket[2:n],
}
if udpBindWriteChan, ok := w.EstablishedTunnels[tunnelEndpoint].bindWriteChannels[pkt.Channel]; ok {
udpBindWriteChan <- pkt
lastActivityStamp = time.Now().Unix()
}
}
}
}()
}
}()
return tunnelWriteChannel, 1, nil
}