Skip to content

Commit d13fb81

Browse files
author
Daniil Tcelikin
committed
added loading models from assets
1 parent 1362c15 commit d13fb81

File tree

6 files changed

+104
-24
lines changed

6 files changed

+104
-24
lines changed

android/android.md

Lines changed: 0 additions & 12 deletions
This file was deleted.

android/app/build.gradle.kts

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import org.jetbrains.kotlin.de.undercouch.gradle.tasks.download.Download
2+
13
@Suppress("DSL_SCOPE_VIOLATION") // TODO: Remove once KTIJ-19369 is fixed
24
plugins {
35
alias(libs.plugins.androidApplication)
46
alias(libs.plugins.kotlinAndroid)
7+
alias(libs.plugins.download)
58
}
69

710
android {
@@ -29,9 +32,19 @@ android {
2932
sourceCompatibility = JavaVersion.VERSION_1_8
3033
targetCompatibility = JavaVersion.VERSION_1_8
3134
}
35+
36+
sourceSets {
37+
maybeCreate("main").apply {
38+
assets {
39+
srcDirs("src/main/assets")
40+
}
41+
}
42+
}
43+
3244
kotlinOptions {
3345
jvmTarget = "1.8"
3446
}
47+
3548
externalNativeBuild {
3649
cmake {
3750
path = file("src/main/cpp/CMakeLists.txt")
@@ -43,7 +56,32 @@ android {
4356
}
4457
}
4558

59+
60+
tasks {
61+
val downloadTokenizer by creating(Download::class) {
62+
onlyIf { !file("$projectDir/src/main/assets/tokenizer.bin").exists() }
63+
src("https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin")
64+
dest("$projectDir/src/main/assets/tokenizer.bin")
65+
}
66+
val downloadModel by creating(Download::class) {
67+
onlyIf { !file("$projectDir/src/main/assets/stories15M.bin").exists() }
68+
src("https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin")
69+
dest("$projectDir/src/main/assets/stories15M.bin")
70+
}
71+
whenTaskAdded {
72+
if (name in listOf("assembleDebug", "assembleRelease")) {
73+
dependsOn(downloadTokenizer)
74+
dependsOn(downloadModel)
75+
}
76+
}
77+
}
78+
4679
dependencies {
80+
implementation(libs.androidx.activity.ktx)
81+
implementation(libs.androidx.lifecycle.viewmodel.ktx)
82+
implementation(libs.androidx.lifecycle.livedata.ktx)
83+
implementation(libs.androidx.lifecycle.runtime.ktx)
84+
4785
implementation(libs.coroutines.core)
4886
implementation(libs.core.ktx)
4987
implementation(libs.appcompat)

android/app/src/main/java/com/celikin/llama2/MainActivity.kt

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ package com.celikin.llama2
22

33
import androidx.appcompat.app.AppCompatActivity
44
import android.os.Bundle
5+
import androidx.lifecycle.lifecycleScope
56
import com.celikin.llama2.databinding.ActivityMainBinding
67
import com.celikin.llama2.wrapper.InferenceRunner
78
import com.celikin.llama2.wrapper.InferenceRunnerManager
9+
import kotlinx.coroutines.launch
810

911
class MainActivity : AppCompatActivity() {
1012

@@ -26,15 +28,21 @@ class MainActivity : AppCompatActivity() {
2628
val prompt = binding.promptEdit.text.toString()
2729
inferenceRunnerManager.run(prompt)
2830
}
29-
initInference()
31+
32+
lifecycleScope.launch {
33+
val assetsFolder = copyAssets(arrayOf("stories15M.bin", "tokenizer.bin"))
34+
initInference(assetsFolder)
35+
}
36+
3037
}
3138

3239
private fun updateText(token: String) {
3340
binding.sampleText.text = "${binding.sampleText.text}$token"
3441
}
3542

36-
private fun initInference() {
37-
inferenceRunnerManager = InferenceRunnerManager().apply { init(callback) }
43+
private fun initInference(assetsFolder: String) {
44+
inferenceRunnerManager = InferenceRunnerManager()
45+
.apply { init(callback, assetsFolder) }
3846
}
3947

4048
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package com.celikin.llama2
2+
3+
import android.content.Context
4+
import android.util.Log
5+
import java.io.File
6+
import java.io.FileOutputStream
7+
8+
9+
private fun userAssetPath(context: Context?): String {
10+
if (context == null)
11+
return ""
12+
val extDir = context.getExternalFilesDir("assets")
13+
?: return context.getDir("assets", 0).absolutePath
14+
return extDir.absolutePath
15+
}
16+
17+
fun Context.copyAssets(listFiles: Array<String>):String {
18+
val extFolder = userAssetPath(this)
19+
try {
20+
assets.list("")
21+
?.filter { listFiles.contains(it) }
22+
?.filter { !File(extFolder, it).exists() }
23+
?.forEach {
24+
val target = File(extFolder, it)
25+
assets.open(it).use { input ->
26+
FileOutputStream(target).use { output ->
27+
input.copyTo(output)
28+
Log.i("Utils", "Copied from apk assets folder to ${target.absolutePath}")
29+
}
30+
}
31+
}
32+
} catch (e: Exception) {
33+
Log.e("Utils", "asset copy failed", e)
34+
}
35+
return extFolder
36+
}

android/app/src/main/java/com/celikin/llama2/wrapper/InferenceRunnerManager.kt

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,30 @@ import kotlinx.coroutines.SupervisorJob
77
import kotlinx.coroutines.launch
88

99
class InferenceRunnerManager {
10+
private lateinit var folderPath: String
1011
private val applicationScope = CoroutineScope(Dispatchers.IO + SupervisorJob())
1112

12-
fun init(callback: InferenceRunner.InferenceCallback) {
13+
fun init(callback: InferenceRunner.InferenceCallback, folderPath: String) {
14+
this.folderPath = folderPath
1315
InferenceRunner.setInferenceCallback(callback)
1416
}
1517

1618
fun run(
1719
prompt: String = "",
1820
temperature: Float = 0.9f,
1921
steps: Int = 256,
20-
checkpoint: String = "/data/local/tmp/stories15M.bin",
21-
tokenizer: String = "/data/local/tmp/tokenizer.bin",
22-
ompthreads: Int = 4,
22+
checkpointFileName: String = "stories15M.bin",
23+
tokenizerFileName: String = "tokenizer.bin",
24+
ompThreads: Int = 4,
2325
) {
2426
applicationScope.launch {
2527
InferenceRunner.run(
26-
checkpoint = checkpoint,
27-
tokenizer = tokenizer,
28+
checkpoint = "$folderPath/$checkpointFileName",
29+
tokenizer = "$folderPath/$tokenizerFileName",
2830
temperature = temperature,
2931
steps = steps,
3032
prompt = prompt,
31-
ompthreads = ompthreads
33+
ompthreads = ompThreads
3234
)
3335
}
3436
}

android/gradle/libs.versions.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
[versions]
2+
activity-ktx = "1.7.2"
23
agp = "8.2.0-alpha12"
34
kotlin = "1.8.21"
4-
core-ktx = "1.9.0"
5+
core-ktx = "1.10.1"
6+
download = "5.3.1"
57
junit = "4.13.2"
68
androidx-test-ext-junit = "1.1.5"
79
espresso-core = "3.5.1"
810
appcompat = "1.6.1"
11+
lifecycle-viewmodel-ktx = "2.6.1"
912
material = "1.9.0"
1013
constraintlayout = "2.1.4"
11-
kotlinx-coroutines-core = "1.6.4"
14+
kotlinx-coroutines-core = "1.7.1"
1215
[libraries]
16+
androidx-activity-ktx = { module = "androidx.activity:activity-ktx", version.ref = "activity-ktx" }
17+
androidx-lifecycle-livedata-ktx = { module = "androidx.lifecycle:lifecycle-livedata-ktx", version.ref = "lifecycle-viewmodel-ktx" }
18+
androidx-lifecycle-runtime-ktx = { module = "androidx.lifecycle:lifecycle-runtime-ktx", version.ref = "lifecycle-viewmodel-ktx" }
19+
androidx-lifecycle-viewmodel-ktx = { module = "androidx.lifecycle:lifecycle-viewmodel-ktx", version.ref = "lifecycle-viewmodel-ktx" }
1320
core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "core-ktx" }
1421
junit = { group = "junit", name = "junit", version.ref = "junit" }
1522
androidx-test-ext-junit = { group = "androidx.test.ext", name = "junit", version.ref = "androidx-test-ext-junit" }
@@ -22,4 +29,5 @@ coroutines-core = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-
2229
[plugins]
2330
androidApplication = { id = "com.android.application", version.ref = "agp" }
2431
kotlinAndroid = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
32+
download = { id = "de.undercouch.download", version.ref = "download" }
2533

0 commit comments

Comments
 (0)