diff --git a/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/CommandFactoryHelpers.cs b/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/CommandFactoryHelpers.cs index 9f825d872b..3821f1ac10 100644 --- a/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/CommandFactoryHelpers.cs +++ b/core/Azure.Mcp.Core/tests/Azure.Mcp.Core.UnitTests/Areas/Server/CommandFactoryHelpers.cs @@ -5,9 +5,12 @@ using Azure.Mcp.Core.Areas.Group; using Azure.Mcp.Core.Areas.Subscription; using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Services.Azure.Authentication; using Azure.Mcp.Core.Services.Azure.Subscription; using Azure.Mcp.Core.Services.Azure.Tenant; using Azure.Mcp.Core.Services.Caching; +using Azure.Mcp.Core.Services.ProcessExecution; +using Azure.Mcp.Core.Services.Time; using Azure.Mcp.Tools.Acr; using Azure.Mcp.Tools.Advisor; using Azure.Mcp.Tools.Aks; @@ -171,7 +174,11 @@ public static IServiceCollection SetupCommonServices() .AddSingleton(Substitute.For()) .AddSingleton(Substitute.For()) .AddSingleton(Substitute.For()) - .AddSingleton(Substitute.For()); + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()); foreach (var area in areaSetups) { diff --git a/servers/Azure.Mcp.Server/changelog-entries/copilot-di-refactor-extension.yaml b/servers/Azure.Mcp.Server/changelog-entries/copilot-di-refactor-extension.yaml new file mode 100644 index 0000000000..3d06937b6a --- /dev/null +++ b/servers/Azure.Mcp.Server/changelog-entries/copilot-di-refactor-extension.yaml @@ -0,0 +1,4 @@ +pr: 1988 +changes: + - section: "Other Changes" + description: "Refactored `Azure.Mcp.Tools.Extension` commands to use constructor dependency injection instead of resolving services via `context.GetService()` in `ExecuteAsync`." diff --git a/servers/Azure.Mcp.Server/tests/Azure.Mcp.Server.UnitTests/Infrastructure/VisualStudioToolNameTests.cs b/servers/Azure.Mcp.Server/tests/Azure.Mcp.Server.UnitTests/Infrastructure/VisualStudioToolNameTests.cs index 72163c288e..14585a8d36 100644 --- a/servers/Azure.Mcp.Server/tests/Azure.Mcp.Server.UnitTests/Infrastructure/VisualStudioToolNameTests.cs +++ b/servers/Azure.Mcp.Server/tests/Azure.Mcp.Server.UnitTests/Infrastructure/VisualStudioToolNameTests.cs @@ -2,6 +2,9 @@ // Licensed under the MIT License. using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Services.Azure.Authentication; +using Azure.Mcp.Core.Services.ProcessExecution; +using Azure.Mcp.Core.Services.Time; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -42,7 +45,11 @@ private static Task> GetAllModeToolNamesAsync() .AddSingleton() .AddSingleton(Substitute.For()) .AddSingleton(Substitute.For()) - .AddSingleton(Substitute.For()); + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()); foreach (var area in areaSetups) { diff --git a/tools/Azure.Mcp.Tools.Extension/src/Commands/AzCommand.cs b/tools/Azure.Mcp.Tools.Extension/src/Commands/AzCommand.cs index 38eb191735..0ce6675776 100644 --- a/tools/Azure.Mcp.Tools.Extension/src/Commands/AzCommand.cs +++ b/tools/Azure.Mcp.Tools.Extension/src/Commands/AzCommand.cs @@ -14,10 +14,11 @@ namespace Azure.Mcp.Tools.Extension.Commands; -public sealed class AzCommand(ILogger logger, int processTimeoutSeconds = 300) : GlobalCommand() +public sealed class AzCommand(ILogger logger, IExternalProcessService processService, int processTimeoutSeconds = 300) : GlobalCommand() { private const string CommandTitle = "Azure CLI Command"; private readonly ILogger _logger = logger; + private readonly IExternalProcessService _processService = processService; private readonly int _processTimeoutSeconds = processTimeoutSeconds; private static string? _cachedAzPath; private volatile bool _isAuthenticated = false; @@ -178,13 +179,12 @@ public override async Task ExecuteAsync(CommandContext context, { ArgumentNullException.ThrowIfNull(options.Command); var command = options.Command; - var processService = context.GetService(); // Try to authenticate, but continue even if it fails - await AuthenticateWithAzureCredentialsAsync(processService, _logger, cancellationToken); + await AuthenticateWithAzureCredentialsAsync(_processService, _logger, cancellationToken); var azPath = FindAzCliPath() ?? throw new FileNotFoundException("Azure CLI executable not found in PATH or common installation locations. Please ensure Azure CLI is installed."); - var result = await processService.ExecuteAsync(azPath, command, + var result = await _processService.ExecuteAsync(azPath, command, operationTimeoutSeconds: _processTimeoutSeconds, cancellationToken: cancellationToken); @@ -194,7 +194,7 @@ public override async Task ExecuteAsync(CommandContext context, context.Response.Message = result.Error; } - var jElem = processService.ParseJsonOutput(result); + var jElem = _processService.ParseJsonOutput(result); context.Response.Results = ResponseResult.Create(jElem, ExtensionJsonContext.Default.JsonElement); } catch (Exception ex) diff --git a/tools/Azure.Mcp.Tools.Extension/src/Commands/AzqrCommand.cs b/tools/Azure.Mcp.Tools.Extension/src/Commands/AzqrCommand.cs index 18c7418492..4b95e8586a 100644 --- a/tools/Azure.Mcp.Tools.Extension/src/Commands/AzqrCommand.cs +++ b/tools/Azure.Mcp.Tools.Extension/src/Commands/AzqrCommand.cs @@ -17,10 +17,13 @@ namespace Azure.Mcp.Tools.Extension.Commands; -public sealed class AzqrCommand(ILogger logger, int processTimeoutSeconds = 300) : SubscriptionCommand() +public sealed class AzqrCommand(ILogger logger, ISubscriptionService subscriptionService, IDateTimeProvider dateTimeProvider, IExternalProcessService processService, int processTimeoutSeconds = 300) : SubscriptionCommand() { private const string CommandTitle = "Azure Quick Review CLI Command"; private readonly ILogger _logger = logger; + private readonly ISubscriptionService _subscriptionService = subscriptionService; + private readonly IDateTimeProvider _dateTimeProvider = dateTimeProvider; + private readonly IExternalProcessService _processService = processService; private readonly int _processTimeoutSeconds = processTimeoutSeconds; private static string? _cachedAzqrPath; public override string Id => "e7ef18a3-2730-4300-bad3-dc766f47dd2a"; @@ -76,9 +79,7 @@ public override async Task ExecuteAsync(CommandContext context, var azqrPath = FindAzqrCliPath() ?? throw new FileNotFoundException("Azure Quick Review CLI (azqr) executable not found in PATH. Please ensure azqr is installed. Go to https://aka.ms/azqr to learn more about how to install Azure Quick Review CLI."); - var subscriptionService = context.GetService(); - var dateTimeProvider = context.GetService(); - var subscription = await subscriptionService.GetSubscription(options.Subscription, options.Tenant, cancellationToken: cancellationToken); + var subscription = await _subscriptionService.GetSubscription(options.Subscription, options.Tenant, cancellationToken: cancellationToken); // Compose azqr command var command = $"scan --subscription-id {subscription.Id}"; @@ -88,7 +89,7 @@ public override async Task ExecuteAsync(CommandContext context, } var tempDir = Path.GetTempPath(); - var dateString = dateTimeProvider.UtcNow.ToString("yyyyMMdd-HHmmss"); + var dateString = _dateTimeProvider.UtcNow.ToString("yyyyMMdd-HHmmss"); var reportFileName = Path.Combine(tempDir, $"azqr-report-{options.Subscription}-{dateString}"); // Azure Quick Review always appends the file extension to the report file's name, we need to create a new path with the file extension to check for the existence of the report file. @@ -99,8 +100,7 @@ public override async Task ExecuteAsync(CommandContext context, // Also generate a JSON report for users who don't have access to Excel. command += " --json"; - var processService = context.GetService(); - var result = await processService.ExecuteAsync(azqrPath, command, + var result = await _processService.ExecuteAsync(azqrPath, command, operationTimeoutSeconds: _processTimeoutSeconds, cancellationToken: cancellationToken); diff --git a/tools/Azure.Mcp.Tools.Extension/src/Commands/CliGenerateCommand.cs b/tools/Azure.Mcp.Tools.Extension/src/Commands/CliGenerateCommand.cs index 221387d5e7..fc60273184 100644 --- a/tools/Azure.Mcp.Tools.Extension/src/Commands/CliGenerateCommand.cs +++ b/tools/Azure.Mcp.Tools.Extension/src/Commands/CliGenerateCommand.cs @@ -12,10 +12,11 @@ namespace Azure.Mcp.Tools.Extension.Commands; -public sealed class CliGenerateCommand(ILogger logger) : GlobalCommand +public sealed class CliGenerateCommand(ILogger logger, ICliGenerateService cliGenerateService) : GlobalCommand { private const string CommandTitle = "Generate CLI Command"; private readonly ILogger _logger = logger; + private readonly ICliGenerateService _cliGenerateService = cliGenerateService; private readonly string[] _allowedCliTypeValues = ["az"]; public override string Id => "3de4ef37-90bf-41f1-8385-5e870c3ae911"; @@ -84,14 +85,13 @@ public override async Task ExecuteAsync(CommandContext context, { throw new ArgumentException($"Invalid CLI type: {options.CliType}. Supported values are: {string.Join(", ", _allowedCliTypeValues)}"); } - ICliGenerateService cliGenerateService = context.GetService(); // Only log the cli type when we know for sure it doesn't have private data. context.Activity?.AddTag("cliType", cliType); if (cliType == Constants.AzureCliType) { - using HttpResponseMessage responseMessage = await cliGenerateService.GenerateAzureCLICommandAsync(intent, cancellationToken); + using HttpResponseMessage responseMessage = await _cliGenerateService.GenerateAzureCLICommandAsync(intent, cancellationToken); responseMessage.EnsureSuccessStatusCode(); var responseBody = await responseMessage.Content.ReadAsStringAsync(cancellationToken); diff --git a/tools/Azure.Mcp.Tools.Extension/src/Commands/CliInstallCommand.cs b/tools/Azure.Mcp.Tools.Extension/src/Commands/CliInstallCommand.cs index f05f8fe24f..e93f1bb62a 100644 --- a/tools/Azure.Mcp.Tools.Extension/src/Commands/CliInstallCommand.cs +++ b/tools/Azure.Mcp.Tools.Extension/src/Commands/CliInstallCommand.cs @@ -12,10 +12,11 @@ namespace Azure.Mcp.Tools.Extension.Commands; -public sealed class CliInstallCommand(ILogger logger) : GlobalCommand +public sealed class CliInstallCommand(ILogger logger, ICliInstallService cliInstallService) : GlobalCommand { private const string CommandTitle = "Get CLI installation instructions"; private readonly ILogger _logger = logger; + private readonly ICliInstallService _cliInstallService = cliInstallService; private readonly string[] _allowedCliTypeValues = ["az", "azd", "func"]; public override string Id => "464626d0-b9be-4a3b-9f29-858637ab8c10"; @@ -79,12 +80,11 @@ public override async Task ExecuteAsync(CommandContext context, { throw new ArgumentException($"Invalid CLI type: {options.CliType}. Supported values are: {string.Join(", ", _allowedCliTypeValues)}"); } - ICliInstallService cliInstallService = context.GetService(); // Only log the cli type when we know for sure it doesn't have private data. context.Activity?.AddTag("cliType", cliType); - using HttpResponseMessage responseMessage = await cliInstallService.GetCliInstallInstructions(cliType, cancellationToken); + using HttpResponseMessage responseMessage = await _cliInstallService.GetCliInstallInstructions(cliType, cancellationToken); responseMessage.EnsureSuccessStatusCode(); var responseBody = await responseMessage.Content.ReadAsStringAsync(cancellationToken); diff --git a/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/AzqrCommandTests.cs b/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/AzqrCommandTests.cs index 8851ff4c6b..3d9950bed9 100644 --- a/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/AzqrCommandTests.cs +++ b/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/AzqrCommandTests.cs @@ -45,7 +45,7 @@ public async Task ExecuteAsync_ReturnsSuccessResult_WhenScanSucceeds() var fixedDateTime = new DateTime(2024, 1, 15, 10, 30, 45, DateTimeKind.Utc); _dateTimeProvider.UtcNow.Returns(fixedDateTime); - var command = new AzqrCommand(_logger); + var command = new AzqrCommand(_logger, _subscriptionService, _dateTimeProvider, _processService); var mockSubscriptionId = "12345678-1234-1234-1234-123456789012"; var args = command.GetCommand().Parse($"--subscription {mockSubscriptionId}"); @@ -116,7 +116,7 @@ await _processService.Received().ExecuteAsync( public async Task ExecuteAsync_ReturnsBadRequest_WhenMissingSubscriptionArgument() { // Arrange - var command = new AzqrCommand(_logger); + var command = new AzqrCommand(_logger, _subscriptionService, _dateTimeProvider, _processService); var args = command.GetCommand().Parse(""); // No subscription specified var context = new CommandContext(_serviceProvider); diff --git a/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/CliGenerateCommandTests.cs b/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/CliGenerateCommandTests.cs index 1261de8add..58896e5ec0 100644 --- a/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/CliGenerateCommandTests.cs +++ b/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/CliGenerateCommandTests.cs @@ -35,7 +35,7 @@ public CliGenerateCommandTests() collection.AddSingleton(_httpClientFactory); collection.AddSingleton(_cliGenerateService); _serviceProvider = collection.BuildServiceProvider(); - _command = new(_logger); + _command = new(_logger, _cliGenerateService); _context = new(_serviceProvider); _commandDefinition = _command.GetCommand(); } diff --git a/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/CliInstallCommandTests.cs b/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/CliInstallCommandTests.cs index 1cdead31b0..95d837afd1 100644 --- a/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/CliInstallCommandTests.cs +++ b/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/CliInstallCommandTests.cs @@ -35,7 +35,7 @@ public CliInstallCommandTests() collection.AddSingleton(_httpClientFactory); collection.AddSingleton(_cliInstallService); _serviceProvider = collection.BuildServiceProvider(); - _command = new(_logger); + _command = new(_logger, _cliInstallService); _context = new(_serviceProvider); _commandDefinition = _command.GetCommand(); } diff --git a/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/ExtensionSetupTests.cs b/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/ExtensionSetupTests.cs index b8747ef062..10a606db25 100644 --- a/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/ExtensionSetupTests.cs +++ b/tools/Azure.Mcp.Tools.Extension/tests/Azure.Mcp.Tools.Extension.UnitTests/ExtensionSetupTests.cs @@ -1,9 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using Azure.Mcp.Core.Services.Azure.Authentication; +using Azure.Mcp.Core.Services.Azure.Subscription; +using Azure.Mcp.Core.Services.ProcessExecution; +using Azure.Mcp.Core.Services.Time; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Mcp.Core.Areas.Server.Options; +using NSubstitute; using Xunit; namespace Azure.Mcp.Tools.Extension.UnitTests; @@ -18,6 +23,12 @@ private static IServiceProvider BuildServiceProvider(ServiceStartOptions? startO var setup = new ExtensionSetup(); setup.ConfigureServices(services); + services.AddSingleton(Substitute.For()); + services.AddSingleton(Substitute.For()); + services.AddSingleton(Substitute.For()); + services.AddSingleton(Substitute.For()); + services.AddSingleton(Substitute.For()); + if (startOptions is not null) { services.AddSingleton(startOptions);