Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,10 @@ impl AppBuilder {
let has_tokens =
is_authenticated(&server, &secrets, "default").await;

let client = if has_tokens || server.requires_auth() {
let client = if has_tokens
|| server.requires_auth()
|| server.has_custom_headers()
{
McpClient::new_authenticated(
server, mcp_sm, secrets, "default",
)
Expand Down
60 changes: 52 additions & 8 deletions src/cli/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ pub enum McpCommand {
#[arg(long)]
scopes: Option<String>,

/// Custom HTTP header (repeatable, format: "Name:Value")
///
/// For MCP servers that use header-based auth instead of OAuth.
/// Example: --header "X-API-Key:sk-abc123"
#[arg(long = "header", value_name = "NAME:VALUE")]
headers: Vec<String>,

/// Server description
#[arg(long)]
description: Option<String>,
Expand Down Expand Up @@ -107,6 +114,7 @@ pub async fn run_mcp_command(cmd: McpCommand) -> anyhow::Result<()> {
auth_url,
token_url,
scopes,
headers,
description,
} => {
add_server(
Expand All @@ -116,6 +124,7 @@ pub async fn run_mcp_command(cmd: McpCommand) -> anyhow::Result<()> {
auth_url,
token_url,
scopes,
headers,
description,
)
.await
Expand All @@ -133,13 +142,15 @@ pub async fn run_mcp_command(cmd: McpCommand) -> anyhow::Result<()> {
}

/// Add a new MCP server.
#[allow(clippy::too_many_arguments)]
async fn add_server(
name: String,
url: String,
client_id: Option<String>,
auth_url: Option<String>,
token_url: Option<String>,
scopes: Option<String>,
headers: Vec<String>,
description: Option<String>,
) -> anyhow::Result<()> {
let mut config = McpServerConfig::new(&name, &url);
Expand All @@ -148,8 +159,8 @@ async fn add_server(
config = config.with_description(desc);
}

// Track if auth is required
let requires_auth = client_id.is_some();
// Track if OAuth auth is required
let requires_oauth = client_id.is_some();

// Set up OAuth if client_id is provided
if let Some(client_id) = client_id {
Expand All @@ -170,7 +181,24 @@ async fn add_server(
config = config.with_oauth(oauth);
}

// Validate
// Parse custom headers (format: "Name:Value")
if !headers.is_empty() {
let mut header_map = std::collections::HashMap::new();
for raw in &headers {
let (key, value) = raw.split_once(':').ok_or_else(|| {
anyhow::anyhow!("Invalid header format '{}'. Expected 'Name:Value'.", raw)
})?;
let key = key.trim().to_string();
let value = value.trim().to_string();
if key.is_empty() {
anyhow::bail!("Header name cannot be empty in '{}'", raw);
}
header_map.insert(key, value);
}
config = config.with_headers(header_map);
}

// Validate (includes header name/value safety checks)
config.validate()?;

// Save (DB if available, else disk)
Expand All @@ -183,9 +211,11 @@ async fn add_server(
println!(" ✓ Added MCP server '{}'", name);
println!(" URL: {}", url);

if requires_auth {
if requires_oauth {
println!();
println!(" Run 'ironclaw mcp auth {}' to authenticate.", name);
} else if !headers.is_empty() {
println!(" Auth: custom headers ({} configured)", headers.len());
}

println!();
Expand Down Expand Up @@ -230,7 +260,9 @@ async fn list_servers(verbose: bool) -> anyhow::Result<()> {

for server in &servers.servers {
let status = if server.enabled { "●" } else { "○" };
let auth_status = if server.requires_auth() {
let auth_status = if server.has_custom_headers() {
" (header auth)"
} else if server.requires_auth() {
" (auth required)"
} else {
""
Expand All @@ -248,6 +280,17 @@ async fn list_servers(verbose: bool) -> anyhow::Result<()> {
println!(" Scopes: {}", oauth.scopes.join(", "));
}
}
if let Some(ref headers) = server.headers {
for (name, value) in headers {
// Mask header values to avoid leaking secrets in logs/terminal
let masked = if value.len() <= 4 {
"****".to_string()
} else {
format!("{}...{}", &value[..2], &value[value.len() - 2..])
};
println!(" Header: {}: {}", name, masked);
}
}
println!();
} else {
println!(
Expand Down Expand Up @@ -360,8 +403,9 @@ async fn test_server(name: String, user_id: String) -> anyhow::Result<()> {
let secrets = get_secrets_store().await?;
let has_tokens = is_authenticated(&server, &secrets, &user_id).await;

let client = if has_tokens {
// We have stored tokens, use authenticated client
let client = if has_tokens || server.has_custom_headers() {
// We have stored tokens or custom headers — use authenticated client
// (custom headers are injected via server_config in send_request)
McpClient::new_authenticated(server.clone(), session_manager, secrets, user_id)
} else if server.requires_auth() {
// OAuth configured but no tokens - need to authenticate
Expand All @@ -373,7 +417,7 @@ async fn test_server(name: String, user_id: String) -> anyhow::Result<()> {
println!();
return Ok(());
} else {
// No OAuth and no tokens - try unauthenticated
// No OAuth, no tokens, no custom headers - try unauthenticated
McpClient::new_with_name(&server.name, &server.url)
};

Expand Down
2 changes: 1 addition & 1 deletion src/extensions/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2500,7 +2500,7 @@ impl ExtensionManager {

let has_tokens = is_authenticated(&server, &self.secrets, &self.user_id).await;

let client = if has_tokens || server.requires_auth() {
let client = if has_tokens || server.requires_auth() || server.has_custom_headers() {
McpClient::new_authenticated(
server.clone(),
Arc::clone(&self.mcp_session_manager),
Expand Down
190 changes: 188 additions & 2 deletions src/tools/mcp/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,26 @@ impl McpClient {
.header("Content-Type", "application/json")
.json(&request);

// Add Authorization header if we have a token
if let Some(token) = self.get_access_token().await? {
// Add custom headers from config (if any)
let has_custom_auth = if let Some(ref config) = self.server_config
&& config.has_custom_headers()
{
let has_auth = config.has_custom_auth_header();
if let Some(ref headers) = config.headers {
for (name, value) in headers {
// Headers were validated at config time (CRLF-safe)
req_builder = req_builder.header(name.as_str(), value.as_str());
}
}
has_auth
} else {
false
};
Comment on lines +182 to +195
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for adding custom headers and checking for a custom Authorization header is a bit complex, as it mixes a side effect (modifying req_builder) inside an expression. This can be refactored for better readability and maintainability by separating the action of adding headers from the check for the authorization header.

            let has_custom_auth = if let Some(config) = self.server_config.as_ref() {
                if let Some(headers) = &config.headers {
                    for (name, value) in headers {
                        // Headers were validated at config time (CRLF-safe)
                        req_builder = req_builder.header(name.as_str(), value.as_str());
                    }
                    config.has_custom_auth_header()
                } else {
                    false
                }
            } else {
                false
            };


// Add OAuth Bearer token if we have one and the user hasn't
// supplied a custom Authorization header (case-insensitive check
// already done by has_custom_auth_header).
if !has_custom_auth && let Some(token) = self.get_access_token().await? {
req_builder = req_builder.header("Authorization", format!("Bearer {}", token));
}

Expand Down Expand Up @@ -740,4 +758,172 @@ mod tests {
};
assert!(!tool.requires_approval());
}

#[test]
fn test_custom_auth_header_skips_bearer() {
// When config has a custom Authorization header,
// has_custom_auth_header() should return true so send_request
// skips auto Bearer injection.
let mut headers = std::collections::HashMap::new();
headers.insert("Authorization".to_string(), "Bearer static-tok".to_string());
let config = McpServerConfig::new("test", "https://mcp.example.com").with_headers(headers);
assert!(config.has_custom_auth_header());
}

#[test]
fn test_non_auth_headers_allow_bearer() {
// Custom headers that aren't Authorization should NOT block Bearer
let mut headers = std::collections::HashMap::new();
headers.insert("X-API-Key".to_string(), "sk-abc".to_string());
let config = McpServerConfig::new("test", "https://mcp.example.com").with_headers(headers);
assert!(!config.has_custom_auth_header());
}

/// Spin up a lightweight axum server that echoes back all received request
/// headers as a JSON-RPC result, so we can assert on the actual wire-level
/// headers sent by `send_request`.
async fn mock_mcp_echo_server() -> (String, tokio::task::JoinHandle<()>) {
use axum::http::StatusCode;
use axum::{Router, extract::Request, routing::post};

let app = Router::new().route(
"/",
post(|req: Request| async move {
// Collect headers into a map for easy assertion.
let headers: std::collections::HashMap<String, String> = req
.headers()
.iter()
.map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let body = serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"result": headers
});
(StatusCode::OK, axum::Json(body))
}),
);

let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let url = format!("http://127.0.0.1:{}", addr.port());
let handle = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
(url, handle)
}

/// Helper: call send_request on a client and return the echoed headers.
async fn send_and_get_headers(client: &McpClient) -> std::collections::HashMap<String, String> {
let req = McpRequest::initialize(1);
let resp = client.send_request(req).await.unwrap();
serde_json::from_value(resp.result.unwrap()).unwrap()
}

#[tokio::test]
async fn test_wire_custom_headers_injected() {
// Custom headers should appear in the outbound request.
let (url, handle) = mock_mcp_echo_server().await;
let mut hdrs = std::collections::HashMap::new();
hdrs.insert("X-Api-Key".to_string(), "sk-test-123".to_string());
let config = McpServerConfig::new("test", &url).with_headers(hdrs);
let client = McpClient::new_with_name("test", &url);
// Inject config so send_request reads custom headers
let client = McpClient {
server_config: Some(config),
..client
};
let headers = send_and_get_headers(&client).await;
assert_eq!(headers.get("x-api-key").unwrap(), "sk-test-123");
handle.abort();
}

#[tokio::test]
async fn test_wire_custom_auth_suppresses_bearer() {
// When config has a custom Authorization header, auto Bearer must NOT appear.
let (url, handle) = mock_mcp_echo_server().await;
let mut hdrs = std::collections::HashMap::new();
hdrs.insert(
"Authorization".to_string(),
"Token my-static-token".to_string(),
);
let config = McpServerConfig::new("test", &url).with_headers(hdrs);
let client = McpClient::new_with_name("test", &url);
let client = McpClient {
server_config: Some(config),
..client
};
let headers = send_and_get_headers(&client).await;
assert_eq!(
headers.get("authorization").unwrap(),
"Token my-static-token"
);
handle.abort();
}

#[tokio::test]
async fn test_wire_non_auth_headers_allow_bearer() {
// Custom non-Authorization headers must NOT suppress auto Bearer.
// Use a real InMemorySecretsStore with a stored token to prove Bearer
// IS injected alongside the custom header.
use crate::secrets::{CreateSecretParams, InMemorySecretsStore, SecretsCrypto};
use secrecy::SecretString;

let (url, handle) = mock_mcp_echo_server().await;

// Set up secrets store with a token
let crypto = Arc::new(
SecretsCrypto::new(SecretString::from("test-master-key-32chars!!!!!!!!xx")).unwrap(),
);
let secrets: Arc<dyn SecretsStore + Send + Sync> =
Arc::new(InMemorySecretsStore::new(crypto));
let mut hdrs = std::collections::HashMap::new();
hdrs.insert("X-Custom".to_string(), "value".to_string());
let config = McpServerConfig::new("test", &url).with_headers(hdrs);

// Store a token under the expected secret name
let token_name = config.token_secret_name(); // "mcp_test_access_token"
secrets
.create(
"default",
CreateSecretParams::new(&token_name, "oauth-bearer-tok"),
)
.await
.unwrap();

let session_manager = Arc::new(McpSessionManager::new());
let client = McpClient::new_authenticated(config, session_manager, secrets, "default");

let headers = send_and_get_headers(&client).await;
// Custom header present
assert_eq!(headers.get("x-custom").unwrap(), "value");
// Bearer token also present (not suppressed by non-auth custom header)
let auth = headers.get("authorization").unwrap();
assert!(
auth.starts_with("Bearer "),
"expected Bearer token, got: {}",
auth
);
handle.abort();
}

#[test]
fn test_client_creation_branches() {
// No headers, no OAuth, localhost → new_with_name path
let config = McpServerConfig::new("local", "http://localhost:8080");
assert!(!config.requires_auth());
assert!(!config.has_custom_headers());

// Headers only, localhost → has_custom_headers triggers authenticated path
let mut headers = std::collections::HashMap::new();
headers.insert("X-Key".to_string(), "val".to_string());
let config = McpServerConfig::new("local", "http://localhost:8080").with_headers(headers);
assert!(!config.requires_auth());
assert!(config.has_custom_headers());

// Remote HTTPS, no headers → requires_auth triggers authenticated path
let config = McpServerConfig::new("remote", "https://mcp.example.com");
assert!(config.requires_auth());
assert!(!config.has_custom_headers());
}
}
Loading