Skip to content

Commit

Permalink
feat: addin embedded cache
Browse files Browse the repository at this point in the history
  • Loading branch information
kaveh-ahangar committed Aug 1, 2023
1 parent 9dd5098 commit fa5c33c
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 95 deletions.
32 changes: 32 additions & 0 deletions cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package cache

import (
"sync"
"time"
)

type Cache struct {
data map[string]interface{}
duration time.Duration
mutex sync.RWMutex
}

func NewCache(duration time.Duration) *Cache {
return &Cache{
data: make(map[string]interface{}),
duration: duration,
}
}

func (c *Cache) Set(key string, value interface{}) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.data[key] = value
}

func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
val, found := c.data[key]
return val, found
}
110 changes: 52 additions & 58 deletions cmd/bepass/main.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
package main

import (
"bepass/doh"
"bepass/logger"
"bepass/server"
"bepass/socks5"
"fmt"
"log"
"os"
"strings"
"time"

"github.com/jellydator/ttlcache/v3"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)

var (
cache *ttlcache.Cache[string, string]
"bepass/cache"
"bepass/doh"
"bepass/logger"
"bepass/server"
"bepass/socks5"
)

type Config struct {
Expand All @@ -36,24 +33,63 @@ type Config struct {
func loadConfig() (*Config, error) {
viper.SetConfigName("config")
viper.AddConfigPath(".")
err := viper.ReadInConfig()
if err != nil {
if err := viper.ReadInConfig(); err != nil {
return nil, err
}

var config Config
err = viper.Unmarshal(&config)
if err != nil {
if err := viper.Unmarshal(&config); err != nil {
return nil, err
}

return &config, nil
}

func createCache(ttl int) *ttlcache.Cache[string, string] {
return ttlcache.New(
ttlcache.WithTTL[string, string](time.Duration(int64(ttl) * int64(time.Minute))),
func runServer(cmd *cobra.Command, args []string) error {
config, err := loadConfig()
if err != nil {
return err
}

cache := cache.NewCache(time.Duration(config.DnsCacheTTL) * time.Second)

var resolveSystem string
var dohClient *doh.Client

if strings.HasPrefix(config.RemoteDNSAddr, "https://") {
resolveSystem = "doh"
dohClient = doh.NewClient(doh.WithTimeout(10 * time.Second))
} else {
resolveSystem = "DNSCrypt"
}

stdLogger := log.New(os.Stderr, "", log.Ldate|log.Ltime)
appLogger := logger.NewLogger(stdLogger)
chunkConfig := server.ChunkConfig{
BeforeSniLength: config.SniChunksLength,
AfterSniLength: config.ChunksLengthAfterSni,
Delay: config.DelayBetweenChunks,
}

serverHandler := &server.Server{
RemoteDNSAddr: config.RemoteDNSAddr,
Cache: cache,
ResolveSystem: resolveSystem,
DoHClient: dohClient,
Logger: appLogger,
ChunkConfig: chunkConfig,
}

s5 := socks5.NewServer(
socks5.WithConnectHandle(serverHandler.Handle),
)

fmt.Println("Starting socks server:", config.BindAddress)
if err := s5.ListenAndServe("tcp", config.BindAddress); err != nil {
return err
}

return nil
}

func main() {
Expand All @@ -62,49 +98,7 @@ func main() {
rootCmd := &cobra.Command{
Use: "bepass",
Short: "bepass is a socks5 proxy server",
RunE: func(cmd *cobra.Command, args []string) error {
config, err := loadConfig()
if err != nil {
return err
}

cache = createCache(config.DnsCacheTTL)
go cache.Start()

if strings.HasPrefix(config.RemoteDNSAddr, "https://") {
config.ResolveSystem = "doh"
config.DoHClient = doh.NewClient(doh.WithTimeout(10 * time.Second))
} else {
config.ResolveSystem = "DNSCrypt"
}
stdLogger := log.New(os.Stderr, "", log.Ldate|log.Ltime)
appLogger := logger.NewLogger(stdLogger)
chunkConfig := server.ChunkConfig{
BeforeSniLength: config.SniChunksLength,
AfterSniLength: config.ChunksLengthAfterSni,
Delay: config.DelayBetweenChunks,
}

serverHandler := &server.Server{
RemoteDNSAddr: config.RemoteDNSAddr,
Cache: cache,
ResolveSystem: config.ResolveSystem,
DoHClient: config.DoHClient,
Logger: appLogger,
ChunkConfig: chunkConfig,
}

s5 := socks5.NewServer(
socks5.WithConnectHandle(serverHandler.Handle),
)
fmt.Println("starting socks server: " + config.BindAddress)
err = s5.ListenAndServe("tcp", config.BindAddress)
if err != nil {
return err
}

return nil
},
RunE: runServer,
}

rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "./config.json", "Path to configuration file")
Expand Down
96 changes: 59 additions & 37 deletions server/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"bepass/cache"
"bepass/doh"
"bepass/logger"
"bepass/socks5"
Expand All @@ -11,14 +12,11 @@ import (
"io"
"math/rand"
"net"
"regexp"
"strconv"
"strings"
"sync"
"time"

"github.com/ameshkov/dnscrypt/v2"
"github.com/jellydator/ttlcache/v3"
"github.com/miekg/dns"
)

Expand All @@ -31,7 +29,7 @@ type ChunkConfig struct {

type Server struct {
RemoteDNSAddr string
Cache *ttlcache.Cache[string, string]
Cache *cache.Cache
ResolveSystem string
DoHClient *doh.Client
ChunkConfig ChunkConfig
Expand All @@ -40,34 +38,66 @@ type Server struct {

// getHostname returns the Server Name Indication (SNI) from a TLS Client Hello message.
func (s *Server) getHostname(data []byte) ([]byte, error) {
re := regexp.MustCompile(`\x00\x00\x00\x00\x00(?P<Length>.)(?P<SNI>.{0,255})`)
matches := re.FindSubmatch(data)
if len(matches) == 0 {
return nil, fmt.Errorf("could not find SNI in Server Name TLS Extension block")
const (
sniTypeByte = 0x00
sniLengthOffset = 2
)

// Find the SNI type byte
sniTypeIndex := bytes.IndexByte(data, sniTypeByte)
if sniTypeIndex == -1 {
return nil, fmt.Errorf("could not find SNI type byte in Server Hello message")
}
return matches[2], nil

// Ensure sufficient data to read the SNI length and value
if len(data) < sniTypeIndex+sniLengthOffset+1 {
return nil, fmt.Errorf("insufficient data to read SNI length")
}

// Read the SNI length
sniLength := int(data[sniTypeIndex+sniLengthOffset])

// Calculate the index of the SNI value
sniValueIndex := sniTypeIndex + sniLengthOffset + 1

// Ensure sufficient data to read the SNI value
if len(data) < sniValueIndex+sniLength {
return nil, fmt.Errorf("insufficient data to read SNI value")
}

// Extract and return the SNI value
sniValue := data[sniValueIndex : sniValueIndex+sniLength]
return sniValue, nil
}

// getChunkedPackets splits the data into chunks based on SNI and chunk lengths.
func (s *Server) getChunkedPackets(data []byte) map[int][]byte {
const (
chunkIndexHostname = 0
chunkIndexSNIValue = 1
chunkIndexRemainder = 2
)

chunks := make(map[int][]byte)
hostname, err := s.getHostname(data)
if err != nil {
s.Logger.Errorf("get hostname error: %v", err)
chunks[0] = data
chunks[chunkIndexRemainder] = data
return chunks
}

s.Logger.Printf("Hostname %s", string(hostname))
index := bytes.Index(data, hostname)
if index == -1 {
return nil
}
chunks[0] = make([]byte, index)
copy(chunks[0], data[:index])
chunks[1] = make([]byte, len(hostname))
copy(chunks[1], data[index:index+len(hostname)])
chunks[2] = make([]byte, len(data)-index-len(hostname))
copy(chunks[2], data[index+len(hostname):])

chunks[chunkIndexHostname] = make([]byte, index)
copy(chunks[chunkIndexHostname], data[:index])
chunks[chunkIndexSNIValue] = make([]byte, len(hostname))
copy(chunks[chunkIndexSNIValue], data[index:index+len(hostname)])
chunks[chunkIndexRemainder] = make([]byte, len(data)-index-len(hostname))
copy(chunks[chunkIndexRemainder], data[index+len(hostname):])
return chunks
}

Expand All @@ -80,21 +110,21 @@ func (s *Server) sendSplitChunks(dst io.Writer, chunks map[int][]byte, config Ch

for _, chunk := range chunks {
position := 0
for {

for position < len(chunk) {
chunkLength := rand.Intn(chunkLengthMax-chunkLengthMin) + chunkLengthMin
delay := rand.Intn(config.Delay[1]-config.Delay[0]) + config.Delay[0]
endPosition := position + chunkLength
if endPosition > len(chunk) {
endPosition = len(chunk)
if chunkLength > len(chunk)-position {
chunkLength = len(chunk) - position
}
_, errWrite := dst.Write(chunk[position:endPosition])

delay := rand.Intn(config.Delay[1]-config.Delay[0]) + config.Delay[0]

_, errWrite := dst.Write(chunk[position : position+chunkLength])
if errWrite != nil {
return
}
position = endPosition
if position == len(chunk) {
break
}

position += chunkLength
time.Sleep(time.Duration(delay) * time.Millisecond)
}
}
Expand Down Expand Up @@ -199,9 +229,9 @@ func (s *Server) Resolve(fqdn string, dohClient *doh.Client) (string, error) {
}

// Check the cache for fqdn
if cachedValue := s.Cache.Get(fqdn); cachedValue != nil {
if cachedValue, _ := s.Cache.Get(fqdn); cachedValue != nil {
s.Logger.Printf("using cached value for %s", fqdn)
return cachedValue.Value(), nil
return cachedValue.(string), nil
}

// Build request message
Expand All @@ -223,26 +253,18 @@ func (s *Server) Resolve(fqdn string, dohClient *doh.Client) (string, error) {
default:
exchange, err = s.resolveDNSWithDNSCrypt(&req)
}

if err != nil {
return "", err
}

// Parse answer and store in cache
answer := exchange.Answer[0]
s.Logger.Printf("resolved %s to %s", fqdn, answer.String())
record := strings.Fields(answer.String())
if record[3] == "CNAME" {
return s.Resolve(record[4], dohClient)
}

ttl, err := strconv.Atoi(record[1])
if err != nil {
return "", fmt.Errorf("invalid TTL value: %v", err)
}

ip := record[4]
s.Cache.Set(fqdn, ip, time.Duration(ttl)*time.Second)
s.Cache.Set(fqdn, ip)
return ip, nil
}

Expand Down

0 comments on commit fa5c33c

Please sign in to comment.