Skip to content

Commit 4fadaee

Browse files
committed
Refactor to use Endpoint.Equal
Compare IP first by default and compare DNS name first when we know the Endpoint was resolved.
1 parent babace5 commit 4fadaee

3 files changed

Lines changed: 63 additions & 46 deletions

File tree

pkg/mesh/mesh.go

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -679,22 +679,11 @@ func nodesAreEqual(a, b *Node) bool {
679679
if a == b {
680680
return true
681681
}
682-
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
682+
// Check the DNS name first since this package
683+
// is doing the DNS resolution.
684+
if !a.Endpoint.Equal(b.Endpoint, true) {
683685
return false
684686
}
685-
if a.Endpoint != nil {
686-
if a.Endpoint.Port != b.Endpoint.Port {
687-
return false
688-
}
689-
// Check the DNS name first since this package
690-
// is doing the DNS resolution.
691-
if a.Endpoint.DNS != b.Endpoint.DNS {
692-
return false
693-
}
694-
if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
695-
return false
696-
}
697-
}
698687
// Ignore LastSeen when comparing equality we want to check if the nodes are
699688
// equivalent. However, we do want to check if LastSeen has transitioned
700689
// between valid and invalid.
@@ -708,22 +697,11 @@ func peersAreEqual(a, b *Peer) bool {
708697
if a == b {
709698
return true
710699
}
711-
if !(a.Endpoint != nil) == (b.Endpoint != nil) {
700+
// Check the DNS name first since this package
701+
// is doing the DNS resolution.
702+
if !a.Endpoint.Equal(b.Endpoint, true) {
712703
return false
713704
}
714-
if a.Endpoint != nil {
715-
if a.Endpoint.Port != b.Endpoint.Port {
716-
return false
717-
}
718-
// Check the DNS name first since this package
719-
// is doing the DNS resolution.
720-
if a.Endpoint.DNS != b.Endpoint.DNS {
721-
return false
722-
}
723-
if a.Endpoint.DNS == "" && !a.Endpoint.IP.Equal(b.Endpoint.IP) {
724-
return false
725-
}
726-
}
727705
if len(a.AllowedIPs) != len(b.AllowedIPs) {
728706
return false
729707
}
@@ -778,7 +756,7 @@ func discoveredEndpointsAreEqual(a, b map[string]*wireguard.Endpoint) bool {
778756
return false
779757
}
780758
for k := range a {
781-
if !a[k].Equal(b[k]) {
759+
if !a[k].Equal(b[k], false) {
782760
return false
783761
}
784762
}
@@ -802,17 +780,17 @@ func discoverNATEndpoints(nodes map[string]*Node, peers map[string]*Peer, conf *
802780
}
803781
for _, n := range nodes {
804782
if peer, ok := keys[string(n.Key)]; ok && n.PersistentKeepalive > 0 {
805-
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint))
783+
level.Debug(logger).Log("msg", "WireGuard Update NAT Endpoint", "node", n.Name, "endpoint", peer.Endpoint, "former-endpoint", n.Endpoint, "same", n.Endpoint.Equal(peer.Endpoint, false))
806784
// Should check location leader but only available in topology ... or have topology handle that list
807785
// Better check wg latest-handshake
808-
if !n.Endpoint.Equal(peer.Endpoint) {
786+
if !n.Endpoint.Equal(peer.Endpoint, false) {
809787
natEndpoints[string(n.Key)] = peer.Endpoint
810788
}
811789
}
812790
}
813791
for _, p := range peers {
814792
if peer, ok := keys[string(p.PublicKey)]; ok && p.PersistentKeepalive > 0 {
815-
if !p.Endpoint.Equal(peer.Endpoint) {
793+
if !p.Endpoint.Equal(peer.Endpoint, false) {
816794
natEndpoints[string(p.PublicKey)] = peer.Endpoint
817795
}
818796
}

pkg/wireguard/conf.go

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,31 @@ func (e *Endpoint) String() string {
9696
}
9797

9898
// Equal compares two endpoints.
99-
func (e *Endpoint) Equal(b *Endpoint) bool {
99+
func (e *Endpoint) Equal(b *Endpoint, DNSFirst bool) bool {
100100
if (e == nil) != (b == nil) {
101101
return false
102102
}
103103
if e != nil {
104104
if e.Port != b.Port {
105105
return false
106106
}
107-
// IPs take priority, so check them first.
108-
if !e.IP.Equal(b.IP) {
109-
return false
110-
}
111-
// Only check the DNS name if the IP is empty.
112-
if e.IP == nil && e.DNS != b.DNS {
113-
return false
107+
if DNSFirst {
108+
// Check the DNS name first if it was resolved.
109+
if e.DNS != b.DNS {
110+
return false
111+
}
112+
if e.DNS == "" && !e.IP.Equal(b.IP) {
113+
return false
114+
}
115+
} else {
116+
// IPs take priority, so check them first.
117+
if !e.IP.Equal(b.IP) {
118+
return false
119+
}
120+
// Only check the DNS name if the IP is empty.
121+
if e.IP == nil && e.DNS != b.DNS {
122+
return false
123+
}
114124
}
115125
}
116126

@@ -331,7 +341,7 @@ func (c *Conf) Equal(b *Conf) bool {
331341
return false
332342
}
333343
}
334-
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint) {
344+
if !c.Peers[i].Endpoint.Equal(b.Peers[i].Endpoint, false) {
335345
return false
336346
}
337347
if c.Peers[i].PersistentKeepalive != b.Peers[i].PersistentKeepalive || !bytes.Equal(c.Peers[i].PresharedKey, b.Peers[i].PresharedKey) || !bytes.Equal(c.Peers[i].PublicKey, b.Peers[i].PublicKey) {

pkg/wireguard/conf_test.go

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,11 @@ func TestCompareConf(t *testing.T) {
207207

208208
func TestCompareEndpoint(t *testing.T) {
209209
for _, tc := range []struct {
210-
name string
211-
a *Endpoint
212-
b *Endpoint
213-
out bool
210+
name string
211+
a *Endpoint
212+
b *Endpoint
213+
dnsFirst bool
214+
out bool
214215
}{
215216
{
216217
name: "both nil",
@@ -272,8 +273,36 @@ func TestCompareEndpoint(t *testing.T) {
272273
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
273274
out: true,
274275
},
276+
{
277+
name: "DNS first, ignore IP",
278+
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: "a"}},
279+
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: "a"}},
280+
dnsFirst: true,
281+
out: true,
282+
},
283+
{
284+
name: "DNS first",
285+
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "a"}},
286+
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{DNS: "b"}},
287+
dnsFirst: true,
288+
out: false,
289+
},
290+
{
291+
name: "DNS first, no DNS compare IP",
292+
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
293+
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.2"), DNS: ""}},
294+
dnsFirst: true,
295+
out: false,
296+
},
297+
{
298+
name: "DNS first, no DNS compare IP (same)",
299+
a: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
300+
b: &Endpoint{Port: 1234, DNSOrIP: DNSOrIP{IP: net.ParseIP("192.168.0.1"), DNS: ""}},
301+
dnsFirst: true,
302+
out: true,
303+
},
275304
} {
276-
equal := tc.a.Equal(tc.b)
305+
equal := tc.a.Equal(tc.b, tc.dnsFirst)
277306
if equal != tc.out {
278307
t.Errorf("test case %q: expected %t, got %t", tc.name, tc.out, equal)
279308
}

0 commit comments

Comments
 (0)