Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
using Azure.Mcp.Core.Areas.Group;
using Azure.Mcp.Core.Areas.Subscription;
using Azure.Mcp.Core.Commands;
using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Subscription;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Core.Services.Caching;
using Azure.Mcp.Core.Services.ProcessExecution;
using Azure.Mcp.Core.Services.Time;
using Azure.Mcp.Tools.Acr;
using Azure.Mcp.Tools.Advisor;
using Azure.Mcp.Tools.Aks;
Expand Down Expand Up @@ -171,7 +174,11 @@ public static IServiceCollection SetupCommonServices()
.AddSingleton(Substitute.For<ISubscriptionService>())
.AddSingleton(Substitute.For<ITenantService>())
.AddSingleton(Substitute.For<IHttpClientFactory>())
.AddSingleton(Substitute.For<ICacheService>());
.AddSingleton(Substitute.For<ICacheService>())
.AddSingleton(Substitute.For<IDateTimeProvider>())
.AddSingleton(Substitute.For<IExternalProcessService>())
.AddSingleton(Substitute.For<IAzureTokenCredentialProvider>())
.AddSingleton(Substitute.For<IAzureCloudConfiguration>());

foreach (var area in areaSetups)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pr: 1988
changes:
- section: "Other Changes"
description: "Refactored `Azure.Mcp.Tools.Extension` commands to use constructor dependency injection instead of resolving services via `context.GetService<T>()` in `ExecuteAsync`."
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// Licensed under the MIT License.

using Azure.Mcp.Core.Commands;
using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.ProcessExecution;
using Azure.Mcp.Core.Services.Time;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
Expand Down Expand Up @@ -42,7 +45,11 @@ private static Task<List<string>> GetAllModeToolNamesAsync()
.AddSingleton<ITelemetryService, NoOpTelemetryService>()
.AddSingleton(Substitute.For<Azure.Mcp.Core.Services.Azure.Subscription.ISubscriptionService>())
.AddSingleton(Substitute.For<Azure.Mcp.Core.Services.Azure.Tenant.ITenantService>())
.AddSingleton(Substitute.For<IHttpClientFactory>());
.AddSingleton(Substitute.For<IHttpClientFactory>())
.AddSingleton(Substitute.For<IDateTimeProvider>())
.AddSingleton(Substitute.For<IExternalProcessService>())
.AddSingleton(Substitute.For<IAzureTokenCredentialProvider>())
.AddSingleton(Substitute.For<IAzureCloudConfiguration>());

foreach (var area in areaSetups)
{
Expand Down
10 changes: 5 additions & 5 deletions tools/Azure.Mcp.Tools.Extension/src/Commands/AzCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

Comment thread
conniey marked this conversation as resolved.
namespace Azure.Mcp.Tools.Extension.Commands;

public sealed class AzCommand(ILogger<AzCommand> logger, int processTimeoutSeconds = 300) : GlobalCommand<AzOptions>()
public sealed class AzCommand(ILogger<AzCommand> logger, IExternalProcessService processService, int processTimeoutSeconds = 300) : GlobalCommand<AzOptions>()
Comment thread
conniey marked this conversation as resolved.
{
private const string CommandTitle = "Azure CLI Command";
private readonly ILogger<AzCommand> _logger = logger;
private readonly IExternalProcessService _processService = processService;
private readonly int _processTimeoutSeconds = processTimeoutSeconds;
private static string? _cachedAzPath;
private volatile bool _isAuthenticated = false;
Expand Down Expand Up @@ -178,13 +179,12 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
{
ArgumentNullException.ThrowIfNull(options.Command);
var command = options.Command;
var processService = context.GetService<IExternalProcessService>();

// Try to authenticate, but continue even if it fails
await AuthenticateWithAzureCredentialsAsync(processService, _logger, cancellationToken);
await AuthenticateWithAzureCredentialsAsync(_processService, _logger, cancellationToken);

var azPath = FindAzCliPath() ?? throw new FileNotFoundException("Azure CLI executable not found in PATH or common installation locations. Please ensure Azure CLI is installed.");
var result = await processService.ExecuteAsync(azPath, command,
var result = await _processService.ExecuteAsync(azPath, command,
operationTimeoutSeconds: _processTimeoutSeconds,
cancellationToken: cancellationToken);

Expand All @@ -194,7 +194,7 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
context.Response.Message = result.Error;
}

var jElem = processService.ParseJsonOutput(result);
var jElem = _processService.ParseJsonOutput(result);
context.Response.Results = ResponseResult.Create(jElem, ExtensionJsonContext.Default.JsonElement);
}
catch (Exception ex)
Expand Down
14 changes: 7 additions & 7 deletions tools/Azure.Mcp.Tools.Extension/src/Commands/AzqrCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

namespace Azure.Mcp.Tools.Extension.Commands;

public sealed class AzqrCommand(ILogger<AzqrCommand> logger, int processTimeoutSeconds = 300) : SubscriptionCommand<AzqrOptions>()
public sealed class AzqrCommand(ILogger<AzqrCommand> logger, ISubscriptionService subscriptionService, IDateTimeProvider dateTimeProvider, IExternalProcessService processService, int processTimeoutSeconds = 300) : SubscriptionCommand<AzqrOptions>()
{
private const string CommandTitle = "Azure Quick Review CLI Command";
private readonly ILogger<AzqrCommand> _logger = logger;
private readonly ISubscriptionService _subscriptionService = subscriptionService;
private readonly IDateTimeProvider _dateTimeProvider = dateTimeProvider;
private readonly IExternalProcessService _processService = processService;
private readonly int _processTimeoutSeconds = processTimeoutSeconds;
private static string? _cachedAzqrPath;
public override string Id => "e7ef18a3-2730-4300-bad3-dc766f47dd2a";
Expand Down Expand Up @@ -76,9 +79,7 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,

var azqrPath = FindAzqrCliPath() ?? throw new FileNotFoundException("Azure Quick Review CLI (azqr) executable not found in PATH. Please ensure azqr is installed. Go to https://aka.ms/azqr to learn more about how to install Azure Quick Review CLI.");

var subscriptionService = context.GetService<ISubscriptionService>();
var dateTimeProvider = context.GetService<IDateTimeProvider>();
var subscription = await subscriptionService.GetSubscription(options.Subscription, options.Tenant, cancellationToken: cancellationToken);
var subscription = await _subscriptionService.GetSubscription(options.Subscription, options.Tenant, cancellationToken: cancellationToken);

// Compose azqr command
var command = $"scan --subscription-id {subscription.Id}";
Expand All @@ -88,7 +89,7 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
}

var tempDir = Path.GetTempPath();
var dateString = dateTimeProvider.UtcNow.ToString("yyyyMMdd-HHmmss");
var dateString = _dateTimeProvider.UtcNow.ToString("yyyyMMdd-HHmmss");
var reportFileName = Path.Combine(tempDir, $"azqr-report-{options.Subscription}-{dateString}");

// Azure Quick Review always appends the file extension to the report file's name, we need to create a new path with the file extension to check for the existence of the report file.
Expand All @@ -99,8 +100,7 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
// Also generate a JSON report for users who don't have access to Excel.
command += " --json";

var processService = context.GetService<IExternalProcessService>();
var result = await processService.ExecuteAsync(azqrPath, command,
var result = await _processService.ExecuteAsync(azqrPath, command,
operationTimeoutSeconds: _processTimeoutSeconds,
cancellationToken: cancellationToken);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

namespace Azure.Mcp.Tools.Extension.Commands;

public sealed class CliGenerateCommand(ILogger<CliGenerateCommand> logger) : GlobalCommand<CliGenerateOptions>
public sealed class CliGenerateCommand(ILogger<CliGenerateCommand> logger, ICliGenerateService cliGenerateService) : GlobalCommand<CliGenerateOptions>
{
private const string CommandTitle = "Generate CLI Command";
private readonly ILogger<CliGenerateCommand> _logger = logger;
private readonly ICliGenerateService _cliGenerateService = cliGenerateService;
private readonly string[] _allowedCliTypeValues = ["az"];

public override string Id => "3de4ef37-90bf-41f1-8385-5e870c3ae911";
Expand Down Expand Up @@ -84,14 +85,13 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
{
throw new ArgumentException($"Invalid CLI type: {options.CliType}. Supported values are: {string.Join(", ", _allowedCliTypeValues)}");
}
ICliGenerateService cliGenerateService = context.GetService<ICliGenerateService>();

// Only log the cli type when we know for sure it doesn't have private data.
context.Activity?.AddTag("cliType", cliType);

if (cliType == Constants.AzureCliType)
{
using HttpResponseMessage responseMessage = await cliGenerateService.GenerateAzureCLICommandAsync(intent, cancellationToken);
using HttpResponseMessage responseMessage = await _cliGenerateService.GenerateAzureCLICommandAsync(intent, cancellationToken);
responseMessage.EnsureSuccessStatusCode();

var responseBody = await responseMessage.Content.ReadAsStringAsync(cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

namespace Azure.Mcp.Tools.Extension.Commands;

public sealed class CliInstallCommand(ILogger<CliInstallCommand> logger) : GlobalCommand<CliInstallOptions>
public sealed class CliInstallCommand(ILogger<CliInstallCommand> logger, ICliInstallService cliInstallService) : GlobalCommand<CliInstallOptions>
{
private const string CommandTitle = "Get CLI installation instructions";
private readonly ILogger<CliInstallCommand> _logger = logger;
private readonly ICliInstallService _cliInstallService = cliInstallService;
private readonly string[] _allowedCliTypeValues = ["az", "azd", "func"];

public override string Id => "464626d0-b9be-4a3b-9f29-858637ab8c10";
Expand Down Expand Up @@ -79,12 +80,11 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
{
throw new ArgumentException($"Invalid CLI type: {options.CliType}. Supported values are: {string.Join(", ", _allowedCliTypeValues)}");
}
ICliInstallService cliInstallService = context.GetService<ICliInstallService>();

// Only log the cli type when we know for sure it doesn't have private data.
context.Activity?.AddTag("cliType", cliType);

using HttpResponseMessage responseMessage = await cliInstallService.GetCliInstallInstructions(cliType, cancellationToken);
using HttpResponseMessage responseMessage = await _cliInstallService.GetCliInstallInstructions(cliType, cancellationToken);
responseMessage.EnsureSuccessStatusCode();

var responseBody = await responseMessage.Content.ReadAsStringAsync(cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public async Task ExecuteAsync_ReturnsSuccessResult_WhenScanSucceeds()
var fixedDateTime = new DateTime(2024, 1, 15, 10, 30, 45, DateTimeKind.Utc);
_dateTimeProvider.UtcNow.Returns(fixedDateTime);

var command = new AzqrCommand(_logger);
var command = new AzqrCommand(_logger, _subscriptionService, _dateTimeProvider, _processService);

var mockSubscriptionId = "12345678-1234-1234-1234-123456789012";
var args = command.GetCommand().Parse($"--subscription {mockSubscriptionId}");
Expand Down Expand Up @@ -116,7 +116,7 @@ await _processService.Received().ExecuteAsync(
public async Task ExecuteAsync_ReturnsBadRequest_WhenMissingSubscriptionArgument()
{
// Arrange
var command = new AzqrCommand(_logger);
var command = new AzqrCommand(_logger, _subscriptionService, _dateTimeProvider, _processService);

var args = command.GetCommand().Parse(""); // No subscription specified
var context = new CommandContext(_serviceProvider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public CliGenerateCommandTests()
collection.AddSingleton(_httpClientFactory);
collection.AddSingleton(_cliGenerateService);
_serviceProvider = collection.BuildServiceProvider();
_command = new(_logger);
_command = new(_logger, _cliGenerateService);
_context = new(_serviceProvider);
_commandDefinition = _command.GetCommand();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public CliInstallCommandTests()
collection.AddSingleton(_httpClientFactory);
collection.AddSingleton(_cliInstallService);
_serviceProvider = collection.BuildServiceProvider();
_command = new(_logger);
_command = new(_logger, _cliInstallService);
_context = new(_serviceProvider);
_commandDefinition = _command.GetCommand();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using Azure.Mcp.Core.Services.Azure.Authentication;
using Azure.Mcp.Core.Services.Azure.Subscription;
using Azure.Mcp.Core.Services.ProcessExecution;
using Azure.Mcp.Core.Services.Time;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Mcp.Core.Areas.Server.Options;
using NSubstitute;
using Xunit;

namespace Azure.Mcp.Tools.Extension.UnitTests;
Expand All @@ -18,6 +23,12 @@ private static IServiceProvider BuildServiceProvider(ServiceStartOptions? startO
var setup = new ExtensionSetup();
setup.ConfigureServices(services);

services.AddSingleton(Substitute.For<IExternalProcessService>());
services.AddSingleton(Substitute.For<ISubscriptionService>());
services.AddSingleton(Substitute.For<IDateTimeProvider>());
services.AddSingleton(Substitute.For<IAzureTokenCredentialProvider>());
services.AddSingleton(Substitute.For<IAzureCloudConfiguration>());

if (startOptions is not null)
{
services.AddSingleton(startOptions);
Expand Down