Skip to content

Commit 575ef3d

Browse files
authored
Local Model form dropdown selection (#156)
1 parent 570247a commit 575ef3d

File tree

15 files changed

+127
-25
lines changed

15 files changed

+127
-25
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package com.shifthackz.aisdv1.core.common.model
2+
3+
import java.io.Serializable
4+
5+
data class Quintuple<out A, out B, out C, out D, out E>(
6+
val first: A,
7+
val second: B,
8+
val third: C,
9+
val fourth: D,
10+
val fifth: E,
11+
) : Serializable {
12+
13+
override fun toString(): String = "($first, $second, $third, $fourth, $fifth)"
14+
}

data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager
1111
import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao
1212
import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity
1313
import io.reactivex.rxjava3.core.Completable
14+
import io.reactivex.rxjava3.core.Flowable
1415
import io.reactivex.rxjava3.core.Observable
1516
import io.reactivex.rxjava3.core.Single
1617
import java.io.File
@@ -21,6 +22,7 @@ internal class DownloadableModelLocalDataSource(
2122
private val preferenceManager: PreferenceManager,
2223
private val buildInfoProvider: BuildInfoProvider,
2324
) : DownloadableModelDataSource.Local {
25+
2426
override fun getAll(): Single<List<LocalAiModel>> = dao.query()
2527
.map(List<LocalModelEntity>::mapEntityToDomain)
2628
.map { models ->
@@ -45,6 +47,17 @@ internal class DownloadableModelLocalDataSource(
4547
.flatMap(::getById)
4648
.onErrorResumeNext { Single.error(Throwable("No selected model")) }
4749

50+
override fun observeAll(): Flowable<List<LocalAiModel>> = dao
51+
.observe()
52+
.map(List<LocalModelEntity>::mapEntityToDomain)
53+
.map { models ->
54+
buildList {
55+
addAll(models)
56+
if (buildInfoProvider.type == BuildType.FOSS) add(LocalAiModel.CUSTOM)
57+
}
58+
}
59+
.flatMap { models -> models.withLocalData().toFlowable() }
60+
4861
override fun select(id: String): Completable = Completable.fromAction {
4962
preferenceManager.localModelId = id
5063
}

data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,8 @@ internal class DownloadableModelRepositoryImpl(
2525
.onErrorResumeNext { localDataSource.getAll() }
2626

2727
override fun getById(id: String) = localDataSource.getById(id)
28+
29+
override fun observeAll() = localDataSource.observeAll()
30+
2831
override fun select(id: String) = localDataSource.select(id)
2932
}

domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.domain.datasource
33
import com.shifthackz.aisdv1.domain.entity.DownloadState
44
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
55
import io.reactivex.rxjava3.core.Completable
6+
import io.reactivex.rxjava3.core.Flowable
67
import io.reactivex.rxjava3.core.Observable
78
import io.reactivex.rxjava3.core.Single
89

@@ -17,6 +18,7 @@ sealed interface DownloadableModelDataSource {
1718
fun getAll(): Single<List<LocalAiModel>>
1819
fun getById(id: String): Single<LocalAiModel>
1920
fun getSelected(): Single<LocalAiModel>
21+
fun observeAll(): Flowable<List<LocalAiModel>>
2022
fun select(id: String): Completable
2123
fun save(list: List<LocalAiModel>): Completable
2224
fun isDownloaded(id: String): Single<Boolean>

domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase
3434
import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCaseImpl
3535
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase
3636
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCaseImpl
37+
import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase
38+
import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCaseImpl
3739
import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCase
3840
import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCaseImpl
3941
import com.shifthackz.aisdv1.domain.usecase.gallery.GetAllGalleryUseCase
@@ -136,6 +138,7 @@ internal val useCasesModule = module {
136138
factoryOf(::ObserveLocalDiffusionProcessStatusUseCaseImpl) bind ObserveLocalDiffusionProcessStatusUseCase::class
137139
factoryOf(::GetLocalAiModelsUseCaseImpl) bind GetLocalAiModelsUseCase::class
138140
factoryOf(::DownloadModelUseCaseImpl) bind DownloadModelUseCase::class
141+
factoryOf(::ObserveLocalAiModelsUseCaseImpl) bind ObserveLocalAiModelsUseCase::class
139142
factoryOf(::DeleteModelUseCaseImpl) bind DeleteModelUseCase::class
140143
factoryOf(::AcquireWakelockUseCaseImpl) bind AcquireWakelockUseCase::class
141144
factoryOf(::ReleaseWakeLockUseCaseImpl) bind ReleaseWakeLockUseCase::class

domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.domain.repository
33
import com.shifthackz.aisdv1.domain.entity.DownloadState
44
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
55
import io.reactivex.rxjava3.core.Completable
6+
import io.reactivex.rxjava3.core.Flowable
67
import io.reactivex.rxjava3.core.Observable
78
import io.reactivex.rxjava3.core.Single
89

@@ -12,5 +13,6 @@ interface DownloadableModelRepository {
1213
fun delete(id: String): Completable
1314
fun getAll(): Single<List<LocalAiModel>>
1415
fun getById(id: String): Single<LocalAiModel>
16+
fun observeAll(): Flowable<List<LocalAiModel>>
1517
fun select(id: String): Completable
1618
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.shifthackz.aisdv1.domain.usecase.downloadable
2+
3+
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
4+
import io.reactivex.rxjava3.core.Flowable
5+
6+
interface ObserveLocalAiModelsUseCase {
7+
operator fun invoke(): Flowable<List<LocalAiModel>>
8+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package com.shifthackz.aisdv1.domain.usecase.downloadable
2+
3+
import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository
4+
5+
internal class ObserveLocalAiModelsUseCaseImpl(
6+
private val repository: DownloadableModelRepository,
7+
) : ObserveLocalAiModelsUseCase {
8+
9+
override fun invoke() = repository.observeAll()
10+
}
Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
package com.shifthackz.aisdv1.presentation.screen.setup.steps
22

3-
import androidx.compose.foundation.layout.Column
43
import androidx.compose.foundation.layout.Spacer
54
import androidx.compose.foundation.layout.fillMaxWidth
65
import androidx.compose.foundation.layout.height
76
import androidx.compose.foundation.layout.padding
7+
import androidx.compose.foundation.lazy.LazyColumn
8+
import androidx.compose.foundation.lazy.items
9+
import androidx.compose.foundation.lazy.rememberLazyListState
810
import androidx.compose.runtime.Composable
11+
import androidx.compose.runtime.LaunchedEffect
912
import androidx.compose.ui.Modifier
1013
import androidx.compose.ui.unit.dp
14+
import com.shifthackz.aisdv1.domain.entity.ServerSource
1115
import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent
1216
import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState
1317
import com.shifthackz.aisdv1.presentation.screen.setup.components.ConfigurationModeButton
@@ -18,21 +22,31 @@ fun SourceSelectionStep(
1822
state: ServerSetupState,
1923
processIntent: (ServerSetupIntent) -> Unit = {},
2024
) {
21-
BaseServerSetupStateWrapper(modifier) {
22-
Column {
23-
Spacer(modifier = Modifier.height(12.dp))
24-
state.allowedModes.forEach { mode ->
25-
ConfigurationModeButton(
26-
modifier = Modifier
27-
.fillMaxWidth()
28-
.padding(horizontal = 16.dp, vertical = 4.dp),
29-
state = state,
30-
mode = mode,
31-
onClick = {
32-
processIntent(ServerSetupIntent.UpdateServerMode(it))
33-
},
34-
)
35-
}
25+
val lazyListState = rememberLazyListState()
26+
LaunchedEffect(state.mode) {
27+
// Adding 1 here, because item with index == 0 is top spacer
28+
lazyListState.animateScrollToItem(state.mode.ordinal + 1)
29+
}
30+
LazyColumn(
31+
modifier = modifier,
32+
state = lazyListState,
33+
) {
34+
item(key = "SPACER_TOP") { Spacer(modifier = Modifier.height(12.dp)) }
35+
items(
36+
items = state.allowedModes,
37+
key = ServerSource::key,
38+
) { mode ->
39+
ConfigurationModeButton(
40+
modifier = Modifier
41+
.fillMaxWidth()
42+
.padding(horizontal = 16.dp, vertical = 4.dp),
43+
state = state,
44+
mode = mode,
45+
onClick = {
46+
processIntent(ServerSetupIntent.UpdateServerMode(it))
47+
},
48+
)
3649
}
50+
item(key = "SPACER_BOTTOM") { Spacer(modifier = Modifier.height(32.dp)) }
3751
}
3852
}

presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,18 @@ fun EngineSelectionComponent(
4545
onItemSelected = { intentHandler(EngineSelectionIntent(it)) },
4646
)
4747

48-
else -> Unit
48+
ServerSource.LOCAL -> DropdownTextField(
49+
label = R.string.hint_sd_model.asUiText(),
50+
loading = state.loading,
51+
modifier = modifier,
52+
value = state.localAiModels.firstOrNull { it.id == state.selectedLocalAiModelId },
53+
items = state.localAiModels,
54+
onItemSelected = { intentHandler(EngineSelectionIntent(it.id)) },
55+
displayDelegate = { it.name.asUiText() },
56+
)
57+
58+
ServerSource.HORDE -> Unit
59+
ServerSource.OPEN_AI -> Unit
4960
}
5061
}
5162
}

0 commit comments

Comments
 (0)