tunnel.go
7.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
package server
import (
"encoding/base64"
"fmt"
"math/rand"
"net"
"ngrok/conn"
"ngrok/log"
"ngrok/msg"
"ngrok/util"
"os"
"strconv"
"strings"
"sync/atomic"
"time"
)
var defaultPortMap = map[string]int{
"http": 80,
"https": 443,
"smtp": 25,
}
/**
* Tunnel: A control connection, metadata and proxy connections which
* route public traffic to a firewalled endpoint.
*/
type Tunnel struct {
// request that opened the tunnel
req *msg.ReqTunnel
// time when the tunnel was opened
start time.Time
// public url
url string
// tcp listener
listener *net.TCPListener
// control connection
ctl *Control
// logger
log.Logger
// closing
closing int32
}
// Common functionality for registering virtually hosted protocols
func registerVhost(t *Tunnel, protocol string, servingPort int) (err error) {
vhost := os.Getenv("VHOST")
if vhost == "" {
vhost = fmt.Sprintf("%s:%d", opts.domain, servingPort)
}
// Canonicalize virtual host by removing default port (e.g. :80 on HTTP)
defaultPort, ok := defaultPortMap[protocol]
if !ok {
return fmt.Errorf("Couldn't find default port for protocol %s", protocol)
}
defaultPortSuffix := fmt.Sprintf(":%d", defaultPort)
if strings.HasSuffix(vhost, defaultPortSuffix) {
vhost = vhost[0 : len(vhost)-len(defaultPortSuffix)]
}
// Canonicalize by always using lower-case
vhost = strings.ToLower(vhost)
// Register for specific hostname
hostname := strings.ToLower(strings.TrimSpace(t.req.Hostname))
if hostname != "" {
t.url = fmt.Sprintf("%s://%s", protocol, hostname)
return tunnelRegistry.Register(t.url, t)
}
// Register for specific subdomain
subdomain := strings.ToLower(strings.TrimSpace(t.req.Subdomain))
if subdomain != "" {
t.url = fmt.Sprintf("%s://%s.%s", protocol, subdomain, vhost)
return tunnelRegistry.Register(t.url, t)
}
// Register for random URL
t.url, err = tunnelRegistry.RegisterRepeat(func() string {
return fmt.Sprintf("%s://%x.%s", protocol, rand.Int31(), vhost)
}, t)
return
}
// Create a new tunnel from a registration message received
// on a control channel
func NewTunnel(m *msg.ReqTunnel, ctl *Control) (t *Tunnel, err error) {
t = &Tunnel{
req: m,
start: time.Now(),
ctl: ctl,
Logger: log.NewPrefixLogger(),
}
proto := t.req.Protocol
switch proto {
case "tcp":
bindTcp := func(port int) error {
if t.listener, err = net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: port}); err != nil {
err = t.ctl.conn.Error("Error binding TCP listener: %v", err)
return err
}
// create the url
addr := t.listener.Addr().(*net.TCPAddr)
t.url = fmt.Sprintf("tcp://%s:%d", opts.domain, addr.Port)
// register it
if err = tunnelRegistry.RegisterAndCache(t.url, t); err != nil {
// This should never be possible because the OS will
// only assign available ports to us.
t.listener.Close()
err = fmt.Errorf("TCP listener bound, but failed to register %s", t.url)
return err
}
go t.listenTcp(t.listener)
return nil
}
// use the custom remote port you asked for
if t.req.RemotePort != 0 {
bindTcp(int(t.req.RemotePort))
return
}
// try to return to you the same port you had before
cachedUrl := tunnelRegistry.GetCachedRegistration(t)
if cachedUrl != "" {
var port int
parts := strings.Split(cachedUrl, ":")
portPart := parts[len(parts)-1]
port, err = strconv.Atoi(portPart)
if err != nil {
t.ctl.conn.Error("Failed to parse cached url port as integer: %s", portPart)
} else {
// we have a valid, cached port, let's try to bind with it
if bindTcp(port) != nil {
t.ctl.conn.Warn("Failed to get custom port %d: %v, trying a random one", port, err)
} else {
// success, we're done
return
}
}
}
// Bind for TCP connections
bindTcp(0)
return
case "http", "https":
l, ok := listeners[proto]
if !ok {
err = fmt.Errorf("Not listening for %s connections", proto)
return
}
if err = registerVhost(t, proto, l.Addr.(*net.TCPAddr).Port); err != nil {
return
}
default:
err = fmt.Errorf("Protocol %s is not supported", proto)
return
}
// pre-encode the http basic auth for fast comparisons later
if m.HttpAuth != "" {
m.HttpAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(m.HttpAuth))
}
t.AddLogPrefix(t.Id())
t.Info("Registered new tunnel on: %s", t.ctl.conn.Id())
metrics.OpenTunnel(t)
return
}
func (t *Tunnel) Shutdown() {
t.Info("Shutting down")
// mark that we're shutting down
atomic.StoreInt32(&t.closing, 1)
// if we have a public listener (this is a raw TCP tunnel), shut it down
if t.listener != nil {
t.listener.Close()
}
// remove ourselves from the tunnel registry
tunnelRegistry.Del(t.url)
// let the control connection know we're shutting down
// currently, only the control connection shuts down tunnels,
// so it doesn't need to know about it
// t.ctl.stoptunnel <- t
metrics.CloseTunnel(t)
}
func (t *Tunnel) Id() string {
return t.url
}
// Listens for new public tcp connections from the internet.
func (t *Tunnel) listenTcp(listener *net.TCPListener) {
for {
defer func() {
if r := recover(); r != nil {
log.Warn("listenTcp failed with error %v", r)
}
}()
// accept public connections
tcpConn, err := listener.AcceptTCP()
if err != nil {
// not an error, we're shutting down this tunnel
if atomic.LoadInt32(&t.closing) == 1 {
return
}
t.Error("Failed to accept new TCP connection: %v", err)
continue
}
conn := conn.Wrap(tcpConn, "pub")
conn.AddLogPrefix(t.Id())
conn.Info("New connection from %v", conn.RemoteAddr())
go t.HandlePublicConnection(conn)
}
}
func (t *Tunnel) HandlePublicConnection(publicConn conn.Conn) {
defer publicConn.Close()
defer func() {
if r := recover(); r != nil {
publicConn.Warn("HandlePublicConnection failed with error %v", r)
}
}()
startTime := time.Now()
metrics.OpenConnection(t, publicConn)
var proxyConn conn.Conn
var err error
for i := 0; i < (2 * proxyMaxPoolSize); i++ {
// get a proxy connection
if proxyConn, err = t.ctl.GetProxy(); err != nil {
t.Warn("Failed to get proxy connection: %v", err)
return
}
defer proxyConn.Close()
t.Info("Got proxy connection %s", proxyConn.Id())
proxyConn.AddLogPrefix(t.Id())
// tell the client we're going to start using this proxy connection
startPxyMsg := &msg.StartProxy{
Url: t.url,
ClientAddr: publicConn.RemoteAddr().String(),
}
if err = msg.WriteMsg(proxyConn, startPxyMsg); err != nil {
proxyConn.Warn("Failed to write StartProxyMessage: %v, attempt %d", err, i)
proxyConn.Close()
} else {
// success
break
}
}
if err != nil {
// give up
publicConn.Error("Too many failures starting proxy connection")
return
}
// To reduce latency handling tunnel connections, we employ the following curde heuristic:
// Whenever we take a proxy connection from the pool, replace it with a new one
util.PanicToError(func() { t.ctl.out <- &msg.ReqProxy{} })
// no timeouts while connections are joined
proxyConn.SetDeadline(time.Time{})
// join the public and proxy connections
bytesIn, bytesOut := conn.Join(publicConn, proxyConn)
metrics.CloseConnection(t, publicConn, startTime, bytesIn, bytesOut)
//log.Info("Proxy authId=%s bytesIn=%d, bytesOut=%d\n", t.ctl.userInfo.Uc.UserId, bytesIn, bytesOut)
atomic.AddInt32(&t.ctl.userInfo.TransPerDay, int32(bytesIn+bytesOut))
atomic.AddInt32(&t.ctl.userInfo.TransAll, int32(bytesIn+bytesOut))
}