Skip to content

Commit

Permalink
rewriting tls and http host parser
Browse files Browse the repository at this point in the history
  • Loading branch information
uoosef committed Sep 3, 2023
1 parent 9f2378c commit 49d519c
Show file tree
Hide file tree
Showing 3 changed files with 400 additions and 212 deletions.
28 changes: 28 additions & 0 deletions server/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package server

import (
"bufio"
"io"
"net/http"
)

// ParseHTTPHost parses the head of the first HTTP request on conn and returns
// a new, unread connection with metadata for virtual host muxing
func ParseHTTPHost(rd io.Reader) (string, error) {
var request *http.Request
var err error
if request, err = http.ReadRequest(bufio.NewReader(rd)); err != nil {
return "", err
}

// You probably don't need access to the request body and this makes the API
// simpler by allowing you to call Free() optionally
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {

}
}(request.Body)

return request.Host, nil
}
227 changes: 15 additions & 212 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"math/rand"
"net"
"net/url"
"regexp"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -57,218 +56,17 @@ type Server struct {
Transport *transport.Transport
}

var sniRegex = regexp.MustCompile(`^(?:[a-z0-9-]+\.)+[a-z]+$`)

// getHostname returns the Server Name Indication (SNI) from a TLS Client Hello message.
func (s *Server) getHostnameRegex(data []byte) ([]byte, error) {
const (
sniTypeByte = 0x00
sniLengthOffset = 2
)

if data[0] != 0x16 {
return nil, fmt.Errorf("not a tls packet")
}

// Find the SNI type byte
sniTypeIndex := bytes.IndexByte(data, sniTypeByte)
if sniTypeIndex == -1 {
return nil, fmt.Errorf("could not find SNI type byte in Server Hello message")
}

// Ensure sufficient data to read the SNI length and value
if len(data) < sniTypeIndex+sniLengthOffset+1 {
return nil, fmt.Errorf("insufficient data to read SNI length")
}

var sni string
var prev byte
for i := 0; i < len(data)-1; i++ {
if prev == 0 && data[i] == 0 {
start := i + 2
end := start + int(data[i+1])
if start < end && end < len(data) {
str := string(data[start:end])
if sniRegex.MatchString(str) {
sni = str
break
}
}
}
prev = data[i]
}
return []byte(sni), nil
}

// getHostname This function is basically all most folks want to invoke out of this
// getHostname This function extracts the tls sni or http
func (s *Server) getHostname(data []byte) ([]byte, error) {
extensions, err := s.getExtensionBlock(data)
shouldUseNewRegexMethod := !s.WorkerConfig.WorkerEnabled || (s.WorkerConfig.WorkerEnabled && s.WorkerConfig.WorkerDNSOnly)
if err != nil {
if shouldUseNewRegexMethod {
return s.getHostnameRegex(data)
}
return nil, err
}
sn, err := s.getSNBlock(extensions)
hello, err := ReadClientHello(bytes.NewReader(data))
if err != nil {
if shouldUseNewRegexMethod {
return s.getHostnameRegex(data)
}
return nil, err
}
sni, err := s.getSNIBlock(sn)
if err != nil {
return s.getHostnameRegex(data)
}
return sni, nil
}

/* Return the length computed from the two octets starting at index */
func (s *Server) lengthFromData(data []byte, index int) int {
if index < 0 || index+1 >= len(data) {
return 0
}

b1 := int(data[index])
b2 := int(data[index+1])

return (b1 << 8) + b2
}

// getSNIBlock /* Given a Server Name TLS Extension block, parse out and return the SNI
func (s *Server) getSNIBlock(data []byte) ([]byte, error) {
index := 0

for {
if index >= len(data) {
break
}
length := s.lengthFromData(data, index)
endIndex := index + 2 + length
if data[index+2] == 0x00 { /* SNI */
sni := data[index+3:]
sniLength := s.lengthFromData(sni, 0)
return sni[2 : sniLength+2], nil
}
index = endIndex
}
return []byte{}, fmt.Errorf(
"finished parsing the SN block without finding an SNI",
)
}

// getSNBlock finds the SN block given a TLS Extensions data block.
func (s *Server) getSNBlock(data []byte) ([]byte, error) {
if len(data) < 4 {
return nil, fmt.Errorf("not enough bytes to be an SN block")
}

extensionLength := s.lengthFromData(data, 0)
if extensionLength+4 > len(data) {
return nil, fmt.Errorf("extension size is invalid")
}
data = data[2 : extensionLength+2]

for index := 0; index+4 < len(data); {
blockLength := s.lengthFromData(data, index+2)
endIndex := index + 4 + blockLength
if data[index] == 0x00 && data[index+1] == 0x00 {
return data[index+4 : endIndex], nil
host, err := ParseHTTPHost(bytes.NewReader(data))
if err != nil {
return nil, err
}

index = endIndex
return []byte(host), nil
}

return nil, fmt.Errorf("SN block not found within the Extension block")
}

// getExtensionBlock finds the extension block given a raw TLS Client Hello.
func (s *Server) getExtensionBlock(data []byte) ([]byte, error) {
dataLen := len(data)
index := s.ChunkConfig.TLSHeaderLength + 38

if dataLen <= index+1 {
return nil, fmt.Errorf("not enough bits to be a Client Hello")
}

_, newIndex, err := s.getSessionIDLength(data, index)
if err != nil {
return nil, err
}
index = newIndex

_, newIndex, err = s.getCipherListLength(data, index)
if err != nil {
return nil, err
}
index = newIndex

_, newIndex, err = s.getCompressionLength(data, index)
if err != nil {
return nil, err
}
index = newIndex

if len(data[index:]) == 0 {
return nil, fmt.Errorf("no extensions")
}

return data[index:], nil
}

// getSessionIDLength retrieves the session ID length from the TLS Client Hello data.
func (s *Server) getSessionIDLength(data []byte, index int) (int, int, error) {
dataLen := len(data)

if index+1 >= dataLen {
return 0, 0, fmt.Errorf("not enough bytes for the SessionID")
}

sessionIDLength := int(data[index])
newIndex := index + 1 + sessionIDLength

if newIndex+2 >= dataLen {
return 0, 0, fmt.Errorf("not enough bytes for the SessionID")
}

return sessionIDLength, newIndex, nil
}

// getCipherListLength retrieves the cipher list length from the TLS Client Hello data.
func (s *Server) getCipherListLength(data []byte, index int) (int, int, error) {
dataLen := len(data)

if index+2 >= dataLen {
return 0, 0, fmt.Errorf("not enough bytes for the Cipher List")
}

cipherListLength := s.lengthFromData(data, index)
newIndex := index + 2 + cipherListLength

if newIndex+1 >= dataLen {
return 0, 0, fmt.Errorf("not enough bytes for the Cipher List")
}

return cipherListLength, newIndex, nil
}

// getCompressionLength retrieves the compression length from the TLS Client Hello data.
func (s *Server) getCompressionLength(data []byte, index int) (int, int, error) {
dataLen := len(data)

if index+1 >= dataLen {
return 0, 0, fmt.Errorf("not enough bytes for the compression length")
}

compressionLength := int(data[index])
newIndex := index + 1 + compressionLength

if newIndex >= dataLen {
return 0, 0, fmt.Errorf("not enough bytes for the compression length")
}

return compressionLength, newIndex, nil
return []byte(hello.ServerName), nil
}

func (s *Server) getChunkedPackets(data []byte) map[int][]byte {
Expand Down Expand Up @@ -316,7 +114,12 @@ func (s *Server) sendSplitChunks(dst io.Writer, chunks map[int][]byte) {
chunkLength = len(chunk) - position
}

delay := rand.Intn(s.ChunkConfig.Delay[1]-s.ChunkConfig.Delay[0]) + s.ChunkConfig.Delay[0]
var delay int
if s.ChunkConfig.Delay[1]-s.ChunkConfig.Delay[0] > 0 {
delay = rand.Intn(s.ChunkConfig.Delay[1]-s.ChunkConfig.Delay[0]) + s.ChunkConfig.Delay[0]
} else {
delay = s.ChunkConfig.Delay[0]
}

_, errWrite := dst.Write(chunk[position : position+chunkLength])
if errWrite != nil {
Expand Down Expand Up @@ -403,7 +206,7 @@ func (s *Server) sendChunks(dst io.Writer, src io.Reader, shouldSplit bool, wg *
bytesRead, err := src.Read(dataBuffer)
if bytesRead > 0 {
// check if it's the first packet and its tls packet
if index == 0 && dataBuffer[0] == 0x16 && shouldSplit {
if index == 0 && shouldSplit {
chunks := s.getChunkedPackets(dataBuffer[:bytesRead])
s.sendSplitChunks(dst, chunks)
} else {
Expand Down Expand Up @@ -480,7 +283,7 @@ func (s *Server) Resolve(fqdn string) (string, error) {
}
// Parse answer and store in cache
answer := exchange.Answer[0]
logger.Infof("resolved %s to %s", fqdn, answer.String())
logger.Infof("resolved %s to %s", fqdn, strings.Replace(answer.String(), "\t", " ", -1))
record := strings.Fields(answer.String())
if record[3] == "CNAME" {
ip, err := s.Resolve(record[4])
Expand Down
Loading

0 comments on commit 49d519c

Please sign in to comment.