Skip to content

Commit 047c800

Browse files
committed
wip
1 parent 53085e4 commit 047c800

8 files changed

Lines changed: 271 additions & 15 deletions

File tree

src/aws_api/arn.rs

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::fmt::{Display, Formatter};
12
use crate::aws_api::error::Error;
23
use std::str::FromStr;
34

@@ -13,11 +14,11 @@ pub(crate) struct AwsArn {
1314
impl FromStr for AwsArn {
1415
type Err = Error;
1516

16-
// Note, this will only handle ARNs were the resource type is included and split with ':'
1717
fn from_str(s: &str) -> Result<Self, Self::Err> {
1818
// Split one more than we need to verify it's valid
19-
let mut parts: Vec<String> = s.splitn(8, ":").map(|s| s.to_string()).collect();
20-
if parts.len() != 7 {
19+
let mut parts: Vec<String> = s.splitn(9, ":").map(|s| s.to_string()).collect();
20+
let num_parts = parts.len();
21+
if num_parts < 6 || num_parts >= 8 {
2122
return Err(Error::ArnParseError(s.to_string()));
2223
}
2324

@@ -37,7 +38,9 @@ impl FromStr for AwsArn {
3738
};
3839

3940
arn.resource_id = parts.pop().unwrap();
40-
arn.resource_type = parts.pop().unwrap();
41+
if num_parts == 7 {
42+
arn.resource_type = parts.pop().unwrap();
43+
}
4144
arn.account_id = parts.pop().unwrap();
4245
arn.region = parts.pop().unwrap();
4346
arn.service = parts.pop().unwrap();
@@ -51,6 +54,23 @@ impl FromStr for AwsArn {
5154
}
5255
}
5356

57+
impl Display for AwsArn {
58+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
59+
let mut parts = Vec::with_capacity(7);
60+
parts.push("arn");
61+
parts.push(self.partition.as_str());
62+
parts.push(self.service.as_str());
63+
parts.push(self.region.as_str());
64+
parts.push(self.account_id.as_str());
65+
if self.resource_type != "" {
66+
parts.push(self.resource_type.as_str());
67+
}
68+
parts.push(self.resource_id.as_str());
69+
70+
write!(f, "{}", parts.join(":"))
71+
}
72+
}
73+
5474
impl AwsArn {
5575
pub fn get_endpoint(&self) -> String {
5676
let domain = if self.region.starts_with("cn-") {
@@ -68,7 +88,7 @@ mod tests {
6888
use super::*;
6989

7090
#[test]
71-
fn test_parse_arn_valid() {
91+
fn test_parse_secrets_arn_valid() {
7292
let input = "arn:aws:secretsmanager:us-east-2:891477334659:secret:test-ohio-secret-L86lpn";
7393

7494
let arn = input.parse::<AwsArn>().unwrap();
@@ -81,6 +101,20 @@ mod tests {
81101
assert_eq!("test-ohio-secret-L86lpn", arn.resource_id);
82102
}
83103

104+
#[test]
105+
fn test_parse_ssm_arn_valid() {
106+
let input = "arn:aws:ssm:us-east-1:123377354456:parameter/ci-test-value";
107+
108+
let arn = input.parse::<AwsArn>().unwrap();
109+
110+
assert_eq!("aws", arn.partition);
111+
assert_eq!("ssm", arn.service);
112+
assert_eq!("us-east-1", arn.region);
113+
assert_eq!("123377354456", arn.account_id);
114+
assert_eq!("", arn.resource_type);
115+
assert_eq!("parameter/ci-test-value", arn.resource_id);
116+
}
117+
84118
#[test]
85119
fn test_parse_arn_invalid() {
86120
assert!(

src/aws_api/auth.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub trait Clock {
1313
fn now(&self) -> DateTime<Utc>;
1414
}
1515

16+
#[derive(Default)]
1617
pub struct SystemClock;
1718

1819
impl Clock for SystemClock {

src/aws_api/client.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use hyper_util::rt::{TokioExecutor, TokioTimer};
1313
use rustls::ClientConfig;
1414
use std::time::Duration;
1515
use tower::BoxError;
16+
use crate::aws_api::paramstore::ParameterStore;
1617

1718
/// Main client for AWS services
1819
pub struct AwsClient {
@@ -32,6 +33,9 @@ impl AwsClient {
3233
pub fn secrets_manager(&self) -> SecretsManager {
3334
SecretsManager::new(self)
3435
}
36+
37+
/// Get an instance of the ParameterStore service
38+
pub fn parameter_store(&self) -> ParameterStore { ParameterStore::new(self) }
3539

3640
pub async fn perform(&self, req: Request<Full<Bytes>>) -> Result<Bytes, Error> {
3741
let resp = self.client.request(req).await?;

src/aws_api/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub enum Error {
1313
SignatureError(String),
1414
SerdeError(serde_json::Error),
1515
AwsError { code: String, message: String },
16+
InvalidParameters(Vec<String>),
1617
}
1718

1819
impl fmt::Display for Error {
@@ -27,6 +28,7 @@ impl fmt::Display for Error {
2728
Error::HttpResponseError(e) => write!(f, "Failed to parse HTTP response: {}", e),
2829
Error::HttpResponseErrorParse(e) => write!(f, "Failed to parse HTTP response: {}", e),
2930
Error::UriParseError(e) => write!(f, "Unable to parse endpoint url: {}", e),
31+
Error::InvalidParameters(params) => write!(f, "Unable to lookup parameters: {:?}", params)
3032
}
3133
}
3234
}

src/aws_api/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ pub mod client;
44
pub mod config;
55
mod error;
66
mod secretsmanager;
7+
mod paramstore;
8+
mod test_util;

src/aws_api/paramstore.rs

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
use crate::aws_api::arn::AwsArn;
2+
use crate::aws_api::auth::{AwsRequestSigner, SystemClock};
3+
use crate::aws_api::client::AwsClient;
4+
use crate::aws_api::error::Error;
5+
use http::header::CONTENT_TYPE;
6+
use http::{HeaderMap, HeaderValue, Method, Uri};
7+
use serde::Deserialize;
8+
use serde_json::json;
9+
use std::collections::HashMap;
10+
use tracing::error;
11+
12+
pub struct ParameterStore<'a> {
13+
client: &'a AwsClient,
14+
service_name: &'static str,
15+
}
16+
17+
#[derive(Debug, Deserialize)]
18+
pub struct GetParametersResponse {
19+
/// The parameter object.
20+
#[serde(rename = "Parameters")]
21+
pub parameters: Vec<Parameter>,
22+
23+
#[serde(rename = "InvalidParameters")]
24+
pub invalid_parameters: Vec<InvalidParameters>,
25+
}
26+
27+
#[derive(Debug, Deserialize)]
28+
pub struct InvalidParameters {
29+
#[serde(rename = "Name")]
30+
pub name: String,
31+
}
32+
33+
#[derive(Debug, Deserialize)]
34+
pub struct Parameter {
35+
/// The Amazon Resource Name (ARN) of the parameter.
36+
#[serde(rename = "ARN")]
37+
pub arn: Option<String>,
38+
39+
/// The data type of the parameter, such as text, aws:ec2:image, or aws:tag-specification.
40+
#[serde(rename = "DataType")]
41+
pub data_type: Option<String>,
42+
43+
/// The last modification date of the parameter.
44+
#[serde(rename = "LastModifiedDate")]
45+
pub last_modified_date: Option<f64>,
46+
47+
/// The name of the parameter.
48+
#[serde(rename = "Name")]
49+
pub name: String,
50+
51+
/// The unique identifier for the parameter version.
52+
#[serde(rename = "Selector")]
53+
pub selector: Option<String>,
54+
55+
/// The parameter source.
56+
#[serde(rename = "SourceResult")]
57+
pub source_result: Option<String>,
58+
59+
/// The parameter type.
60+
#[serde(rename = "Type")]
61+
pub type_: String,
62+
63+
/// The parameter value.
64+
#[serde(rename = "Value")]
65+
pub value: String,
66+
67+
/// The parameter version.
68+
#[serde(rename = "Version")]
69+
pub version: Option<i64>,
70+
71+
/// Tags associated with the parameter.
72+
#[serde(rename = "Tags")]
73+
pub tags: Option<HashMap<String, String>>,
74+
}
75+
76+
impl<'a> ParameterStore<'a> {
77+
pub(crate) fn new(client: &'a AwsClient) -> Self {
78+
Self {
79+
client,
80+
service_name: "ssm",
81+
}
82+
}
83+
84+
pub async fn get_parameters(
85+
&self,
86+
param_arns: &[String],
87+
) -> Result<HashMap<String, Parameter>, Error> {
88+
let mut arns_by_endpoint = HashMap::new();
89+
for arn_str in param_arns {
90+
let arn = arn_str.parse::<AwsArn>()?;
91+
if arn.service != self.service_name {
92+
return Err(Error::ArnParseError(arn_str.clone()));
93+
}
94+
95+
arns_by_endpoint
96+
.entry(arn.get_endpoint())
97+
.or_insert_with(|| Vec::new())
98+
.push(arn);
99+
}
100+
101+
let mut res = HashMap::new();
102+
for (endpoint, arns) in &arns_by_endpoint {
103+
let endpoint = endpoint.parse::<Uri>()?;
104+
105+
let payload = json!({
106+
"Names": arns.iter().map(|arn| arn.to_string()).collect::<Vec<String>>(),
107+
"WithDecryption": true,
108+
});
109+
110+
let payload_bytes = serde_json::to_vec(&payload)?;
111+
112+
let mut hdrs = HeaderMap::new();
113+
hdrs.insert(
114+
"X-Amz-Target",
115+
HeaderValue::from_static("AmazonSSM.GetParameters"),
116+
);
117+
hdrs.insert(
118+
CONTENT_TYPE,
119+
HeaderValue::from_static("application/x-amz-json-1.1"),
120+
);
121+
122+
// Sign the request
123+
let signer = AwsRequestSigner::new(
124+
self.service_name,
125+
&arns[0].region,
126+
&self.client.config.aws_access_key_id,
127+
&self.client.config.aws_secret_access_key,
128+
self.client.config.aws_session_token.as_deref(),
129+
SystemClock,
130+
);
131+
let signed_request = signer.sign(endpoint, Method::POST, hdrs, payload_bytes)?;
132+
133+
// Send the request
134+
let response = self.client.perform(signed_request).await?;
135+
136+
let result: GetParametersResponse = serde_json::from_slice(response.as_ref())?;
137+
138+
if !result.invalid_parameters.is_empty() {
139+
return Err(Error::InvalidParameters(
140+
result
141+
.invalid_parameters
142+
.into_iter()
143+
.map(|i| i.name)
144+
.collect(),
145+
));
146+
}
147+
148+
for param in result.parameters {
149+
if param.arn.is_none() {
150+
error!(parameter = param.name, "Parameter was missing ARN");
151+
return Err(Error::InvalidParameters(vec![arns.into_iter().map(|arn|arn.to_string()).collect()]))
152+
}
153+
154+
let arn = param.arn.clone().unwrap();
155+
res.insert(arn, param);
156+
}
157+
}
158+
159+
Ok(res)
160+
}
161+
}
162+
163+
#[cfg(test)]
164+
mod tests {
165+
use super::*;
166+
use crate::aws_api::config::AwsConfig;
167+
use crate::aws_api::test_util::init_crypto;
168+
169+
#[tokio::test]
170+
async fn test_basic_paramstore_retrieval() {
171+
// TEST_PARAMSTORE_ARNS should be set to a comma-separated list of k=v pairs,
172+
// where k is an ARN of a secret and v is the secret value to test against.
173+
let test_paramstore_arns = std::env::var("TEST_PARAMSTORE_ARNS");
174+
if !test_paramstore_arns.is_ok() {
175+
println!("Skipping test_basic_paramstore_retrieval due to unset envvar");
176+
return;
177+
}
178+
179+
let test_arns: Vec<(String, String)> = test_paramstore_arns
180+
.unwrap()
181+
.split(",")
182+
.filter(|s| !s.is_empty())
183+
.filter_map(|pair| {
184+
let parts: Vec<&str> = pair.splitn(2, '=').collect();
185+
if parts.len() == 2 {
186+
Some((parts[0].trim().to_string(), parts[1].trim().to_string()))
187+
} else {
188+
None // Skip malformed pairs that don't have an equals sign
189+
}
190+
})
191+
.collect();
192+
193+
init_crypto();
194+
195+
let client = AwsClient::new(AwsConfig::from_env()).unwrap();
196+
197+
let ps = client.parameter_store();
198+
199+
let arn_values = test_arns.iter().map(|arn| arn.0.clone()).collect::<Vec<String>>();
200+
let res = ps.get_parameters(&arn_values).await.unwrap();
201+
202+
for test_arn in &test_arns {
203+
let entry = res.get(&test_arn.0).unwrap();
204+
205+
assert_eq!(test_arn.1, entry.value);
206+
}
207+
}
208+
}

src/aws_api/secretsmanager.rs

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ impl<'a> SecretsManager<'a> {
4444
) -> Result<GetSecretValueResponse, Error> {
4545
let arn = secret_arn.parse::<AwsArn>()?;
4646

47-
if arn.service != self.service_name {
47+
if arn.service != self.service_name || arn.resource_type != "secret" {
4848
return Err(Error::ArnParseError(secret_arn.to_string()));
4949
}
5050

@@ -90,7 +90,7 @@ impl<'a> SecretsManager<'a> {
9090
mod tests {
9191
use super::*;
9292
use crate::aws_api::config::AwsConfig;
93-
use std::sync::Once;
93+
use crate::aws_api::test_util::init_crypto;
9494

9595
#[tokio::test]
9696
async fn test_basic_secret_retrieval() {
@@ -129,12 +129,4 @@ mod tests {
129129
}
130130
}
131131

132-
static INIT_CRYPTO: Once = Once::new();
133-
pub fn init_crypto() {
134-
INIT_CRYPTO.call_once(|| {
135-
rustls::crypto::aws_lc_rs::default_provider()
136-
.install_default()
137-
.unwrap()
138-
});
139-
}
140132
}

src/aws_api/test_util.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
use std::sync::Once;
2+
3+
// used for testing
4+
#[allow(dead_code)]
5+
static INIT_CRYPTO: Once = Once::new();
6+
#[allow(dead_code)]
7+
pub fn init_crypto() {
8+
INIT_CRYPTO.call_once(|| {
9+
rustls::crypto::aws_lc_rs::default_provider()
10+
.install_default()
11+
.unwrap()
12+
});
13+
}

0 commit comments

Comments
 (0)