Skip to content

Commit

Permalink
fix race problem in udp when passing the pointer to channel
Browse files Browse the repository at this point in the history
  • Loading branch information
uoosef committed Aug 29, 2023
1 parent ab592b9 commit a5b51b8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 16 deletions.
7 changes: 3 additions & 4 deletions transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type UdpBind struct {
SocksWriter io.Writer
SocksReq *socks5.Request
AssociateBind *net.UDPConn
RecvChan chan *UDPPacket
RecvChan chan UDPPacket
}

type UDPConf struct {
Expand Down Expand Up @@ -132,7 +132,7 @@ func (t *Transport) TunnelUDP(w io.Writer, req *socks5.Request) error {
return err
}

bindWriteChannel := make(chan *UDPPacket)
bindWriteChannel := make(chan UDPPacket)
tunnelWriteChannel, channelIndex, err := t.Tunnel.PersistentDial(tunnelEndpoint, bindWriteChannel)
if err != nil {
t.Logger.Errorf("Unable to get or create tunnel for udpBindWriteChannel %v\r\n", err)
Expand Down Expand Up @@ -165,15 +165,14 @@ func (t *Transport) TunnelUDP(w io.Writer, req *socks5.Request) error {
if err != nil {
continue
}
tunnelWriteChannel <- &UDPPacket{
tunnelWriteChannel <- UDPPacket{
Channel: channelIndex,
Data: pk.Data,
}
}
}()
for {
datagram := <-udpBind.RecvChan
fmt.Print("injam")
pkb, err := statute.NewDatagram(req.RawDestAddr.String(), datagram.Data)
if err != nil {
continue
Expand Down
18 changes: 6 additions & 12 deletions transport/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"bepass/wsconnadapter"
"context"
"encoding/binary"
"fmt"
"github.com/gorilla/websocket"
"golang.org/x/net/proxy"
"net"
Expand All @@ -15,8 +14,8 @@ import (
)

type EstablishedTunnel struct {
tunnelWriteChannel chan *UDPPacket
bindWriteChannels map[uint16]chan *UDPPacket
tunnelWriteChannel chan UDPPacket
bindWriteChannels map[uint16]chan UDPPacket
channelIndex uint16
}

Expand Down Expand Up @@ -56,18 +55,18 @@ func (w *WSTunnel) Dial(endpoint string) (*websocket.Conn, error) {
return conn, err
}

func (w *WSTunnel) PersistentDial(tunnelEndpoint string, bindWriteChannel chan *UDPPacket) (chan *UDPPacket, uint16, error) {
func (w *WSTunnel) PersistentDial(tunnelEndpoint string, bindWriteChannel chan UDPPacket) (chan UDPPacket, uint16, error) {
if tunnel, ok := w.EstablishedTunnels[tunnelEndpoint]; ok {
tunnel.channelIndex = tunnel.channelIndex + 1
tunnel.bindWriteChannels[tunnel.channelIndex] = bindWriteChannel
return tunnel.tunnelWriteChannel, tunnel.channelIndex, nil
}

tunnelWriteChannel := make(chan *UDPPacket)
tunnelWriteChannel := make(chan UDPPacket)

w.EstablishedTunnels[tunnelEndpoint] = &EstablishedTunnel{
tunnelWriteChannel: tunnelWriteChannel,
bindWriteChannels: make(map[uint16]chan *UDPPacket),
bindWriteChannels: make(map[uint16]chan UDPPacket),
channelIndex: 1,
}

Expand Down Expand Up @@ -162,15 +161,10 @@ func (w *WSTunnel) PersistentDial(tunnelEndpoint string, bindWriteChannel chan *
continue
}

fmt.Println(rawPacket[:n])

// first 2 packets of response is channel id

channelID := binary.BigEndian.Uint16(rawPacket[:2])

fmt.Println(channelID)

pkt := &UDPPacket{
pkt := UDPPacket{
channelID,
rawPacket[2:n],
}
Expand Down

0 comments on commit a5b51b8

Please sign in to comment.