summaryrefslogtreecommitdiff
blob: ca1f0fdf41c7c79de7f3f7cb47d5158efa995054 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
https://todo.sr.ht/~emersion/soju/207

diff -u b/user.go b/user.go
--- b/user.go
+++ b/user.go
@@ -218,6 +218,7 @@
 	net.user.srv.metrics.upstreams.Add(1)
 	defer net.user.srv.metrics.upstreams.Add(-1)
 
+	done := ctx.Done()
 	ctx, cancel := context.WithTimeout(ctx, time.Minute)
 	defer cancel()
 
@@ -227,6 +228,12 @@
 	}
 	defer uc.Close()
 
+	// The context is cancelled by the caller when the network is stopped.
+	go func() {
+		<-done
+		uc.Close()
+	}()
+
 	if net.user.srv.Identd != nil {
 		net.user.srv.Identd.Store(uc.RemoteAddr().String(), uc.LocalAddr().String(), userIdent(&net.user.User))
 		defer net.user.srv.Identd.Delete(uc.RemoteAddr().String(), uc.LocalAddr().String())
@@ -239,9 +246,6 @@
 		return fmt.Errorf("failed to register: %w", err)
 	}
 
-	// TODO: this is racy with net.stopped. If the network is stopped
-	// before the user goroutine receives eventUpstreamConnected, the
-	// connection won't be closed.
 	net.user.events <- eventUpstreamConnected{uc}
 	defer func() {
 		net.user.events <- eventUpstreamDisconnected{uc}
@@ -259,6 +263,12 @@
 		return
 	}
 
+	ctx, cancel := context.WithCancel(context.TODO())
+	go func() {
+		<-net.stopped
+		cancel()
+	}()
+
 	var lastTry time.Time
 	backoff := newBackoffer(retryConnectMinDelay, retryConnectMaxDelay, retryConnectJitter)
 	for {
@@ -273,7 +283,7 @@
 		}
 		lastTry = time.Now()
 
-		if err := net.runConn(context.TODO()); err != nil {
+		if err := net.runConn(ctx); err != nil {
 			text := err.Error()
 			temp := true
 			var regErr registrationError
@@ -299,10 +309,6 @@
 	if !net.isStopped() {
 		close(net.stopped)
 	}
-
-	if net.conn != nil {
-		net.conn.Close()
-	}
 }
 
 func (net *network) detach(ch *database.Channel) {