// Package cloudflare provides a minimal Cloudflare API client for removing // geo-restriction WAF rules created by felhom-controller. package cloudflare import ( "encoding/json" "fmt" "io" "log" "net/http" "strings" "time" ) const apiBase = "https://api.cloudflare.com/client/v4" // geoRulePrefix is the description prefix used by felhom-controller for geo-blocking rules. const geoRulePrefix = "[felhom-geo]" // RemoveGeoRules deletes all WAF custom rules with a [felhom-geo] description prefix // from the Cloudflare zone associated with the given domain. func RemoveGeoRules(apiToken, domain string, logger *log.Logger) error { if apiToken == "" { return fmt.Errorf("no Cloudflare API token provided") } if domain == "" { return fmt.Errorf("no domain provided") } // 1. Resolve zone ID zoneID, err := resolveZone(apiToken, domain) if err != nil { return fmt.Errorf("resolve zone for %s: %w", domain, err) } if logger != nil { logger.Printf("[INFO] cloudflare.RemoveGeoRules: zone=%s for domain=%s", zoneID, domain) } // 2. Find the http_request_firewall_custom phase ruleset rulesetID, err := findFirewallRuleset(apiToken, zoneID) if err != nil { return fmt.Errorf("find firewall ruleset: %w", err) } if rulesetID == "" { if logger != nil { logger.Printf("[INFO] cloudflare.RemoveGeoRules: no firewall ruleset found — nothing to remove") } return nil } // 3. List rules and filter by [felhom-geo] prefix rules, err := listRules(apiToken, zoneID, rulesetID) if err != nil { return fmt.Errorf("list rules: %w", err) } var geoRuleIDs []string for _, r := range rules { if strings.HasPrefix(r.Description, geoRulePrefix) { geoRuleIDs = append(geoRuleIDs, r.ID) } } if len(geoRuleIDs) == 0 { if logger != nil { logger.Printf("[INFO] cloudflare.RemoveGeoRules: no [felhom-geo] rules found") } return nil } // 4. Delete each matching rule var errors []string for _, ruleID := range geoRuleIDs { if err := deleteRule(apiToken, zoneID, rulesetID, ruleID); err != nil { errors = append(errors, fmt.Sprintf("delete rule %s: %v", ruleID, err)) } } if len(errors) > 0 { return fmt.Errorf("deleted %d/%d rules; errors: %s", len(geoRuleIDs)-len(errors), len(geoRuleIDs), strings.Join(errors, "; ")) } if logger != nil { logger.Printf("[INFO] cloudflare.RemoveGeoRules: deleted %d [felhom-geo] rule(s)", len(geoRuleIDs)) } return nil } // --- Cloudflare API helpers --- type cfResponse struct { Success bool `json:"success"` Result json.RawMessage `json:"result"` Errors []struct { Message string `json:"message"` } `json:"errors"` } type zone struct { ID string `json:"id"` Name string `json:"name"` } type ruleset struct { ID string `json:"id"` Phase string `json:"phase"` } type rule struct { ID string `json:"id"` Description string `json:"description"` } func resolveZone(apiToken, domain string) (string, error) { // Try exact domain first, then parent domain for _, name := range []string{domain, parentDomain(domain)} { if name == "" { continue } resp, err := cfDo(apiToken, "GET", fmt.Sprintf("/zones?name=%s&status=active", name), nil) if err != nil { return "", err } var zones []zone if err := json.Unmarshal(resp.Result, &zones); err != nil { return "", fmt.Errorf("parse zones: %w", err) } if len(zones) > 0 { return zones[0].ID, nil } } return "", fmt.Errorf("no active zone found for %s", domain) } func parentDomain(domain string) string { parts := strings.SplitN(domain, ".", 2) if len(parts) < 2 { return "" } return parts[1] } func findFirewallRuleset(apiToken, zoneID string) (string, error) { resp, err := cfDo(apiToken, "GET", fmt.Sprintf("/zones/%s/rulesets", zoneID), nil) if err != nil { return "", err } var rulesets []ruleset if err := json.Unmarshal(resp.Result, &rulesets); err != nil { return "", fmt.Errorf("parse rulesets: %w", err) } for _, rs := range rulesets { if rs.Phase == "http_request_firewall_custom" { return rs.ID, nil } } return "", nil } func listRules(apiToken, zoneID, rulesetID string) ([]rule, error) { resp, err := cfDo(apiToken, "GET", fmt.Sprintf("/zones/%s/rulesets/%s", zoneID, rulesetID), nil) if err != nil { return nil, err } var rs struct { Rules []rule `json:"rules"` } if err := json.Unmarshal(resp.Result, &rs); err != nil { return nil, fmt.Errorf("parse ruleset detail: %w", err) } return rs.Rules, nil } func deleteRule(apiToken, zoneID, rulesetID, ruleID string) error { _, err := cfDo(apiToken, "DELETE", fmt.Sprintf("/zones/%s/rulesets/%s/rules/%s", zoneID, rulesetID, ruleID), nil) return err } func cfDo(apiToken, method, path string, body io.Reader) (*cfResponse, error) { url := apiBase + path req, err := http.NewRequest(method, url, body) if err != nil { return nil, err } req.Header.Set("Authorization", "Bearer "+apiToken) req.Header.Set("Content-Type", "application/json") client := &http.Client{Timeout: 15 * time.Second} resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("request %s %s: %w", method, path, err) } defer resp.Body.Close() data, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return nil, fmt.Errorf("read response: %w", err) } var cfResp cfResponse if err := json.Unmarshal(data, &cfResp); err != nil { return nil, fmt.Errorf("parse CF response (status %d): %s", resp.StatusCode, string(data)) } if !cfResp.Success { msgs := make([]string, len(cfResp.Errors)) for i, e := range cfResp.Errors { msgs[i] = e.Message } return nil, fmt.Errorf("CF API error: %s", strings.Join(msgs, "; ")) } return &cfResp, nil }