diff --git a/alembic/versions/a1b2c3d4e5f6_gemmgemm_config_transC_transO_gemmO.py b/alembic/versions/a1b2c3d4e5f6_gemmgemm_config_transC_transO_gemmO.py new file mode 100644 index 000000000..cb3a98731 --- /dev/null +++ b/alembic/versions/a1b2c3d4e5f6_gemmgemm_config_transC_transO_gemmO.py @@ -0,0 +1,35 @@ +"""gemmgemm_config transC transO gemmO + +Revision ID: a1b2c3d4e5f6 +Revises: 4ce656722c5d +Create Date: 2026-03-05 + +Add transpose_C, transpose_O, gemm_o to rocmlir_gemmgemm_config for compatibility +with rocMLIR tier1-gemmgemm-configs format (-transC, -transO, -gemmO). +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'a1b2c3d4e5f6' +down_revision = '4ce656722c5d' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + 'rocmlir_gemmgemm_config', + sa.Column('transpose_C', sa.Boolean(), nullable=False, server_default='0')) + op.add_column( + 'rocmlir_gemmgemm_config', + sa.Column('transpose_O', sa.Boolean(), nullable=False, server_default='0')) + op.add_column( + 'rocmlir_gemmgemm_config', + sa.Column('gemm_o', sa.Integer(), nullable=False, server_default='0')) + + +def downgrade() -> None: + op.drop_column('rocmlir_gemmgemm_config', 'gemm_o') + op.drop_column('rocmlir_gemmgemm_config', 'transpose_O') + op.drop_column('rocmlir_gemmgemm_config', 'transpose_C') diff --git a/tuna/rocmlir/config_type.py b/tuna/rocmlir/config_type.py index f55fe4ac4..e4c965d21 100644 --- a/tuna/rocmlir/config_type.py +++ b/tuna/rocmlir/config_type.py @@ -33,6 +33,7 @@ class ConfigType(Enum): convolution: str = 'convolution' gemm: str = 'gemm' attention: str = 'attention' + gemm_gemm: str = 'gemm_gemm' def __str__(self) -> str: return self.value diff --git a/tuna/rocmlir/rocmlir_tables.py b/tuna/rocmlir/rocmlir_tables.py index e3d62877d..bd8197dbc 100644 --- a/tuna/rocmlir/rocmlir_tables.py +++ b/tuna/rocmlir/rocmlir_tables.py @@ -786,6 +786,159 @@ class AttentionResults(BASE, ResultsMixin): # pylint: disable=too-many-instance index=True) +class GemmGemmJob(BASE, JobMixin): + """Represents gemm_gemm job table""" + __tablename__ = "rocmlir_gemmgemm_job" + __table_args__ = (UniqueConstraint('config', 'session', name="uq_idx"),) + + config = Column(Integer, + ForeignKey("rocmlir_gemmgemm_config.id"), + nullable=False, + index=True) + + +class GemmGemmConfig(BASE, SimpleCSVMixin): + """Represents GemmGemm config table""" + __tablename__ = "rocmlir_gemmgemm_config" + + data_type = Column(String(length=60), nullable=False, server_default="") + out_data_type = Column(String(length=60), nullable=False, server_default="") + group_size = Column(Integer, nullable=False, server_default="0") + m = Column(Integer, nullable=False, server_default="0") + n = Column(Integer, nullable=False, server_default="0") + k = Column(Integer, nullable=False, server_default="0") + transpose_A = Column(Boolean, nullable=False, server_default="0") + transpose_B = Column(Boolean, nullable=False, server_default="0") + transpose_C = Column(Boolean, nullable=False, server_default="0") + transpose_O = Column(Boolean, nullable=False, server_default="0") + gemm_o = Column(Integer, nullable=False, server_default="0") + kernel_repeats = Column(Integer, nullable=False, server_default="0") + + def __repr__(self) -> str: + return f"GemmGemmConfig {self.to_dict()}" + + options = { + 'data_type': '-t', + 'out_data_type': '-out_datatype', + 'transpose_A': '-transA', + 'transpose_B': '-transB', + 'transpose_C': '-transC', + 'transpose_O': '-transO', + 'group_size': '-g', + 'm': '-m', + 'n': '-n', + 'k': '-k', + 'gemm_o': '-gemmO', + 'kernel_repeats': None, + 'id': None, + 'valid': None + } + + def config_string(self): + """Return config as a flag/value string suitable for tuningRunner.py.""" + string = "" + for field, flag in self.options.items(): + value = getattr(self, field, None) + if value is not None and flag is not None: + string += f"{flag} {value} " + return string.strip() + + def parse_line(self, line): + """Parse a command-line-style gemm_gemm config into a GemmGemmConfig object.""" + + print(f"Parsing line {line}") + + i = iter(line.split()) + options = dict(zip(i, i)) + + fields = { + '-transA': 'transpose_A', + '-transB': 'transpose_B', + '-transC': 'transpose_C', + '-transO': 'transpose_O', + '-g': 'group_size', + '-m': 'm', + '-n': 'n', + '-k': 'k', + '-t': 'data_type', + '-out_datatype': 'out_data_type', + '-gemmO': 'gemm_o', + } + + self.kernel_repeats = 1 + for flag, value in options.items(): + if flag not in fields: + continue + if value in ["true", "True"]: + value = 1 + if value in ["false", "False"]: + value = 0 + field = fields[flag] + if field: + setattr(self, field, value) + + def get_configurations(self, filename): + """Read gemm_gemm-configs from filename and expand into all combinations of + type and transpose. + """ + + DATA_TYPES = ['f32', 'f16', 'i8'] + + configs = [] + with open(filename, 'r', encoding='utf8') as config_file: + lines = config_file.readlines() + + for datatype, transA, transB, line in \ + itertools.product(DATA_TYPES, ['false', 'true'], + ['false', 'true'], lines): + line = line.strip() + + if len(line) == 0 or line[0] == '#': + continue + + dataTypeString = "" + if "-t " not in line: + dataTypeString = f"-t {datatype} " + + transAString = "" + if "-transA " not in line: + transAString = f"-transA {transA} " + + transBString = "" + if "-transB " not in line: + transBString = f"-transB {transB} " + + outDataTypeString = "" + if "-out_datatype" not in line: + outDataTypeString = f"-out_datatype {datatype} " + + one_config = f"{dataTypeString}{outDataTypeString}\ + {transAString}{transBString}{line}".strip() + if one_config not in configs: + configs.append(one_config) + + if "-out_datatype" not in line and datatype == 'i8': + outDataTypeString = "-out_datatype i32 " + one_config = f"{dataTypeString}{outDataTypeString}\ + {transAString}{transBString}{line}".strip() + if one_config not in configs: + configs.append(one_config) + + return configs + + +class GemmGemmResults(BASE, ResultsMixin): # pylint: disable=too-many-instance-attributes + """Collects the results of GemmGemm tuning.""" + + __tablename__ = "rocmlir_gemmgemm_results" + __table_args__ = (UniqueConstraint("config_str", "session", name="uq_idx"),) + + config = Column(Integer, + ForeignKey("rocmlir_gemmgemm_config.id"), + nullable=False, + index=True) + + #pylint: disable=too-few-public-methods class RocMLIRDBTables(DBTablesInterface): """Represents db tables for rocMLIR lib""" @@ -819,6 +972,10 @@ def set_tables(self, sess_class=None): self.job_table = AttentionJob self.config_table = AttentionConfig self.results = AttentionResults + elif self.config_type == ConfigType.gemm_gemm: + self.job_table = GemmGemmJob + self.config_table = GemmGemmConfig + self.results = GemmGemmResults else: raise ValueError(f"Config type {self.config_type} not yet supported.") @@ -846,6 +1003,9 @@ def append_if_not_exists(table): append_if_not_exists(AttentionConfig()) append_if_not_exists(AttentionJob()) append_if_not_exists(AttentionResults()) + append_if_not_exists(GemmGemmConfig()) + append_if_not_exists(GemmGemmJob()) + append_if_not_exists(GemmGemmResults()) return tables diff --git a/tuna/rocmlir/rocmlir_worker.py b/tuna/rocmlir/rocmlir_worker.py index 72733c666..a94643702 100644 --- a/tuna/rocmlir/rocmlir_worker.py +++ b/tuna/rocmlir/rocmlir_worker.py @@ -180,6 +180,8 @@ def run_cmd(self): special_args = "--operation gemm" elif self.dbt.config_type == ConfigType.attention: special_args = "--operation attention --verify-mode none" + elif self.dbt.config_type == ConfigType.gemm_gemm: + special_args = "--operation gemm_gemm" else: raise ValueError(f"Config type {self.dbt.config_type} not yet supported.") if self.dbt.session.tuning_space: