Propagate Composable requirement to implementors of lambda interfaces.

Propagates requirement of composable invoke operator override to classes and interfaces that are extending lambda interfaces with corresponding annotations.

Test: ComposableDeclarationCheckerTests
Change-Id: Id2757562f10a2f0da198c7cce602d8ea8726b809
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/analysis/ComposableDeclarationCheckerTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/analysis/ComposableDeclarationCheckerTests.kt
index 3239060..07aa897 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/analysis/ComposableDeclarationCheckerTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/analysis/ComposableDeclarationCheckerTests.kt
@@ -286,4 +286,74 @@
             """
         )
     }
+
+    @Test
+    fun testOverrideComposableLambda() {
+        check(
+            """
+                import androidx.compose.runtime.Composable
+
+                class Impl : @Composable () -> Unit {
+                    @Composable
+                    override fun invoke() {}
+                }
+            """
+        )
+    }
+
+    @Test
+    fun testTransitiveOverrideComposableLambda() {
+        check(
+            """
+                import androidx.compose.runtime.Composable
+
+                interface ComposableFunction : @Composable () -> Unit
+
+                class Impl : ComposableFunction {
+                    @Composable
+                    override fun invoke() {}
+                }
+            """
+        )
+    }
+
+    @Test
+    fun testMissingOverrideComposableLambda() {
+        check(
+            """
+                import androidx.compose.runtime.Composable
+
+                class Impl : @Composable () -> Unit {
+                    override fun invoke() {}
+                }
+            """
+        )
+    }
+
+    @Test
+    fun testWrongOverrideLambda() {
+        check(
+            """
+                import androidx.compose.runtime.Composable
+
+                class Impl : () -> Unit {
+                    @Composable override fun invoke() {}
+                }
+            """
+        )
+    }
+
+    @Test
+    fun testMultipleOverrideLambda() {
+        check(
+            """
+                import androidx.compose.runtime.Composable
+
+                class Impl : () -> Unit, @Composable (Int) -> Unit {
+                    @Composable override fun invoke() {}
+                    @Composable override fun invoke(p0: Int) {}
+                }
+            """
+        )
+    }
 }
diff --git a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposableDeclarationChecker.kt b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposableDeclarationChecker.kt
index c0aad2e..2cc17dc 100644
--- a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposableDeclarationChecker.kt
+++ b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposableDeclarationChecker.kt
@@ -22,9 +22,11 @@
 import androidx.compose.compiler.plugins.kotlin.ComposeErrors.COMPOSABLE_SUSPEND_FUN
 import androidx.compose.compiler.plugins.kotlin.ComposeErrors.COMPOSABLE_VAR
 import com.intellij.psi.PsiElement
+import org.jetbrains.kotlin.builtins.isFunctionType
 import org.jetbrains.kotlin.builtins.isSuspendFunctionType
 import org.jetbrains.kotlin.container.StorageComponentContainer
 import org.jetbrains.kotlin.container.useInstance
+import org.jetbrains.kotlin.descriptors.ClassDescriptor
 import org.jetbrains.kotlin.descriptors.DeclarationDescriptor
 import org.jetbrains.kotlin.descriptors.FunctionDescriptor
 import org.jetbrains.kotlin.descriptors.Modality
@@ -41,6 +43,7 @@
 import org.jetbrains.kotlin.resolve.checkers.DeclarationChecker
 import org.jetbrains.kotlin.resolve.checkers.DeclarationCheckerContext
 import org.jetbrains.kotlin.types.KotlinType
+import org.jetbrains.kotlin.types.typeUtil.supertypes
 import org.jetbrains.kotlin.util.OperatorNameConventions
 
 class ComposableDeclarationChecker : DeclarationChecker, StorageComponentContainerContributor {
@@ -79,7 +82,22 @@
         val hasComposableAnnotation = descriptor.hasComposableAnnotation()
         if (descriptor.overriddenDescriptors.isNotEmpty()) {
             val override = descriptor.overriddenDescriptors.first()
-            if (override.hasComposableAnnotation() != hasComposableAnnotation) {
+            val overrideFunctionIsComposable =
+                if (descriptor.isOperator && descriptor.name == OperatorNameConventions.INVOKE) {
+                    override.hasComposableAnnotation() || descriptor.let {
+                        val cls = descriptor.containingDeclaration as? ClassDescriptor
+                        cls?.run {
+                            defaultType.supertypes().any {
+                                it.isFunctionType &&
+                                    it.arguments.size == descriptor.arity + 1 &&
+                                    it.hasComposableAnnotation()
+                            }
+                        } ?: false
+                    }
+                } else {
+                    override.hasComposableAnnotation()
+                }
+            if (overrideFunctionIsComposable != hasComposableAnnotation) {
                 context.trace.report(
                     ComposeErrors.CONFLICTING_OVERLOADS.on(
                         declaration,
@@ -234,4 +252,9 @@
             context.trace.report(COMPOSABLE_VAR.on(name))
         }
     }
+
+    private val FunctionDescriptor.arity get(): Int =
+        if (extensionReceiverParameter != null) 1 else 0 +
+            contextReceiverParameters.size +
+            valueParameters.size
 }