summaryrefslogtreecommitdiff
blob: 5830036da62ecee9ed636bb0d965b035eb91535f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
https://github.com/imanel/websocket-ruby/commit/736a7515aff808a5c268b87066e928b59aed769e

From 736a7515aff808a5c268b87066e928b59aed769e Mon Sep 17 00:00:00 2001
From: Bernard Potocki <bernard.potocki@imanel.org>
Date: Thu, 17 Feb 2022 20:02:21 +0100
Subject: [PATCH] Ensure correct port is always specified (#48)

--- a/lib/websocket/handshake/base.rb
+++ b/lib/websocket/handshake/base.rb
@@ -7,7 +7,7 @@ class Base
       include ExceptionHandler
       include NiceInspect
 
-      attr_reader :host, :port, :path, :query,
+      attr_reader :host, :path, :query,
                   :state, :version, :secure,
                   :headers, :protocols
 
@@ -66,6 +66,20 @@ def leftovers
         (@leftovers.to_s.split("\n", reserved_leftover_lines + 1)[reserved_leftover_lines] || '').strip
       end
 
+      # Return default port for protocol (80 for ws, 443 for wss)
+      def default_port
+        secure ? 443 : 80
+      end
+
+      # Check if provided port is a default one
+      def default_port?
+        port == default_port
+      end
+
+      def port
+        @port || default_port
+      end
+
       # URI of request.
       # @return [String] Full URI with protocol
       # @example
@@ -73,7 +87,7 @@ def leftovers
       def uri
         uri =  String.new(secure ? 'wss://' : 'ws://')
         uri << host
-        uri << ":#{port}" if port
+        uri << ":#{port}" unless default_port?
         uri << path
         uri << "?#{query}" if query
         uri
--- a/lib/websocket/handshake/client.rb
+++ b/lib/websocket/handshake/client.rb
@@ -61,7 +61,7 @@ def initialize(args = {})
           uri = URI.parse(@url || @uri)
           @secure ||= (uri.scheme == 'wss')
           @host ||= uri.host
-          @port ||= uri.port
+          @port ||= uri.port || default_port
           @path ||= uri.path
           @query ||= uri.query
         end
--- a/lib/websocket/handshake/handler/client04.rb
+++ b/lib/websocket/handshake/handler/client04.rb
@@ -21,7 +21,7 @@ def handshake_keys
             %w[Connection Upgrade]
           ]
           host = @handshake.host
-          host += ":#{@handshake.port}" if @handshake.port
+          host += ":#{@handshake.port}" unless @handshake.default_port?
           keys << ['Host', host]
           keys += super
           keys << ['Sec-WebSocket-Origin', @handshake.origin] if @handshake.origin
--- a/lib/websocket/handshake/handler/client75.rb
+++ b/lib/websocket/handshake/handler/client75.rb
@@ -18,7 +18,7 @@ def handshake_keys
             %w[Connection Upgrade]
           ]
           host = @handshake.host
-          host += ":#{@handshake.port}" if @handshake.port
+          host += ":#{@handshake.port}" unless @handshake.default_port?
           keys << ['Host', host]
           keys << ['Origin', @handshake.origin] if @handshake.origin
           keys << ['WebSocket-Protocol', @handshake.protocols.first] if @handshake.protocols.any?
--- a/lib/websocket/handshake/server.rb
+++ b/lib/websocket/handshake/server.rb
@@ -129,13 +129,13 @@ def should_respond?
       # Host of server according to client header
       # @return [String] host
       def host
-        @headers['host'].to_s.split(':')[0].to_s
+        @host || @headers['host'].to_s.split(':')[0].to_s
       end
 
       # Port of server according to client header
-      # @return [String] port
+      # @return [Integer] port
       def port
-        @headers['host'].to_s.split(':')[1]
+        (@port || @headers['host'].to_s.split(':')[1] || default_port).to_i
       end
 
       private
--- a/spec/support/all_client_drafts.rb
+++ b/spec/support/all_client_drafts.rb
@@ -38,6 +38,10 @@ def validate_request
     expect(handshake.query).to eql('aaa=bbb')
   end
 
+  it 'returns default port' do
+    expect(handshake.port).to be(80)
+  end
+
   it 'returns valid port' do
     @request_params = { port: 123 }
     expect(handshake.port).to be(123)
--- a/spec/support/all_server_drafts.rb
+++ b/spec/support/all_server_drafts.rb
@@ -47,11 +47,17 @@ def validate_request
     expect(handshake.query).to eql('aaa=bbb')
   end
 
+  it 'returns default port' do
+    handshake << client_request
+
+    expect(handshake.port).to be(80)
+  end
+
   it 'returns valid port' do
     @request_params = { port: 123 }
     handshake << client_request
 
-    expect(handshake.port).to eql('123')
+    expect(handshake.port).to be(123)
   end
 
   it 'returns valid response' do