summaryrefslogtreecommitdiff
blob: 62c0307100e4c0eac2381dcaa46d9ac16499f9e0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
From 28b504a3fe7b9723a87841a4e2da9bf86828df83 Mon Sep 17 00:00:00 2001
From: "Chen, Chih-Chia" <pigfoot@gmail.com>
Date: Thu, 17 Mar 2022 14:32:15 +0800
Subject: [PATCH] support feature workgroup

Signed-off-by: Chen, Chih-Chia <pigfoot@gmail.com>
---
 cmd/query.go | 1 +
 lib/query.go | 9 +++++++++
 2 files changed, 10 insertions(+)

diff --git a/cmd/query.go b/cmd/query.go
index aa9018e..7a219b7 100644
--- a/cmd/query.go
+++ b/cmd/query.go
@@ -55,6 +55,7 @@ func init() {
 	QueryCmd.PersistentFlags().StringVarP(&q.SQL, "sql", "s", "", "SQL query to execute. Can be a file or raw query")
 	QueryCmd.PersistentFlags().StringVarP(&q.Format, "format", "f", "csv", "format the output as either json, csv, or table")
 	QueryCmd.PersistentFlags().StringVarP(&q.OutputFile, "output", "o", "", "(optional) file name to write this content to (defaults to standard output)")
+	QueryCmd.PersistentFlags().StringVarP(&q.WorkGroup, "workgroup", "w", "", "(optional) WorkGroup (defaults to primary)")
 	QueryCmd.PersistentFlags().BoolVar(&q.Statistics, "statistics", false, "print query statistics to stderr")
 	// QueryCmd.PersistentFlags().StringVar(&q.JMESPath, "jmespath", "", "optional JMESPath to further filter or format results. See jmespath.org for more.")
 }
diff --git a/lib/query.go b/lib/query.go
index ef93f2f..004d5e5 100644
--- a/lib/query.go
+++ b/lib/query.go
@@ -30,6 +30,7 @@ type Query struct {
 	Format             string
 	JMESPath           string
 	Statistics         bool
+	WorkGroup          string
 }
 
 // Format is an enumeration of available query output formats
@@ -50,6 +51,13 @@ func (q *Query) Execute() (*os.File, error) {
 		q.SQL = string(queryFromFile)
 	}
 
+	var workgroup *string
+	if q.WorkGroup != "" {
+		workgroup = aws.String(q.WorkGroup)
+	} else {
+		workgroup = nil
+	}
+
 	result, err := svc.StartQueryExecution(&athena.StartQueryExecutionInput{
 		QueryString: &q.SQL,
 		QueryExecutionContext: &athena.QueryExecutionContext{
@@ -58,6 +66,7 @@ func (q *Query) Execute() (*os.File, error) {
 		ResultConfiguration: &athena.ResultConfiguration{
 			OutputLocation: aws.String("s3://" + path.Join(q.QueryResultsBucket, q.QueryResultsPrefix)),
 		},
+		WorkGroup: workgroup,
 	})
 
 	if err != nil {
-- 
2.34.1