Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ Changelog

### Fixes

- [FIR, Anvil Interop] Fix rank-based binding replacements getting dropped for multi-contribution classes in root graphs.
- [FIR, Anvil Interop] Fix rank-based binding replacements getting dropped for multi-contribution classes in root graphs when contributions are from external modules.
- [IR, Anvil Interop] Fix rank-based binding replacements getting dropped for multi-contribution classes in graph extensions when contributions are from external modules.
- [FIR] Named annotation arguments in different order from declared parameters getting silently skipped.

0.9.2
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Similar to https://github.com/ZacSweers/metro/issues/1549 but for 'rank' processing
// The important conditions in this test case:
// - There is a class LowRankImpl with one or more bindings
// - There is another class HighRankImpl with the same binding as LowRankImpl, at least one other
// binding, and one of the bindings replaces LowRankImpl
// - The bindings are contributed to a graph extension

// WITH_ANVIL
// MODULE: lib
import com.squareup.anvil.annotations.ContributesBinding

interface ContributedInterface

interface OtherInterface

interface LoggedInScope

@ContributesBinding(LoggedInScope::class, boundType = ContributedInterface::class)
object LowRankImpl : ContributedInterface, OtherInterface

@ContributesBinding(LoggedInScope::class, boundType = OtherInterface::class)
@ContributesBinding(LoggedInScope::class, boundType = ContributedInterface::class, rank = 100)
object HighRankImpl : ContributedInterface, OtherInterface

// MODULE: main(lib)
@GraphExtension(LoggedInScope::class)
interface LoggedInGraph {
val contributedInterface: ContributedInterface

@GraphExtension.Factory
@ContributesTo(AppScope::class)
interface Factory {
fun createLoggedInGraph(): LoggedInGraph
}
}

@DependencyGraph(AppScope::class) interface AppGraph

fun box(): String {
val graph = createGraph<AppGraph>().createLoggedInGraph()
assertTrue(graph.contributedInterface == HighRankImpl)
return "OK"
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (C) 2025 Zac Sweers
// SPDX-License-Identifier: Apache-2.0
package dev.zacsweers.metro.compiler

import org.jetbrains.kotlin.name.ClassId

/**
* Computes the set of class IDs that have been "outranked" by higher-ranked bindings.
*
* This groups bindings by their type key, finds groups with multiple bindings, and for each group
* keeps only the highest-ranked bindings. Returns the class IDs of all outranked bindings.
*
* @param BindingType the binding type
* @param TypeKeyType the type key type (must support equality for grouping)
* @param bindings the list of bindings to process
* @param typeKeySelector extracts the type key from a binding (used for grouping)
* @param rankSelector extracts the rank from a binding (higher rank wins)
* @param classId extracts the class ID from a binding (for the result set)
* @return the set of class IDs that were outranked
*/
internal inline fun <BindingType, TypeKeyType> computeOutrankedBindings(
bindings: List<BindingType>,
typeKeySelector: (BindingType) -> TypeKeyType,
rankSelector: (BindingType) -> Long,
classId: (BindingType) -> ClassId,
): Set<ClassId> {
if (bindings.isEmpty()) return emptySet()

val result = HashSet<ClassId>(bindings.size)

val bindingsByTypeKey = bindings.groupBy(typeKeySelector).filter { (_, group) -> group.size > 1 }

for ((_, bindingGroup) in bindingsByTypeKey) {
val bindingsByRank = bindingGroup.groupBy(rankSelector)

val maxKey =
bindingsByRank.keys.maxOrNull()
// Map was empty, nothing to do here
?: continue

val topBindings = bindingsByRank.getValue(maxKey)

// These are the bindings that were outranked and should not be processed further
for (binding in (bindingGroup - topBindings.toSet())) {
result += classId(binding)
}
}

return result
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,9 @@ internal sealed interface MetroFirTypeResolver {
.type
}
}

companion object {
// For cases where we use this in IR, just use the external resolver
fun forIrUse(): MetroFirTypeResolver = ExternalMetroFirTypeResolver
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package dev.zacsweers.metro.compiler.fir.generators
import dev.zacsweers.metro.compiler.MetroOptions
import dev.zacsweers.metro.compiler.api.fir.MetroContributionExtension
import dev.zacsweers.metro.compiler.compat.CompatContext
import dev.zacsweers.metro.compiler.computeOutrankedBindings
import dev.zacsweers.metro.compiler.expectAsOrNull
import dev.zacsweers.metro.compiler.fir.FirTypeKey
import dev.zacsweers.metro.compiler.fir.MetroFirTypeResolver
Expand All @@ -28,6 +29,7 @@ import dev.zacsweers.metro.compiler.fir.resolvedReplacedClassIds
import dev.zacsweers.metro.compiler.fir.resolvedScopeClassId
import dev.zacsweers.metro.compiler.fir.scopeArgument
import dev.zacsweers.metro.compiler.getAndAdd
import dev.zacsweers.metro.compiler.ir.IrRankedBindingProcessing
import dev.zacsweers.metro.compiler.singleOrError
import dev.zacsweers.metro.compiler.symbols.Symbols
import java.util.Optional
Expand Down Expand Up @@ -506,8 +508,6 @@ internal class ContributedInterfaceSupertypeGenerator(
contributions: Map<ClassId, ConeKotlinType>,
typeResolver: TypeResolveService,
): Set<ClassId> {
val pendingRankReplacements = mutableSetOf<ClassId>()

val rankedBindings =
contributions.values
.filterIsInstance<ConeClassLikeType>()
Expand Down Expand Up @@ -539,7 +539,7 @@ internal class ContributedInterfaceSupertypeGenerator(
.coneType
} ?: contributingType.implicitBoundType(typeResolver)

ContributedBinding(
IrRankedBindingProcessing.ContributedBinding(
contributingType = contributingType,
typeKey =
FirTypeKey(
Expand All @@ -552,25 +552,12 @@ internal class ContributedInterfaceSupertypeGenerator(
}
}

val bindingGroups =
rankedBindings
.groupBy { binding -> binding.typeKey }
.filter { bindingGroup -> bindingGroup.value.size > 1 }

for (bindingGroup in bindingGroups.values) {
val topBindings =
bindingGroup
.groupBy { binding -> binding.rank }
.toSortedMap()
.let { it.getValue(it.lastKey()) }

// These are the bindings that were outranked and should not be processed further
bindingGroup.minus(topBindings).forEach {
pendingRankReplacements += it.contributingType.classId
}
}

return pendingRankReplacements
return computeOutrankedBindings(
rankedBindings,
typeKeySelector = { it.typeKey },
rankSelector = { it.rank },
classId = { it.contributingType.classId },
)
}

@OptIn(ResolveStateAccess::class, SymbolInternals::class)
Expand All @@ -596,10 +583,4 @@ internal class ContributedInterfaceSupertypeGenerator(
"${classId.asSingleFqName()} has a ranked binding with no explicit bound type and $size supertypes ($superTypeFqNames). There must be exactly one supertype or an explicit bound type."
}
}

private data class ContributedBinding(
val contributingType: FirClassLikeSymbol<*>,
val typeKey: FirTypeKey,
val rank: Long,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
package dev.zacsweers.metro.compiler.ir

import dev.zacsweers.metro.compiler.expectAsOrNull
import dev.zacsweers.metro.compiler.fir.annotationsIn
import dev.zacsweers.metro.compiler.fir.coneTypeIfResolved
import dev.zacsweers.metro.compiler.fir.replacesArgument
import dev.zacsweers.metro.compiler.getAndAdd
import dev.zacsweers.metro.compiler.symbols.Symbols
import java.util.SortedMap
import java.util.SortedSet
import org.jetbrains.kotlin.fir.expressions.FirGetClassCall
Expand All @@ -18,7 +16,6 @@ import org.jetbrains.kotlin.ir.expressions.IrConstructorCall
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.util.classId
import org.jetbrains.kotlin.ir.util.classIdOrFail
import org.jetbrains.kotlin.ir.util.getValueArgument
import org.jetbrains.kotlin.ir.util.parentAsClass
import org.jetbrains.kotlin.name.ClassId

Expand Down Expand Up @@ -173,7 +170,7 @@ internal class IrContributionMerger(
.flatMap { annotation -> annotation.replacedClasses() }
.mapNotNull { replacedClass -> replacedClass.classType.rawType().classId }
},
firBody = { annotations ->
firBody = { _, annotations ->
annotations
.flatMap { it.replacesArgument()?.argumentList?.arguments.orEmpty() }
.mapNotNull { it.expectAsOrNull<FirGetClassCall>()?.coneTypeIfResolved()?.classId }
Expand Down Expand Up @@ -236,76 +233,9 @@ internal class IrContributionMerger(
private fun processRankBasedReplacements(
allScopes: Set<ClassId>,
contributions: Map<ClassId, List<IrType>>,
): Set<ClassId> {
val pendingRankReplacements = mutableSetOf<ClassId>()

val rankedBindings =
contributions.values
.flatten()
.map { it.rawType().parentAsClass }
.distinctBy { it.classIdOrFail }
.flatMap { contributingType ->
contributingType
.annotationsIn(metroSymbols.classIds.contributesBindingAnnotations)
.mapNotNull { annotation ->
val scope = annotation.scopeOrNull() ?: return@mapNotNull null
if (scope !in allScopes) return@mapNotNull null

val explicitBindingMissingMetadata =
annotation.getValueArgument(Symbols.Names.binding)

if (explicitBindingMissingMetadata != null) {
// This is a case where an explicit binding is specified but we receive the argument
// as FirAnnotationImpl without the metadata containing the type arguments so we
// short-circuit since we lack the info to compare it against other bindings.
null
} else {
val (explicitBindingType, ignoreQualifier) = annotation.bindingTypeOrNull()
val boundType =
explicitBindingType
?: contributingType.implicitBoundTypeOrNull()!! // Checked in FIR

ContributedIrBinding(
contributingType = contributingType,
typeKey =
IrTypeKey(
boundType,
if (ignoreQualifier) null else contributingType.qualifierAnnotation(),
),
rank = annotation.rankValue(),
)
}
}
}

val bindingGroups =
rankedBindings
.groupBy { binding -> binding.typeKey }
.filter { bindingGroup -> bindingGroup.value.size > 1 }

for (bindingGroup in bindingGroups.values) {
val topBindings =
bindingGroup
.groupBy { binding -> binding.rank }
.toSortedMap()
.let { it.getValue(it.lastKey()) }

// These are the bindings that were outranked and should not be processed further
bindingGroup.minus(topBindings).forEach {
pendingRankReplacements += it.contributingType.classIdOrFail
}
}

return pendingRankReplacements
}
): Set<ClassId> = IrRankedBindingProcessing.processRankBasedReplacements(allScopes, contributions)
}

private data class ContributedIrBinding(
val contributingType: IrClass,
val typeKey: IrTypeKey,
val rank: Long,
)

internal data class IrContributions(
val primaryScope: ClassId?,
val allScopes: Set<ClassId>,
Expand Down
Loading