Skip to content
This repository was archived by the owner on Jun 21, 2020. It is now read-only.

DI as a Kotlin compiler plugin #2

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
95 changes: 95 additions & 0 deletions PROPOSAL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
This note explores compile-time "safe" DI container implementation relying on language itself, as opposed to annotation processing.

Topics to explore for MVP:
1. Exposing dependencies from container
2. Wiring(binding) dependency graph
3. Injection (constructor and instance fields)
4. Scoping (local singletons)
5. Subscoping and providing parts of the parent graph to children.

## Exposing dependencies:
The problem of exposing dependencies can be reduced to constructor injection.
Similarly to dagger, which is generating implementation for an interface, we can imagine the container for
exposed types as a class.
The framework task in this scenario is to wire chain of the constructors with correct dependencies.
```kotlin
class FooComponent(
val bar: Bar,
val otherDependency: Other.Dependency
)

// or

interface FooComponent {
// scoped
val bar: Bar
// unscoped
fun otherDependency(): Other.Dependency
}
```
Component definition can be done using Kotlin DSL and compiler transformations. `TODO: provide more details`
```kotlin
// Library
fun <T> component(vararg dependencies: Any?): T = TODO()
fun modules(vararg modules: Module, block: () -> Unit)
interface Module // Marker

// Client
object Module : k.Module {
fun providesInt(): Int = 0L
}

class Module1() : k.Module {
fun bar(int: Int): Bar = Bar(int)
}

fun init() {
val long = 0
modules(Module, Module1()) { // available as 'this'
val fooComponent = component<FooComponent>(
instance<Bar>(long),
instance<Bar1>(::bar1)
)
}

val instance: Bar = fooComponent.bar
}
```
This way, the dependency graph can be defined in multiple ways. The framework task is to ensure that
all the types are known in the compile time and validate the graph.

Food for thought:
- Is such "dynamic" definition providing better user experience rather than "static graph" that Dagger use?

## Wiring

`TODO: module definition`

`TODO: explore macwire`

`TODO: explore boost di`

## Injection

`TODO: more details`

Field injection:
```kotlin
// Library
interface Injectable<T> {
fun <R> inject(): ReadOnlyProperty<Injectable<T>, R> = // delegate
}

// Client
class SomeClass : Injectable<Foo> {
val someField: Foo by inject() // Can we make it compile time safe?
}
```

## Scoping

`TODO: local scoping and their relation`

## Subscoping

`TODO: parent - children relationship`
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ compileTestKotlin {
}

application {
mainClassName = 'MainKt'
mainClassName = 'ClientKt'
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package me.shika.di

import me.shika.di.resolver.COMPONENT_CALLS
import me.shika.di.resolver.GENERATED_CALL_NAME
import me.shika.di.resolver.findTopLevelFunctionForName
import me.shika.di.resolver.varargArguments
import org.jetbrains.kotlin.codegen.JvmKotlinType
import org.jetbrains.kotlin.codegen.StackValue
import org.jetbrains.kotlin.codegen.extensions.ExpressionCodegenExtension
import org.jetbrains.kotlin.resolve.DelegatingBindingTrace
import org.jetbrains.kotlin.resolve.calls.model.ExpressionValueArgument
import org.jetbrains.kotlin.resolve.calls.model.MutableDataFlowInfoForArguments
import org.jetbrains.kotlin.resolve.calls.model.ResolvedCall
import org.jetbrains.kotlin.resolve.calls.model.ResolvedCallImpl
import org.jetbrains.kotlin.resolve.calls.smartcasts.DataFlowInfo
import org.jetbrains.kotlin.resolve.calls.tasks.TracingStrategy
import org.jetbrains.kotlin.resolve.descriptorUtil.module

class DiCodegenExtension : ExpressionCodegenExtension {
override fun applyFunction(
receiver: StackValue,
resolvedCall: ResolvedCall<*>,
c: ExpressionCodegenExtension.Context
): StackValue? {
val isComponentCall = c.codegen.bindingContext[COMPONENT_CALLS, resolvedCall] == true
if (isComponentCall) {
val module = c.codegen.context.functionDescriptor.module
val function = module.findTopLevelFunctionForName(c.codegen.bindingContext[GENERATED_CALL_NAME, resolvedCall]!!)!!
val clsType = c.typeMapper.mapType(function.returnType!!)

val arguments = resolvedCall.varargArguments()
function.valueParameters.zip(arguments).forEachIndexed { i, pair ->
val (param, argument) = pair
val type = c.codegen.bindingContext.getType(argument.getArgumentExpression()!!)!!
c.codegen.defaultCallGenerator.genValueAndPut(
param,
argument.getArgumentExpression()!!,
JvmKotlinType(c.typeMapper.mapType(type)),
i
)
}

val ctorResolvedCall = ResolvedCallImpl(
resolvedCall.call,
function,
resolvedCall.dispatchReceiver,
resolvedCall.extensionReceiver,
resolvedCall.explicitReceiverKind,
null,
DelegatingBindingTrace(c.codegen.bindingContext, "test"),
TracingStrategy.EMPTY,
MutableDataFlowInfoForArguments.WithoutArgumentsCheck(DataFlowInfo.EMPTY)
)

resolvedCall.valueArguments.entries.first().value.arguments.forEachIndexed { index, valueArgument ->
ctorResolvedCall.recordValueArgument(function.valueParameters[index], ExpressionValueArgument(valueArgument))
}

val method = c.typeMapper.mapToCallableMethod(function, false, resolvedCall = ctorResolvedCall)
return StackValue.functionCall(clsType, function.returnType) {
c.codegen.invokeMethodWithArguments(method, ctorResolvedCall, StackValue.none())
}
} else {
return super.applyFunction(receiver, resolvedCall, c)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package me.shika.test
package me.shika.di

import com.google.auto.service.AutoService
import org.jetbrains.kotlin.compiler.plugin.AbstractCliOption
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package me.shika.di

import me.shika.di.model.Binding
import me.shika.di.render.GraphToFunctionRenderer
import me.shika.di.resolver.COMPONENT_CALLS
import me.shika.di.resolver.ComponentDescriptor
import me.shika.di.resolver.GENERATED_CALL_NAME
import me.shika.di.resolver.ResolverContext
import me.shika.di.resolver.generateCallNames
import me.shika.di.resolver.resolveGraph
import me.shika.di.resolver.resultType
import me.shika.di.resolver.validation.ExtractAnonymousTypes
import me.shika.di.resolver.validation.ExtractFunctions
import me.shika.di.resolver.validation.ParseParameters
import me.shika.di.resolver.validation.ReportBindingDuplicates
import org.jetbrains.kotlin.analyzer.AnalysisResult
import org.jetbrains.kotlin.com.intellij.openapi.project.Project
import org.jetbrains.kotlin.descriptors.ModuleDescriptor
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.resolve.BindingTrace
import org.jetbrains.kotlin.resolve.jvm.extensions.AnalysisHandlerExtension
import org.jetbrains.kotlin.storage.LockBasedStorageManager
import java.io.File

class DiCompilerAnalysisExtension(
private val sourcesDir: File
) : AnalysisHandlerExtension {
private var generatedFiles = false

override fun analysisCompleted(
project: Project,
module: ModuleDescriptor,
bindingTrace: BindingTrace,
files: Collection<KtFile>
): AnalysisResult? {
val calls = bindingTrace.getKeys(COMPONENT_CALLS)
calls.generateCallNames(bindingTrace)
.forEach { (call, name) ->
bindingTrace.record(GENERATED_CALL_NAME, call, FqName(name))
}

if (generatedFiles) {
return null
}
generatedFiles = true

val processors = listOf(
ParseParameters(),
ExtractFunctions(),
ExtractAnonymousTypes(),
ReportBindingDuplicates()
)

calls.forEach { resolvedCall ->
val context = ResolverContext(bindingTrace, LockBasedStorageManager.NO_LOCKS, resolvedCall)
val resultType = resolvedCall.resultType

val bindings = processors.fold(emptySequence<Binding>(), { bindings, processor ->
with(processor) {
context.process(bindings)
}
})

val descriptor = ComponentDescriptor(resultType!!, bindings.toList())
val graph = context.resolveGraph(descriptor)

val fileSpec = GraphToFunctionRenderer(context).invoke(graph)
fileSpec.writeTo(sourcesDir)
}
return AnalysisResult.RetryWithAdditionalRoots(
bindingContext = bindingTrace.bindingContext,
moduleDescriptor = module,
additionalJavaRoots = emptyList(),
additionalKotlinRoots = listOf(sourcesDir)
) // Repeat with my files pls
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package me.shika.di

import me.shika.di.resolver.COMPONENT_CALLS
import me.shika.di.resolver.COMPONENT_FUN_FQ_NAME
import me.shika.di.resolver.MODULE_FUN_FQ_NAME
import me.shika.di.resolver.MODULE_GET
import me.shika.di.resolver.classDescriptor
import org.jetbrains.kotlin.com.intellij.psi.PsiElement
import org.jetbrains.kotlin.descriptors.ValueParameterDescriptor
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.resolve.bindingContextUtil.getReferenceTargets
import org.jetbrains.kotlin.resolve.calls.checkers.CallChecker
import org.jetbrains.kotlin.resolve.calls.checkers.CallCheckerContext
import org.jetbrains.kotlin.resolve.calls.model.ResolvedCall
import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe
import org.jetbrains.kotlin.resolve.scopes.receivers.ExpressionReceiver

class DiCompilerCallChecker : CallChecker {
override fun check(resolvedCall: ResolvedCall<*>, reportOn: PsiElement, context: CallCheckerContext) {
val bContext = context.trace.bindingContext
if (resolvedCall.candidateDescriptor.fqNameSafe == COMPONENT_FUN_FQ_NAME) {
context.trace.record(COMPONENT_CALLS, resolvedCall)
}

if (resolvedCall.candidateDescriptor.fqNameSafe == MODULE_FUN_FQ_NAME) {

}

if (resolvedCall.candidateDescriptor.fqNameSafe == MODULE_GET) {
val receiver = (resolvedCall.extensionReceiver as ExpressionReceiver).expression
val receiverRef = receiver.getReferenceTargets(bContext)
}

if (resolvedCall.hasModuleArgument()) {
val modules = resolvedCall.valueArguments.filter { it.key.isModule() }
println(modules)
}
}
}

fun ResolvedCall<*>.hasModuleArgument() =
valueArguments.keys.any { it.isModule() }

fun ValueParameterDescriptor.isModule() =
type.classDescriptor()?.fqNameSafe == FqName("lib.Module")
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package me.shika.test
package me.shika.di

import com.google.auto.service.AutoService
import me.shika.test.DiCommandLineProcessor.Companion.KEY_ENABLED
import me.shika.test.DiCommandLineProcessor.Companion.KEY_SOURCES
import me.shika.di.DiCommandLineProcessor.Companion.KEY_ENABLED
import me.shika.di.DiCommandLineProcessor.Companion.KEY_SOURCES
import org.jetbrains.kotlin.cli.common.CLIConfigurationKeys
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.codegen.extensions.ExpressionCodegenExtension
import org.jetbrains.kotlin.com.intellij.mock.MockProject
import org.jetbrains.kotlin.compiler.plugin.ComponentRegistrar
import org.jetbrains.kotlin.config.CompilerConfiguration
import org.jetbrains.kotlin.config.JVMConfigurationKeys.IR
import org.jetbrains.kotlin.extensions.CompilerConfigurationExtension
import org.jetbrains.kotlin.extensions.StorageComponentContainerContributor
import org.jetbrains.kotlin.resolve.jvm.extensions.AnalysisHandlerExtension

@AutoService(ComponentRegistrar::class)
Expand All @@ -24,17 +24,27 @@ class DiCompilerComponentRegistrar: ComponentRegistrar {
sourcesDir.deleteRecursively()
sourcesDir.mkdirs()

CompilerConfigurationExtension.registerExtension(
project,
object : CompilerConfigurationExtension {
override fun updateConfiguration(configuration: CompilerConfiguration) {
configuration.put(IR, false)
}
})
// CompilerConfigurationExtension.registerExtension(
// project,
// object : CompilerConfigurationExtension {
// override fun updateConfiguration(configuration: CompilerConfiguration) {
// configuration.put(IR, false)
// }
// })

AnalysisHandlerExtension.registerExtension(
project,
DiCompilerAnalysisExtension(sourcesDir = sourcesDir, reporter = reporter)
DiCompilerAnalysisExtension(sourcesDir = sourcesDir)
)

ExpressionCodegenExtension.registerExtension(
project,
DiCodegenExtension()
)

StorageComponentContainerContributor.registerExtension(
project,
DiCompilerStorageContributor()
)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package me.shika.di

import org.jetbrains.kotlin.container.StorageComponentContainer
import org.jetbrains.kotlin.container.useInstance
import org.jetbrains.kotlin.descriptors.ModuleDescriptor
import org.jetbrains.kotlin.extensions.StorageComponentContainerContributor
import org.jetbrains.kotlin.platform.TargetPlatform

class DiCompilerStorageContributor : StorageComponentContainerContributor {
override fun registerModuleComponents(
container: StorageComponentContainer,
platform: TargetPlatform,
moduleDescriptor: ModuleDescriptor
) {
container.useInstance(DiCompilerCallChecker())
}
}

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package me.shika.test
package me.shika.di

import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package me.shika.di.model

import org.jetbrains.kotlin.types.KotlinType

data class GraphNode(val value: Binding, val dependencies: List<GraphNode>)

sealed class Binding {
abstract val type: KotlinType

data class Instance(override val type: KotlinType): Binding()
data class Function(val from: List<KotlinType>, override val type: KotlinType): Binding()
data class Constructor(val from: List<KotlinType>, override val type: KotlinType): Binding()
}
Loading