Skip to content

Commit

Permalink
Odds and ends
Browse files Browse the repository at this point in the history
  • Loading branch information
zx2c4 committed May 13, 2018
1 parent e941856 commit 2326d6a
Show file tree
Hide file tree
Showing 16 changed files with 139 additions and 164 deletions.
3 changes: 0 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,4 @@ wireguard-go: $(wildcard *.go)
clean:
rm -f wireguard-go

cloc:
cloc $(filter-out xchacha20.go $(wildcard *_test.go), $(wildcard *.go))

.PHONY: clean cloc
110 changes: 75 additions & 35 deletions trie.go → allowedips.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,12 @@ package main
import (
"errors"
"net"
"sync"
)

/* Binary trie
*
* The net.IPs used here are not formatted the
* same way as those created by the "net" functions.
* Here the IPs are slices of either 4 or 16 byte (not always 16)
*
* Synchronization done separately
* See: routing.go
*/

type Trie struct {
type trieEntry struct {
cidr uint
child [2]*Trie
child [2]*trieEntry
bits []byte
peer *Peer

Expand Down Expand Up @@ -90,15 +81,15 @@ func commonBits(ip1 []byte, ip2 []byte) uint {
return i * 8
}

func (node *Trie) RemovePeer(p *Peer) *Trie {
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
if node == nil {
return node
}

// walk recursively

node.child[0] = node.child[0].RemovePeer(p)
node.child[1] = node.child[1].RemovePeer(p)
node.child[0] = node.child[0].removeByPeer(p)
node.child[1] = node.child[1].removeByPeer(p)

if node.peer != p {
return node
Expand All @@ -113,16 +104,16 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
return node.child[0]
}

func (node *Trie) choose(ip net.IP) byte {
func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
}

func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {

// at leaf

if node == nil {
return &Trie{
return &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
Expand All @@ -140,13 +131,13 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
return node
}
bit := node.choose(ip)
node.child[bit] = node.child[bit].Insert(ip, cidr, peer)
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
return node
}

// split node

newNode := &Trie{
newNode := &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
Expand All @@ -166,7 +157,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {

// create new parent for node & newNode

parent := &Trie{
parent := &trieEntry{
bits: ip,
peer: nil,
cidr: cidr,
Expand All @@ -181,7 +172,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
return parent
}

func (node *Trie) Lookup(ip net.IP) *Peer {
func (node *trieEntry) lookup(ip net.IP) *Peer {
var found *Peer
size := uint(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
Expand All @@ -197,16 +188,7 @@ func (node *Trie) Lookup(ip net.IP) *Peer {
return found
}

func (node *Trie) Count() uint {
if node == nil {
return 0
}
l := node.child[0].Count()
r := node.child[1].Count()
return l + r
}

func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
if node == nil {
return results
}
Expand All @@ -223,11 +205,69 @@ func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
} else if len(node.bits) == net.IPv6len {
mask.IP = node.bits
} else {
panic(errors.New("bug: unexpected address length"))
panic(errors.New("unexpected address length"))
}
results = append(results, mask)
}
results = node.child[0].AllowedIPs(p, results)
results = node.child[1].AllowedIPs(p, results)
results = node.child[0].entriesForPeer(p, results)
results = node.child[1].entriesForPeer(p, results)
return results
}

type AllowedIPs struct {
IPv4 *trieEntry
IPv6 *trieEntry
mutex sync.RWMutex
}

func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
table.mutex.RLock()
defer table.mutex.RUnlock()

allowed := make([]net.IPNet, 0, 10)
allowed = table.IPv4.entriesForPeer(peer, allowed)
allowed = table.IPv6.entriesForPeer(peer, allowed)
return allowed
}

func (table *AllowedIPs) Reset() {
table.mutex.Lock()
defer table.mutex.Unlock()

table.IPv4 = nil
table.IPv6 = nil
}

func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()

table.IPv4 = table.IPv4.removeByPeer(peer)
table.IPv6 = table.IPv6.removeByPeer(peer)
}

func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()

switch len(ip) {
case net.IPv6len:
table.IPv6 = table.IPv6.insert(ip, cidr, peer)
case net.IPv4len:
table.IPv4 = table.IPv4.insert(ip, cidr, peer)
default:
panic(errors.New("inserting unknown address type"))
}
}

func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.IPv4.lookup(address)
}

func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.IPv6.lookup(address)
}
16 changes: 8 additions & 8 deletions trie_rand_test.go → allowedips_rand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
}

func TestTrieRandomIPv4(t *testing.T) {
var trie *Trie
var trie *trieEntry
var slow SlowRouter
var peers []*Peer

Expand All @@ -82,23 +82,23 @@ func TestTrieRandomIPv4(t *testing.T) {
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
trie = trie.Insert(addr[:], cidr, peers[index])
trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}

for n := 0; n < NumberOfTests; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
peer2 := trie.Lookup(addr[:])
peer2 := trie.lookup(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
t.Error("trieEntry did not match naive implementation, for:", addr)
}
}
}

func TestTrieRandomIPv6(t *testing.T) {
var trie *Trie
var trie *trieEntry
var slow SlowRouter
var peers []*Peer

Expand All @@ -115,17 +115,17 @@ func TestTrieRandomIPv6(t *testing.T) {
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
trie = trie.Insert(addr[:], cidr, peers[index])
trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}

for n := 0; n < NumberOfTests; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
peer2 := trie.Lookup(addr[:])
peer2 := trie.lookup(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
t.Error("trieEntry did not match naive implementation, for:", addr)
}
}
}
26 changes: 13 additions & 13 deletions trie_test.go → allowedips_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type testPairTrieLookup struct {
peer *Peer
}

func printTrie(t *testing.T, p *Trie) {
func printTrie(t *testing.T, p *trieEntry) {
if p == nil {
return
}
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestCommonBits(t *testing.T) {
}

func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
var trie *Trie
var trie *trieEntry
var peers []*Peer

rand.Seed(1)
Expand All @@ -79,13 +79,13 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber
trie = trie.Insert(addr[:], cidr, peers[index])
trie = trie.insert(addr[:], cidr, peers[index])
}

for n := 0; n < b.N; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
trie.Lookup(addr[:])
trie.lookup(addr[:])
}
}

Expand Down Expand Up @@ -117,21 +117,21 @@ func TestTrieIPv4(t *testing.T) {
g := &Peer{}
h := &Peer{}

var trie *Trie
var trie *trieEntry

insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
trie = trie.Insert([]byte{a, b, c, d}, cidr, peer)
trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
}

assertEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.Lookup([]byte{a, b, c, d})
p := trie.lookup([]byte{a, b, c, d})
if p != peer {
t.Error("Assert EQ failed")
}
}

assertNEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.Lookup([]byte{a, b, c, d})
p := trie.lookup([]byte{a, b, c, d})
if p == peer {
t.Error("Assert NEQ failed")
}
Expand Down Expand Up @@ -173,7 +173,7 @@ func TestTrieIPv4(t *testing.T) {
assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0)

trie = trie.RemovePeer(a)
trie = trie.removeByPeer(a)

assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0)
Expand All @@ -186,7 +186,7 @@ func TestTrieIPv4(t *testing.T) {
insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24)

trie = trie.RemovePeer(a)
trie = trie.removeByPeer(a)

assertNEQ(a, 192, 168, 0, 1)
}
Expand All @@ -204,7 +204,7 @@ func TestTrieIPv6(t *testing.T) {
g := &Peer{}
h := &Peer{}

var trie *Trie
var trie *trieEntry

expand := func(a uint32) []byte {
var out [4]byte
Expand All @@ -221,7 +221,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
trie = trie.Insert(addr, cidr, peer)
trie = trie.insert(addr, cidr, peer)
}

assertEQ := func(peer *Peer, a, b, c, d uint32) {
Expand All @@ -230,7 +230,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
p := trie.Lookup(addr)
p := trie.lookup(addr)
if p != peer {
t.Error("Assert EQ failed")
}
Expand Down
4 changes: 2 additions & 2 deletions device.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type Device struct {

routing struct {
mutex sync.RWMutex
table RoutingTable
table AllowedIPs
}

peers struct {
Expand Down Expand Up @@ -95,7 +95,7 @@ func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {

// stop routing and processing of packets

device.routing.table.RemovePeer(peer)
device.routing.table.RemoveByPeer(peer)
peer.Stop()

// remove from peer map
Expand Down
2 changes: 1 addition & 1 deletion keypair.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Keypairs struct {
mutex sync.RWMutex
current *Keypair
previous *Keypair
next *Keypair // not yet "confirmed by transport"
next *Keypair
}

func (kp *Keypairs) Current() *Keypair {
Expand Down
2 changes: 1 addition & 1 deletion logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewLogger(level int, prepend string) *Logger {

logger.Debug = log.New(logDebug,
"DEBUG: "+prepend,
log.Ldate|log.Ltime|log.Lshortfile,
log.Ldate|log.Ltime,
)

logger.Info = log.New(logInfo,
Expand Down
Loading

0 comments on commit 2326d6a

Please sign in to comment.