Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ class LlmModuleInstrumentationTest : LlmCallback {
@Test
@Throws(IOException::class, URISyntaxException::class)
fun testGenerate() {
val loadResult = llmModule.load()
// Check that the model can be load successfully
assertEquals(OK.toLong(), loadResult.toLong())
llmModule.load()

llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
assertEquals(results.size.toLong(), SEQ_LEN.toLong())
Expand Down Expand Up @@ -273,11 +271,26 @@ class LlmModuleInstrumentationTest : LlmCallback {
}
}

// --- Lifecycle tests ---

@Test
fun testUseAfterCloseThrows() {
llmModule.close()
assertThrows(IllegalStateException::class.java) {
llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
}
}

@Test
fun testCloseIsIdempotent() {
llmModule.close()
llmModule.close()
}

companion object {
private const val TEST_FILE_NAME = "/stories.pte"
private const val TOKENIZER_FILE_NAME = "/tokenizer.bin"
private const val TEST_PROMPT = "Hello"
private const val OK = 0x00
private const val SEQ_LEN = 32
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import java.util.concurrent.atomic.AtomicInteger
import org.apache.commons.io.FileUtils
import org.junit.Assert
import org.junit.Before
import org.junit.Ignore
import org.junit.Test
import org.junit.runner.RunWith
import org.pytorch.executorch.TestFileUtils.getTestFilePath
Expand All @@ -40,49 +39,42 @@ class ModuleInstrumentationTest {
inputStream.close()
}

@Ignore(
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
)
@Test
@Throws(IOException::class, URISyntaxException::class)
fun testModuleLoadAndForward() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))

val results = module.forward()
Assert.assertTrue(results[0].isTensor)
}

@Test
@Throws(IOException::class, URISyntaxException::class)
fun testMethodMetadata() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
try {
val results = module.forward(EValue.from(dummyInput()))
Assert.assertTrue(results[0].isTensor)
} finally {
module.destroy()
}
}

@Ignore(
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
)
@Test
@Throws(IOException::class)
fun testModuleLoadMethodAndForward() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
try {
module.loadMethod(FORWARD_METHOD)

val loadMethod = module.loadMethod(FORWARD_METHOD)
Assert.assertEquals(loadMethod.toLong(), OK.toLong())

val results = module.forward()
Assert.assertTrue(results[0].isTensor)
val results = module.forward(EValue.from(dummyInput()))
Assert.assertTrue(results[0].isTensor)
} finally {
module.destroy()
}
}

@Ignore(
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
)
@Test
@Throws(IOException::class)
fun testModuleLoadForwardExplicit() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))

val results = module.execute(FORWARD_METHOD)
Assert.assertTrue(results[0].isTensor)
try {
val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput()))
Assert.assertTrue(results[0].isTensor)
} finally {
module.destroy()
}
}

@Test(expected = RuntimeException::class)
Expand All @@ -95,18 +87,26 @@ class ModuleInstrumentationTest {
@Throws(IOException::class)
fun testModuleLoadMethodNonExistantMethod() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))

val loadMethod = module.loadMethod(NONE_METHOD)
Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong())
try {
val exception =
Assert.assertThrows(ExecutorchRuntimeException::class.java) {
module.loadMethod(NONE_METHOD)
}
Assert.assertEquals(
ExecutorchRuntimeException.INVALID_ARGUMENT,
exception.getErrorCode(),
)
} finally {
module.destroy()
}
}

@Test(expected = RuntimeException::class)
@Throws(IOException::class)
fun testNonPteFile() {
val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME))

val loadMethod = module.loadMethod(FORWARD_METHOD)
Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong())
module.loadMethod(FORWARD_METHOD)
}

@Test
Expand All @@ -116,27 +116,21 @@ class ModuleInstrumentationTest {

module.destroy()

val loadMethod = module.loadMethod(FORWARD_METHOD)
Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong())
Assert.assertThrows(IllegalStateException::class.java) { module.loadMethod(FORWARD_METHOD) }
}

@Test
@Throws(IOException::class)
fun testForwardOnDestroyedModule() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))

val loadMethod = module.loadMethod(FORWARD_METHOD)
Assert.assertEquals(loadMethod.toLong(), OK.toLong())
module.loadMethod(FORWARD_METHOD)

module.destroy()

val results = module.forward()
Assert.assertEquals(0, results.size.toLong())
Assert.assertThrows(IllegalStateException::class.java) { module.forward() }
}

@Ignore(
"The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward "
)
@Test
@Throws(InterruptedException::class, IOException::class)
fun testForwardFromMultipleThreads() {
Expand All @@ -150,7 +144,7 @@ class ModuleInstrumentationTest {
try {
latch.countDown()
latch.await(5000, TimeUnit.MILLISECONDS)
val results = module.forward()
val results = module.forward(EValue.from(dummyInput()))
Assert.assertTrue(results[0].isTensor)
completed.incrementAndGet()
} catch (_: InterruptedException) {}
Expand All @@ -167,6 +161,139 @@ class ModuleInstrumentationTest {
}

Assert.assertEquals(numThreads.toLong(), completed.get().toLong())
module.destroy()
}

// --- Load mode tests ---

@Test
@Throws(IOException::class)
fun testLoadWithMmapMode() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME), Module.LOAD_MODE_MMAP)
try {
val results = module.forward(EValue.from(dummyInput()))
Assert.assertTrue(results[0].isTensor)
} finally {
module.destroy()
}
}

@Test
@Throws(IOException::class)
fun testLoadWithFileMode() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME), Module.LOAD_MODE_FILE)
try {
val results = module.forward(EValue.from(dummyInput()))
Assert.assertTrue(results[0].isTensor)
} finally {
module.destroy()
}
}

// --- getMethods / getMethodMetadata tests ---

@Test
@Throws(IOException::class)
fun testGetMethods() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
try {
val methods = module.getMethods()
Assert.assertNotNull(methods)
Assert.assertTrue(methods.contains(FORWARD_METHOD))
} finally {
module.destroy()
}
}

@Test
@Throws(IOException::class)
fun testGetMethodMetadata() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
try {
val metadata = module.getMethodMetadata(FORWARD_METHOD)
Assert.assertNotNull(metadata)
Assert.assertEquals(FORWARD_METHOD, metadata.name)
Assert.assertNotNull(metadata.backends)
} finally {
module.destroy()
}
}

// --- Log buffer tests ---

@Test
@Throws(IOException::class)
fun testReadLogBuffer() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
try {
val logs = module.readLogBuffer()
Assert.assertNotNull(logs)
} finally {
module.destroy()
}
}

@Test
fun testReadLogBufferStatic() {
val logs = Module.readLogBufferStatic()
Assert.assertNotNull(logs)
}

// --- etdump test ---

@Test
@Throws(IOException::class)
fun testEtdump() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
try {
module.etdump()
} finally {
module.destroy()
}
}

// --- Destroyed-state tests for remaining methods ---

@Test
@Throws(IOException::class)
fun testGetMethodsOnDestroyedModule() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
module.destroy()
Assert.assertThrows(IllegalStateException::class.java) { module.getMethods() }
}

@Test
@Throws(IOException::class)
fun testGetMethodMetadataOnDestroyedModule() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
module.destroy()
Assert.assertThrows(IllegalStateException::class.java) {
module.getMethodMetadata(FORWARD_METHOD)
}
}

@Test
@Throws(IOException::class)
fun testReadLogBufferOnDestroyedModule() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
module.destroy()
Assert.assertThrows(IllegalStateException::class.java) { module.readLogBuffer() }
}

@Test
@Throws(IOException::class)
fun testEtdumpOnDestroyedModule() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
module.destroy()
Assert.assertThrows(IllegalStateException::class.java) { module.etdump() }
}

@Test
@Throws(IOException::class)
fun testDoubleDestroyIsSafe() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))
module.destroy()
module.destroy()
}

companion object {
Expand All @@ -175,9 +302,8 @@ class ModuleInstrumentationTest {
private const val NON_PTE_FILE_NAME = "/test.txt"
private const val FORWARD_METHOD = "forward"
private const val NONE_METHOD = "none"
private const val OK = 0x00
private const val INVALID_STATE = 0x2
private const val INVALID_ARGUMENT = 0x12
private const val ACCESS_FAILED = 0x22
private val inputShape = longArrayOf(1, 3, 224, 224)

private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,25 @@ public static ExecuTorchRuntime getRuntime() {
/**
* Validates that the given path points to a readable file.
*
* @throws RuntimeException if the file does not exist or is not readable.
* @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not
* readable.
*/
public static void validateFilePath(String path, String description) {
if (path == null) {
throw new IllegalArgumentException("Cannot load " + description + ": path is null");
}
File file = new File(path);
if (!file.canRead() || !file.isFile()) {
throw new RuntimeException("Cannot load " + description + " " + path);
if (!file.exists()) {
throw new IllegalArgumentException(
"Cannot load " + description + ": path does not exist: " + path);
}
if (!file.isFile()) {
throw new IllegalArgumentException(
"Cannot load " + description + ": path is not a file: " + path);
}
if (!file.canRead()) {
throw new IllegalArgumentException(
"Cannot load " + description + ": path is not readable: " + path);
}
}

Expand Down
Loading
Loading