diff options
Diffstat (limited to 'golib/httpclient.go')
-rw-r--r-- | golib/httpclient.go | 257 |
1 files changed, 257 insertions, 0 deletions
diff --git a/golib/httpclient.go b/golib/httpclient.go new file mode 100644 index 0000000..72132bf --- /dev/null +++ b/golib/httpclient.go @@ -0,0 +1,257 @@ +package common + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "strings" + + "github.com/Sirupsen/logrus" +) + +type HTTPClient struct { + httpClient http.Client + endpoint string + apikey string + username string + password string + id string + csrf string + conf HTTPClientConfig + logger *logrus.Logger +} + +type HTTPClientConfig struct { + URLPrefix string + HeaderAPIKeyName string + Apikey string + HeaderClientKeyName string + CsrfDisable bool +} + +const ( + logError = 1 + logWarning = 2 + logInfo = 3 + logDebug = 4 +) + +// Inspired by syncthing/cmd/cli + +const insecure = false + +// HTTPNewClient creates a new HTTP client to deal with Syncthing +func HTTPNewClient(baseURL string, cfg HTTPClientConfig) (*HTTPClient, error) { + + // Create w new Http client + httpClient := http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: insecure, + }, + }, + } + client := HTTPClient{ + httpClient: httpClient, + endpoint: baseURL, + apikey: cfg.Apikey, + conf: cfg, + /* TODO - add user + pwd support + username: c.GlobalString("username"), + password: c.GlobalString("password"), + */ + } + + if client.apikey == "" { + if err := client.getCidAndCsrf(); err != nil { + return nil, err + } + } + return &client, nil +} + +// SetLogger Define the logger to use +func (c *HTTPClient) SetLogger(log *logrus.Logger) { + c.logger = log +} + +func (c *HTTPClient) log(level int, format string, args ...interface{}) { + if c.logger != nil { + switch level { + case logError: + c.logger.Errorf(format, args...) + break + case logWarning: + c.logger.Warningf(format, args...) + break + case logInfo: + c.logger.Infof(format, args...) + break + default: + c.logger.Debugf(format, args...) + break + } + } +} + +// Send request to retrieve Client id and/or CSRF token +func (c *HTTPClient) getCidAndCsrf() error { + request, err := http.NewRequest("GET", c.endpoint, nil) + if err != nil { + return err + } + if _, err := c.handleRequest(request); err != nil { + return err + } + if c.id == "" { + return errors.New("Failed to get device ID") + } + if !c.conf.CsrfDisable && c.csrf == "" { + return errors.New("Failed to get CSRF token") + } + return nil +} + +// GetClientID returns the id +func (c *HTTPClient) GetClientID() string { + return c.id +} + +// formatURL Build full url by concatenating all parts +func (c *HTTPClient) formatURL(endURL string) string { + url := c.endpoint + if !strings.HasSuffix(url, "/") { + url += "/" + } + url += strings.TrimLeft(c.conf.URLPrefix, "/") + if !strings.HasSuffix(url, "/") { + url += "/" + } + return url + strings.TrimLeft(endURL, "/") +} + +// HTTPGet Send a Get request to client and return an error object +func (c *HTTPClient) HTTPGet(url string, data *[]byte) error { + _, err := c.HTTPGetWithRes(url, data) + return err +} + +// HTTPGetWithRes Send a Get request to client and return both response and error +func (c *HTTPClient) HTTPGetWithRes(url string, data *[]byte) (*http.Response, error) { + request, err := http.NewRequest("GET", c.formatURL(url), nil) + if err != nil { + return nil, err + } + res, err := c.handleRequest(request) + if err != nil { + return res, err + } + if res.StatusCode != 200 { + return res, errors.New(res.Status) + } + + *data = c.responseToBArray(res) + + return res, nil +} + +// HTTPPost Send a POST request to client and return an error object +func (c *HTTPClient) HTTPPost(url string, body string) error { + _, err := c.HTTPPostWithRes(url, body) + return err +} + +// HTTPPostWithRes Send a POST request to client and return both response and error +func (c *HTTPClient) HTTPPostWithRes(url string, body string) (*http.Response, error) { + request, err := http.NewRequest("POST", c.formatURL(url), bytes.NewBufferString(body)) + if err != nil { + return nil, err + } + res, err := c.handleRequest(request) + if err != nil { + return res, err + } + if res.StatusCode != 200 { + return res, errors.New(res.Status) + } + return res, nil +} + +func (c *HTTPClient) responseToBArray(response *http.Response) []byte { + defer response.Body.Close() + bytes, err := ioutil.ReadAll(response.Body) + if err != nil { + // TODO improved error reporting + fmt.Println("ERROR: " + err.Error()) + } + return bytes +} + +func (c *HTTPClient) handleRequest(request *http.Request) (*http.Response, error) { + if c.conf.HeaderAPIKeyName != "" && c.apikey != "" { + request.Header.Set(c.conf.HeaderAPIKeyName, c.apikey) + } + if c.conf.HeaderClientKeyName != "" && c.id != "" { + request.Header.Set(c.conf.HeaderClientKeyName, c.id) + } + if c.username != "" || c.password != "" { + request.SetBasicAuth(c.username, c.password) + } + if c.csrf != "" { + request.Header.Set("X-CSRF-Token-"+c.id[:5], c.csrf) + } + + c.log(logDebug, "HTTP %s %v", request.Method, request.URL) + + response, err := c.httpClient.Do(request) + if err != nil { + return nil, err + } + + // Detect client ID change + cid := response.Header.Get(c.conf.HeaderClientKeyName) + if cid != "" && c.id != cid { + c.id = cid + } + + // Detect CSR token change + for _, item := range response.Cookies() { + if item.Name == "CSRF-Token-"+c.id[:5] { + c.csrf = item.Value + goto csrffound + } + } + // OK CSRF found +csrffound: + + if response.StatusCode == 404 { + return nil, errors.New("Invalid endpoint or API call") + } else if response.StatusCode == 401 { + return nil, errors.New("Invalid username or password") + } else if response.StatusCode == 403 { + if c.apikey == "" { + // Request a new Csrf for next requests + c.getCidAndCsrf() + return nil, errors.New("Invalid CSRF token") + } + return nil, errors.New("Invalid API key") + } else if response.StatusCode != 200 { + data := make(map[string]interface{}) + // Try to decode error field of APIError struct + json.Unmarshal(c.responseToBArray(response), &data) + if err, found := data["error"]; found { + return nil, fmt.Errorf(err.(string)) + } else { + body := strings.TrimSpace(string(c.responseToBArray(response))) + if body != "" { + return nil, fmt.Errorf(body) + } + } + return nil, errors.New("Unknown HTTP status returned: " + response.Status) + } + return response, nil +} |