-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Expand file tree
/
Copy pathconfig.go
More file actions
184 lines (158 loc) · 5.18 KB
/
config.go
File metadata and controls
184 lines (158 loc) · 5.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
package postgresql
import (
"fmt"
"net"
"net/url"
"regexp"
"sort"
"strings"
"time"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
"github.com/influxdata/telegraf/config"
)
var socketRegexp = regexp.MustCompile(`/\.s\.PGSQL\.\d+$`)
var sanitizer = regexp.MustCompile(`(\s|^)((?:password|sslcert|sslkey|sslmode|sslrootcert)\s?=\s?(?:(?:'(?:[^'\\]|\\.)*')|(?:\S+)))`)
type Config struct {
Address config.Secret `toml:"address"`
OutputAddress string `toml:"outputaddress"`
MaxIdle int `toml:"max_idle"`
MaxOpen int `toml:"max_open"`
MaxLifetime config.Duration `toml:"max_lifetime"`
IsPgBouncer bool `toml:"-"`
}
func (c *Config) CreateService() (*Service, error) {
addrSecret, err := c.Address.Get()
if err != nil {
return nil, fmt.Errorf("getting address failed: %w", err)
}
addr := addrSecret.String()
defer addrSecret.Destroy()
if c.Address.Empty() || addr == "localhost" {
addr = "host=localhost sslmode=disable"
if err := c.Address.Set([]byte(addr)); err != nil {
return nil, err
}
}
connConfig, err := pgx.ParseConfig(addr)
if err != nil {
return nil, err
}
// Remove the socket name from the path
connConfig.Host = socketRegexp.ReplaceAllLiteralString(connConfig.Host, "")
// Specific support to make it work with PgBouncer too
// See https://github.com/influxdata/telegraf/issues/3253#issuecomment-357505343
if c.IsPgBouncer {
// Remove DriveConfig and revert it by the ParseConfig method
// See https://github.com/influxdata/telegraf/issues/9134
connConfig.PreferSimpleProtocol = true
}
// Provide the connection string without sensitive information for use as
// tag or other output properties
sanitizedAddr, err := c.sanitizedAddress()
if err != nil {
return nil, err
}
return &Service{
SanitizedAddress: sanitizedAddr,
ConnectionDatabase: connectionDatabase(sanitizedAddr),
maxIdle: c.MaxIdle,
maxOpen: c.MaxOpen,
maxLifetime: time.Duration(c.MaxLifetime),
dsn: stdlib.RegisterConnConfig(connConfig),
}, nil
}
// connectionDatabase determines the database to which the connection was made
func connectionDatabase(sanitizedAddr string) string {
connConfig, err := pgx.ParseConfig(sanitizedAddr)
if err != nil || connConfig.Database == "" {
return "postgres"
}
return connConfig.Database
}
// sanitizedAddress strips sensitive information from the connection string.
// If the user set the output address use that before parsing anything else.
func (c *Config) sanitizedAddress() (string, error) {
if c.OutputAddress != "" {
return c.OutputAddress, nil
}
// Get the address
addrSecret, err := c.Address.Get()
if err != nil {
return "", fmt.Errorf("getting address for sanitization failed: %w", err)
}
defer addrSecret.Destroy()
// Make sure we convert URI-formatted strings into key-values
addr := addrSecret.TemporaryString()
if strings.HasPrefix(addr, "postgres://") || strings.HasPrefix(addr, "postgresql://") {
if addr, err = toKeyValue(addr); err != nil {
return "", err
}
}
// Sanitize the string using a regular expression
sanitized := sanitizer.ReplaceAllString(addr, "")
return strings.TrimSpace(sanitized), nil
}
// Based on parseURLSettings() at https://github.com/jackc/pgx/blob/master/pgconn/config.go
func toKeyValue(uri string) (string, error) {
u, err := url.Parse(uri)
if err != nil {
return "", fmt.Errorf("parsing URI failed: %w", err)
}
// Check the protocol
if u.Scheme != "postgres" && u.Scheme != "postgresql" {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
quoteIfNecessary := func(v string) string {
if !strings.ContainsAny(v, ` ='\`) {
return v
}
r := strings.ReplaceAll(v, `\`, `\\`)
r = strings.ReplaceAll(r, `'`, `\'`)
return "'" + r + "'"
}
// Extract the parameters
parts := make([]string, 0, len(u.Query())+5)
if u.User != nil {
parts = append(parts, "user="+quoteIfNecessary(u.User.Username()))
if password, found := u.User.Password(); found {
parts = append(parts, "password="+quoteIfNecessary(password))
}
}
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
hostParts := strings.Split(u.Host, ",")
hosts := make([]string, 0, len(hostParts))
ports := make([]string, 0, len(hostParts))
var anyPortSet bool
for _, host := range hostParts {
if host == "" {
continue
}
h, p, err := net.SplitHostPort(host)
if err != nil {
if !strings.Contains(err.Error(), "missing port") {
return "", fmt.Errorf("failed to process host %q: %w", host, err)
}
h = host
}
anyPortSet = anyPortSet || err == nil
hosts = append(hosts, h)
ports = append(ports, p)
}
if len(hosts) > 0 {
parts = append(parts, "host="+strings.Join(hosts, ","))
}
if anyPortSet {
parts = append(parts, "port="+strings.Join(ports, ","))
}
database := strings.TrimLeft(u.Path, "/")
if database != "" {
parts = append(parts, "dbname="+quoteIfNecessary(database))
}
for k, v := range u.Query() {
parts = append(parts, k+"="+quoteIfNecessary(strings.Join(v, ",")))
}
// Required to produce a repeatable output e.g. for tags or testing
sort.Strings(parts)
return strings.Join(parts, " "), nil
}