Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@ include("skainet-test:skainet-test-java")
include("skainet-apps:skainet-grayscale-cli")
include("skainet-apps:skainet-tensor-tools")
include("skainet-io:skainet-io-safetensors")
include("skainet-io:skainet-io-iree-params")
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ public final class sk/ainet/compile/hlo/ConversionContext {
public final fun getModuleDeclarations ()Ljava/lang/String;
public final fun getTypeMapper ()Lsk/ainet/compile/hlo/TypeMapper;
public final fun getValueName (Ljava/lang/String;)Ljava/lang/String;
public final fun getValueType (Ljava/lang/String;)Ljava/lang/String;
public final fun nextTempValue ()Ljava/lang/String;
public final fun registerExternalParameter (Lsk/ainet/compile/hlo/ExternalParameterRef;)V
public final fun setGraph (Lsk/ainet/lang/graph/ComputeGraph;)V
public final fun setValueName (Ljava/lang/String;Ljava/lang/String;)V
public final fun setValueType (Ljava/lang/String;Ljava/lang/String;)V
}

public abstract class sk/ainet/compile/hlo/ConversionResult {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,17 @@ public class MlirValidator {
inFunction = true
}

// Check for basic SSA value format
if (trimmed.contains(" = ") && !validateSSAValue(trimmed)) {
// Check for basic SSA value format. Module-scope global
// declarations (`util.global private @name = #flow.parameter.named<...> : ...`)
// look syntactically like assignments but bind `@`-prefixed
// symbols, not `%`-prefixed SSA values. Treated as a
// separate syntactic category — the IREE external-weight
// path (issue #523) depends on these lines passing
// validation.
if (trimmed.contains(" = ") &&
!trimmed.startsWith("util.global") &&
!validateSSAValue(trimmed)
) {
errors.add("Line ${lineNum + 1}: Invalid SSA value format")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,18 @@ public class ConstantOperationsConverter : StableHloOperationConverter {
source = BufferHandle.Owned(bytes)
)
)
context.emitModuleDeclaration("util.global private @${key} : $outputType")
// Bind the global to an archive entry via the IREE flow-dialect
// parameter attribute. Without this initializer, iree-compile
// treats the util.global as uninitialized and does not pull
// bytes from the .irpa file at --iree-opt-import-parameters
// time. Scope is carried in the MLIR reference (left of `::`)
// and resolved at iree-compile / iree-run-module time against
// `--parameters=<scope>=<file>.irpa`. The .irpa file itself
// stores a flat key table with no scope column.
context.emitModuleDeclaration(
"util.global private @${key} = " +
"#flow.parameter.named<\"${scope}\"::\"${key}\"> : $outputType"
)

val resultValue = context.nextTempValue()
val operation = "$resultValue = util.global.load @${key} : $outputType"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,14 @@ class ConstantMaterializationPolicyTest {
ConstantMaterializationPolicy.ExternalAlways()
).convert(buildTensorConstantGraph(), "external_policy")

// The util.global must carry a #flow.parameter.named initializer
// so iree-compile binds the declaration to an archive entry at
// --iree-opt-import-parameters time (see issue #523).
assertTrue(
module.content.contains("util.global private @weights : tensor<2x2xf32>"),
module.content.contains(
"util.global private @weights = " +
"#flow.parameter.named<\"model\"::\"weights\"> : tensor<2x2xf32>"
),
"module decl missing:\n${module.content}"
)
assertTrue(
Expand Down Expand Up @@ -110,7 +116,10 @@ class ConstantMaterializationPolicyTest {
)
// Large is externalized.
assertTrue(
module.content.contains("util.global private @large_w : tensor<4x4xf32>"),
module.content.contains(
"util.global private @large_w = " +
"#flow.parameter.named<\"model\"::\"large_w\"> : tensor<4x4xf32>"
),
"large tensor must externalize under SizeThreshold:\n${module.content}"
)

Expand Down
69 changes: 69 additions & 0 deletions skainet-io/skainet-io-iree-params/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import org.jetbrains.kotlin.gradle.ExperimentalWasmDsl
import org.jetbrains.kotlin.gradle.dsl.JvmTarget

plugins {
alias(libs.plugins.kotlinMultiplatform)
alias(libs.plugins.androidMultiplatformLibrary)
alias(libs.plugins.vanniktech.mavenPublish)
id("sk.ainet.dokka")
}

kotlin {
targets.configureEach {
compilations.configureEach {
compileTaskProvider.get().compilerOptions {
freeCompilerArgs.add("-Xexpect-actual-classes")
}
}
}

jvm()
android {
namespace = "sk.ainet.io.irpa"
compileSdk = libs.versions.android.compileSdk.get().toInt()
minSdk = libs.versions.android.minSdk.get().toInt()
compilerOptions {
jvmTarget.set(JvmTarget.JVM_1_8)
}
}

iosArm64()
iosSimulatorArm64()
macosArm64()
linuxX64()
linuxArm64()

js {
browser()
}

@OptIn(ExperimentalWasmDsl::class)
wasmJs {
browser()
}

@OptIn(ExperimentalWasmDsl::class)
wasmWasi {
nodejs()
}

sourceSets {
val commonMain by getting {
dependencies {
implementation(libs.kotlinx.io.core)
implementation(project(":skainet-lang:skainet-lang-core"))
implementation(project(":skainet-compile:skainet-compile-hlo"))
}
}
val commonTest by getting {
dependencies {
implementation(libs.kotlin.test)
}
}
val jvmTest by getting {
dependencies {
implementation(libs.junit)
}
}
}
}
2 changes: 2 additions & 0 deletions skainet-io/skainet-io-iree-params/gradle.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
POM_ARTIFACT_ID=skainet-io-iree-params
POM_NAME=skainet IO IREE Params
Loading
Loading