001package io.prometheus.cloudwatch;
002
003import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider;
004import com.amazonaws.regions.Region;
005import com.amazonaws.regions.RegionUtils;
006import com.amazonaws.services.cloudwatch.AmazonCloudWatchClient;
007import com.amazonaws.services.cloudwatch.model.Datapoint;
008import com.amazonaws.services.cloudwatch.model.Dimension;
009import com.amazonaws.services.cloudwatch.model.DimensionFilter;
010import com.amazonaws.services.cloudwatch.model.GetMetricStatisticsRequest;
011import com.amazonaws.services.cloudwatch.model.GetMetricStatisticsResult;
012import com.amazonaws.services.cloudwatch.model.ListMetricsRequest;
013import com.amazonaws.services.cloudwatch.model.ListMetricsResult;
014import com.amazonaws.services.cloudwatch.model.Metric;
015import io.prometheus.client.Collector;
016import io.prometheus.client.Counter;
017import java.io.Reader;
018import java.io.IOException;
019import java.util.ArrayList;
020import java.util.Arrays;
021import java.util.Date;
022import java.util.List;
023import java.util.Set;
024import java.util.Map;
025import java.util.HashMap;
026import java.util.logging.Level;
027import java.util.logging.Logger;
028import java.util.regex.Pattern;
029
030import org.yaml.snakeyaml.Yaml;
031
032public class CloudWatchCollector extends Collector {
033    private static final Logger LOGGER = Logger.getLogger(CloudWatchCollector.class.getName());
034
035    AmazonCloudWatchClient client;
036
037    Region region;
038
039    static class MetricRule {
040      String awsNamespace;
041      String awsMetricName;
042      int periodSeconds;
043      int rangeSeconds;
044      int delaySeconds;
045      List<String> awsStatistics;
046      List<String> awsExtendedStatistics;
047      List<String> awsDimensions;
048      Map<String,List<String>> awsDimensionSelect;
049      Map<String,List<String>> awsDimensionSelectRegex;
050      String help;
051    }
052
053    private static final Counter cloudwatchRequests = Counter.build()
054      .name("cloudwatch_requests_total").help("API requests made to CloudWatch").register();
055
056    private static final List<String> brokenDynamoMetrics = Arrays.asList(
057            "ConsumedReadCapacityUnits", "ConsumedWriteCapacityUnits",
058            "ProvisionedReadCapacityUnits", "ProvisionedWriteCapacityUnits",
059            "ReadThrottleEvents", "WriteThrottleEvents");
060
061    ArrayList<MetricRule> rules = new ArrayList<MetricRule>();
062
063    public CloudWatchCollector(Reader in) throws IOException {
064        this((Map<String, Object>)new Yaml().load(in),null);
065    }
066    public CloudWatchCollector(String yamlConfig) {
067        this((Map<String, Object>)new Yaml().load(yamlConfig),null);
068    }
069
070    /* For unittests. */
071    protected CloudWatchCollector(String jsonConfig, AmazonCloudWatchClient client) {
072        this((Map<String, Object>)new Yaml().load(jsonConfig), client);
073    }
074
075    private CloudWatchCollector(Map<String, Object> config, AmazonCloudWatchClient client) {
076        if(config == null) {  // Yaml config empty, set config to empty map.
077            config = new HashMap<String, Object>(); 
078        }
079        if (!config.containsKey("region")) {
080          throw new IllegalArgumentException("Must provide region");
081        }
082        region = RegionUtils.getRegion((String) config.get("region"));
083
084        int defaultPeriod = 60;
085        if (config.containsKey("period_seconds")) {
086          defaultPeriod = ((Number)config.get("period_seconds")).intValue();
087        }
088        int defaultRange = 600;
089        if (config.containsKey("range_seconds")) {
090          defaultRange = ((Number)config.get("range_seconds")).intValue();
091        }
092        int defaultDelay = 600;
093        if (config.containsKey("delay_seconds")) {
094          defaultDelay = ((Number)config.get("delay_seconds")).intValue();
095        }
096
097        if (client == null) {
098          if (config.containsKey("role_arn")) {
099            STSAssumeRoleSessionCredentialsProvider credentialsProvider = new STSAssumeRoleSessionCredentialsProvider(
100              (String) config.get("role_arn"),
101              "cloudwatch_exporter"
102            );
103            this.client = new AmazonCloudWatchClient(credentialsProvider);
104          } else {
105            this.client = new AmazonCloudWatchClient();
106          }
107          this.client.setEndpoint(getMonitoringEndpoint());
108        } else {
109          this.client = client;
110        }
111
112        if (!config.containsKey("metrics")) {
113          throw new IllegalArgumentException("Must provide metrics");
114        }
115        for (Object ruleObject : (List<Map<String,Object>>) config.get("metrics")) {
116          Map<String, Object> yamlMetricRule = (Map<String, Object>)ruleObject;
117          MetricRule rule = new MetricRule();
118          rules.add(rule);
119          if (!yamlMetricRule.containsKey("aws_namespace") || !yamlMetricRule.containsKey("aws_metric_name")) {
120            throw new IllegalArgumentException("Must provide aws_namespace and aws_metric_name");
121          }
122          rule.awsNamespace = (String)yamlMetricRule.get("aws_namespace");
123          rule.awsMetricName = (String)yamlMetricRule.get("aws_metric_name");
124          if (yamlMetricRule.containsKey("help")) {
125            rule.help = (String)yamlMetricRule.get("help");
126          }
127          if (yamlMetricRule.containsKey("aws_dimensions")) {
128            rule.awsDimensions = (List<String>)yamlMetricRule.get("aws_dimensions");
129          }
130          if (yamlMetricRule.containsKey("aws_dimension_select") && yamlMetricRule.containsKey("aws_dimension_select_regex")) {
131            throw new IllegalArgumentException("Must not provide aws_dimension_select and aws_dimension_select_regex at the same time");
132          }
133          if (yamlMetricRule.containsKey("aws_dimension_select")) {
134            rule.awsDimensionSelect = (Map<String, List<String>>)yamlMetricRule.get("aws_dimension_select");
135          }
136          if (yamlMetricRule.containsKey("aws_dimension_select_regex")) {
137            rule.awsDimensionSelectRegex = (Map<String,List<String>>)yamlMetricRule.get("aws_dimension_select_regex");
138          }
139          if (yamlMetricRule.containsKey("aws_statistics")) {
140            rule.awsStatistics = (List<String>)yamlMetricRule.get("aws_statistics");
141          } else if (!yamlMetricRule.containsKey("aws_extended_statistics")) {
142            rule.awsStatistics = new ArrayList(Arrays.asList("Sum", "SampleCount", "Minimum", "Maximum", "Average"));
143          }
144          if (yamlMetricRule.containsKey("aws_extended_statistics")) {
145            rule.awsExtendedStatistics = (List<String>)yamlMetricRule.get("aws_extended_statistics");
146          }
147          if (yamlMetricRule.containsKey("period_seconds")) {
148            rule.periodSeconds = ((Number)yamlMetricRule.get("period_seconds")).intValue();
149          } else {
150            rule.periodSeconds = defaultPeriod;
151          }
152          if (yamlMetricRule.containsKey("range_seconds")) {
153            rule.rangeSeconds = ((Number)yamlMetricRule.get("range_seconds")).intValue();
154          } else {
155            rule.rangeSeconds = defaultRange;
156          }
157          if (yamlMetricRule.containsKey("delay_seconds")) {
158            rule.delaySeconds = ((Number)yamlMetricRule.get("delay_seconds")).intValue();
159          } else {
160            rule.delaySeconds = defaultDelay;
161          }
162        }
163    }
164
165    public String getMonitoringEndpoint() {
166      return "https://" + region.getServiceEndpoint("monitoring");
167    }
168
169    private List<List<Dimension>> getDimensions(MetricRule rule) {
170      List<List<Dimension>> dimensions = new ArrayList<List<Dimension>>();
171      if (rule.awsDimensions == null) {
172        dimensions.add(new ArrayList<Dimension>());
173        return dimensions;
174      }
175
176      ListMetricsRequest request = new ListMetricsRequest();
177      request.setNamespace(rule.awsNamespace);
178      request.setMetricName(rule.awsMetricName);
179      List<DimensionFilter> dimensionFilters = new ArrayList<DimensionFilter>();
180      for (String dimension: rule.awsDimensions) {
181        dimensionFilters.add(new DimensionFilter().withName(dimension));
182      }
183      request.setDimensions(dimensionFilters);
184
185      String nextToken = null;
186      do {
187        request.setNextToken(nextToken);
188        ListMetricsResult result = client.listMetrics(request);
189        cloudwatchRequests.inc();
190        for (Metric metric: result.getMetrics()) {
191          if (metric.getDimensions().size() != dimensionFilters.size()) {
192            // AWS returns all the metrics with dimensions beyond the ones we ask for,
193            // so filter them out.
194            continue;
195          }
196          if (useMetric(rule, metric)) {
197            dimensions.add(metric.getDimensions());
198          }
199        }
200        nextToken = result.getNextToken();
201      } while (nextToken != null);
202
203      return dimensions;
204    }
205
206    /**
207     * Check if a metric should be used according to `aws_dimension_select` or `aws_dimension_select_regex`
208     */
209    private boolean useMetric(MetricRule rule, Metric metric) {
210      if (rule.awsDimensionSelect == null && rule.awsDimensionSelectRegex == null) {
211        return true;
212      }
213      if (rule.awsDimensionSelect != null  && metricsIsInAwsDimensionSelect(rule, metric)) {
214        return true;
215      }
216      if (rule.awsDimensionSelectRegex != null  && metricIsInAwsDimensionSelectRegex(rule, metric)) {
217        return true;
218      }
219      return false;
220    }
221
222    /**
223     * Check if a metric is matched in `aws_dimension_select`
224     */
225    private boolean metricsIsInAwsDimensionSelect(MetricRule rule, Metric metric) {
226      Set<String> dimensionSelectKeys = rule.awsDimensionSelect.keySet();
227      for (Dimension dimension : metric.getDimensions()) {
228        String dimensionName = dimension.getName();
229        String dimensionValue = dimension.getValue();
230        if (dimensionSelectKeys.contains(dimensionName)) {
231          List<String> allowedDimensionValues = rule.awsDimensionSelect.get(dimensionName);
232          if (!allowedDimensionValues.contains(dimensionValue)) {
233            return false;
234          }
235        }
236      }
237      return true;
238    }
239
240    /**
241     * Check if a metric is matched in `aws_dimension_select_regex`
242     */
243    private boolean metricIsInAwsDimensionSelectRegex(MetricRule rule, Metric metric) {
244      Set<String> dimensionSelectRegexKeys = rule.awsDimensionSelectRegex.keySet();
245      for (Dimension dimension : metric.getDimensions()) {
246        String dimensionName = dimension.getName();
247        String dimensionValue = dimension.getValue();
248        if (dimensionSelectRegexKeys.contains(dimensionName)) {
249          List<String> allowedDimensionValues = rule.awsDimensionSelectRegex.get(dimensionName);
250          if (!regexListMatch(allowedDimensionValues, dimensionValue)) {
251            return false;
252          }
253        }
254      }
255      return true;
256    }
257
258    /**
259     * Check if any regex string in a list matches a given input value
260     */
261    protected static boolean regexListMatch(List<String> regexList, String input) {
262      for (String regex: regexList) {
263        if (Pattern.matches(regex, input)) {
264          return true;
265        }
266      }
267      return false;
268    }
269
270    private Datapoint getNewestDatapoint(java.util.List<Datapoint> datapoints) {
271      Datapoint newest = null;
272      for (Datapoint d: datapoints) {
273        if (newest == null || newest.getTimestamp().before(d.getTimestamp())) {
274          newest = d;
275        }
276      }
277      return newest;
278    }
279
280    private String toSnakeCase(String str) {
281      return str.replaceAll("([a-z0-9])([A-Z])", "$1_$2").toLowerCase();
282    }
283
284    private String safeName(String s) {
285      // Change invalid chars to underscore, and merge underscores.
286      return s.replaceAll("[^a-zA-Z0-9:_]", "_").replaceAll("__+", "_");
287    }
288
289    private String help(MetricRule rule, String unit, String statistic) {
290      if (rule.help != null) {
291          return rule.help;
292      }
293      return "CloudWatch metric " + rule.awsNamespace + " " + rule.awsMetricName
294          + " Dimensions: " + rule.awsDimensions + " Statistic: " + statistic
295          + " Unit: " + unit;
296    }
297
298    private void scrape(List<MetricFamilySamples> mfs) {
299      long start = System.currentTimeMillis();
300      for (MetricRule rule: rules) {
301        Date startDate = new Date(start - 1000 * rule.delaySeconds);
302        Date endDate = new Date(start - 1000 * (rule.delaySeconds + rule.rangeSeconds));
303        GetMetricStatisticsRequest request = new GetMetricStatisticsRequest();
304        request.setNamespace(rule.awsNamespace);
305        request.setMetricName(rule.awsMetricName);
306        request.setStatistics(rule.awsStatistics);
307        request.setExtendedStatistics(rule.awsExtendedStatistics);
308        request.setEndTime(startDate);
309        request.setStartTime(endDate);
310        request.setPeriod(rule.periodSeconds);
311
312        String baseName = safeName(rule.awsNamespace.toLowerCase() + "_" + toSnakeCase(rule.awsMetricName));
313        String jobName = safeName(rule.awsNamespace.toLowerCase());
314        List<MetricFamilySamples.Sample> sumSamples = new ArrayList<MetricFamilySamples.Sample>();
315        List<MetricFamilySamples.Sample> sampleCountSamples = new ArrayList<MetricFamilySamples.Sample>();
316        List<MetricFamilySamples.Sample> minimumSamples = new ArrayList<MetricFamilySamples.Sample>();
317        List<MetricFamilySamples.Sample> maximumSamples = new ArrayList<MetricFamilySamples.Sample>();
318        List<MetricFamilySamples.Sample> averageSamples = new ArrayList<MetricFamilySamples.Sample>();
319        HashMap<String, ArrayList<MetricFamilySamples.Sample>> extendedSamples = new HashMap<String, ArrayList<MetricFamilySamples.Sample>>();
320
321        String unit = null;
322
323        if (rule.awsNamespace.equals("AWS/DynamoDB")
324                && rule.awsDimensions.contains("GlobalSecondaryIndexName")
325                && brokenDynamoMetrics.contains(rule.awsMetricName)) {
326            baseName += "_index";
327        }
328
329        for (List<Dimension> dimensions: getDimensions(rule)) {
330          request.setDimensions(dimensions);
331
332          GetMetricStatisticsResult result = client.getMetricStatistics(request);
333          cloudwatchRequests.inc();
334          Datapoint dp = getNewestDatapoint(result.getDatapoints());
335          if (dp == null) {
336            continue;
337          }
338          unit = dp.getUnit();
339
340          List<String> labelNames = new ArrayList<String>();
341          List<String> labelValues = new ArrayList<String>();
342          labelNames.add("job");
343          labelValues.add(jobName);
344          labelNames.add("instance");
345          labelValues.add("");
346          for (Dimension d: dimensions) {
347            labelNames.add(safeName(toSnakeCase(d.getName())));
348            labelValues.add(d.getValue());
349          }
350
351          if (dp.getSum() != null) {
352            sumSamples.add(new MetricFamilySamples.Sample(
353                baseName + "_sum", labelNames, labelValues, dp.getSum()));
354          }
355          if (dp.getSampleCount() != null) {
356            sampleCountSamples.add(new MetricFamilySamples.Sample(
357                baseName + "_sample_count", labelNames, labelValues, dp.getSampleCount()));
358          }
359          if (dp.getMinimum() != null) {
360            minimumSamples.add(new MetricFamilySamples.Sample(
361                baseName + "_minimum", labelNames, labelValues, dp.getMinimum()));
362          }
363          if (dp.getMaximum() != null) {
364            maximumSamples.add(new MetricFamilySamples.Sample(
365                baseName + "_maximum",labelNames, labelValues, dp.getMaximum()));
366          }
367          if (dp.getAverage() != null) {
368            averageSamples.add(new MetricFamilySamples.Sample(
369                baseName + "_average", labelNames, labelValues, dp.getAverage()));
370          }
371          if (dp.getExtendedStatistics() != null) {
372            for (Map.Entry<String, Double> entry : dp.getExtendedStatistics().entrySet()) {
373              ArrayList<MetricFamilySamples.Sample> samples = extendedSamples.get(entry.getKey());
374              if (samples == null) {
375                samples = new ArrayList<MetricFamilySamples.Sample>();
376                extendedSamples.put(entry.getKey(), samples);
377              }
378              samples.add(new MetricFamilySamples.Sample(
379                  baseName + "_" + safeName(toSnakeCase(entry.getKey())), labelNames, labelValues, entry.getValue()));
380            }
381          }
382        }
383
384        if (!sumSamples.isEmpty()) {
385          mfs.add(new MetricFamilySamples(baseName + "_sum", Type.GAUGE, help(rule, unit, "Sum"), sumSamples));
386        }
387        if (!sampleCountSamples.isEmpty()) {
388          mfs.add(new MetricFamilySamples(baseName + "_sample_count", Type.GAUGE, help(rule, unit, "SampleCount"), sampleCountSamples));
389        }
390        if (!minimumSamples.isEmpty()) {
391          mfs.add(new MetricFamilySamples(baseName + "_minimum", Type.GAUGE, help(rule, unit, "Minimum"), minimumSamples));
392        }
393        if (!maximumSamples.isEmpty()) {
394          mfs.add(new MetricFamilySamples(baseName + "_maximum", Type.GAUGE, help(rule, unit, "Maximum"), maximumSamples));
395        }
396        if (!averageSamples.isEmpty()) {
397          mfs.add(new MetricFamilySamples(baseName + "_average", Type.GAUGE, help(rule, unit, "Average"), averageSamples));
398        }
399        for (Map.Entry<String, ArrayList<MetricFamilySamples.Sample>> entry : extendedSamples.entrySet()) {
400          mfs.add(new MetricFamilySamples(baseName + "_" + safeName(toSnakeCase(entry.getKey())), Type.GAUGE, help(rule, unit, entry.getKey()), entry.getValue()));
401        }
402      }
403    }
404
405    public List<MetricFamilySamples> collect() {
406      long start = System.nanoTime();
407      double error = 0;
408      List<MetricFamilySamples> mfs = new ArrayList<MetricFamilySamples>(); 
409      try {
410        scrape(mfs);
411      } catch (Exception e) {
412        error = 1;
413        LOGGER.log(Level.WARNING, "CloudWatch scrape failed", e);
414      }
415      List<MetricFamilySamples.Sample> samples = new ArrayList<MetricFamilySamples.Sample>();
416      samples.add(new MetricFamilySamples.Sample(
417          "cloudwatch_exporter_scrape_duration_seconds", new ArrayList<String>(), new ArrayList<String>(), (System.nanoTime() - start) / 1.0E9));
418      mfs.add(new MetricFamilySamples("cloudwatch_exporter_scrape_duration_seconds", Type.GAUGE, "Time this CloudWatch scrape took, in seconds.", samples));
419
420      samples = new ArrayList<MetricFamilySamples.Sample>();
421      samples.add(new MetricFamilySamples.Sample(
422          "cloudwatch_exporter_scrape_error", new ArrayList<String>(), new ArrayList<String>(), error));
423      mfs.add(new MetricFamilySamples("cloudwatch_exporter_scrape_error", Type.GAUGE, "Non-zero if this scrape failed.", samples));
424      return mfs;
425    }
426
427    /**
428     * Convenience function to run standalone.
429     */
430    public static void main(String[] args) throws Exception {
431      String region = "eu-west-1";
432      if (args.length > 0) {
433        region = args[0];
434      }
435      CloudWatchCollector jc = new CloudWatchCollector(("{"
436      + "`region`: `" + region + "`,"
437      + "`metrics`: [{`aws_namespace`: `AWS/ELB`, `aws_metric_name`: `RequestCount`, `aws_dimensions`: [`AvailabilityZone`, `LoadBalancerName`]}] ,"
438      + "}").replace('`', '"'));
439      for(MetricFamilySamples mfs : jc.collect()) {
440        System.out.println(mfs);
441      }
442    }
443}
444