Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/dit-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,12 @@ static struct ggml_tensor * dit_ggml_build_self_attn(
// K/V come in F32 from mul_mat (no KV cache here). Cast to F16 before FA,
// mirroring llama.cpp build_attn_mha for graphs without a KV cache.
if (m->use_flash_attn) {
if (k->type == GGML_TYPE_F32) k = ggml_cast(ctx, k, GGML_TYPE_F16);
if (v->type == GGML_TYPE_F32) v = ggml_cast(ctx, v, GGML_TYPE_F16);
if (k->type == GGML_TYPE_F32) {
k = ggml_cast(ctx, k, GGML_TYPE_F16);
}
if (v->type == GGML_TYPE_F32) {
v = ggml_cast(ctx, v, GGML_TYPE_F16);
}
}

struct ggml_tensor * attn = m->use_flash_attn ? ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0.0f, 0.0f) :
Expand Down Expand Up @@ -333,8 +337,12 @@ static struct ggml_tensor * dit_ggml_build_cross_attn(struct ggml_context * ctx,
// K/V come in F32 from mul_mat (no KV cache here). Cast to F16 before FA,
// mirroring llama.cpp build_attn_mha for graphs without a KV cache.
if (m->use_flash_attn) {
if (k->type == GGML_TYPE_F32) k = ggml_cast(ctx, k, GGML_TYPE_F16);
if (v->type == GGML_TYPE_F32) v = ggml_cast(ctx, v, GGML_TYPE_F16);
if (k->type == GGML_TYPE_F32) {
k = ggml_cast(ctx, k, GGML_TYPE_F16);
}
if (v->type == GGML_TYPE_F32) {
v = ggml_cast(ctx, v, GGML_TYPE_F16);
}
}

struct ggml_tensor * attn = m->use_flash_attn ? ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0.0f, 0.0f) :
Expand Down
49 changes: 37 additions & 12 deletions src/pipeline-synth-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "task-types.h"
#include "vae-enc.h"

#include <cerrno>
#include <charconv>
#include <cstdio>
#include <cstdlib>
Expand All @@ -22,22 +23,23 @@

static const int FRAMES_PER_SECOND = 25;

// CSV list parser tolerant to any whitespace around commas. Locale-immune via
// std::from_chars (C++17 charconv, overloaded on the numeric type). Used for
// audio_codes (int) and custom_timesteps (float). Bails on first parse error
// or overflow, returning the values consumed so far.
template <typename T> static std::vector<T> parse_csv(const std::string & s) {
std::vector<T> out;
const char * first = s.data();
const char * last = first + s.size();
// CSV list parsers tolerant to any whitespace around commas. Bail on first
// parse error or overflow, returning the values consumed so far.
// Integer variant uses std::from_chars (locale-immune, C++17 charconv).
// Float variant uses std::strtof for portability (Apple Clang lacks the
// floating-point overload of std::from_chars on some SDK versions).
static std::vector<int> parse_csv_int(const std::string & s) {
std::vector<int> out;
const char * first = s.data();
const char * last = first + s.size();
while (first < last) {
while (first < last && (*first == ',' || *first == ' ')) {
first++;
++first;
}
if (first == last) {
break;
}
T v{};
int v{};
auto r = std::from_chars(first, last, v);
if (r.ec != std::errc{}) {
break;
Expand All @@ -48,6 +50,29 @@ template <typename T> static std::vector<T> parse_csv(const std::string & s) {
return out;
}

static std::vector<float> parse_csv_float(const std::string & s) {
std::vector<float> out;
const char * first = s.data();
const char * last = first + s.size();
while (first < last) {
while (first < last && (*first == ',' || *first == ' ')) {
++first;
}
if (first == last) {
break;
}
char * end = nullptr;
errno = 0;
float v = std::strtof(first, &end);
if (end == first || errno == ERANGE) {
break;
}
out.push_back(v);
first = end;
}
return out;
}

int ops_encode_src(const AceSynth * ctx,
const float * src_audio,
int src_len,
Expand Down Expand Up @@ -199,7 +224,7 @@ int ops_resolve_params(const AceSynth * ctx, const AceRequest * reqs, int batch_
s.max_codes_len = 0;
s.have_codes = false;
for (int b = 0; b < batch_n; b++) {
s.per_codes[b] = parse_csv<int>(reqs[b].audio_codes);
s.per_codes[b] = parse_csv_int(reqs[b].audio_codes);
int sz = (int) s.per_codes[b].size();
if (sz > s.max_codes_len) {
s.max_codes_len = sz;
Expand All @@ -222,7 +247,7 @@ void ops_build_schedule(SynthState & s) {
// endpoint handled implicitly by the sampler, so we drop it and take
// schedule = first N-1 entries, num_steps = N-1.
if (!s.rr.custom_timesteps.empty()) {
std::vector<float> ts = parse_csv<float>(s.rr.custom_timesteps);
std::vector<float> ts = parse_csv_float(s.rr.custom_timesteps);
if (ts.size() >= 2) {
s.num_steps = (int) ts.size() - 1;
s.schedule.assign(ts.begin(), ts.end() - 1);
Expand Down
8 changes: 6 additions & 2 deletions src/qwen3-enc.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,12 @@ static struct ggml_tensor * qwen3_build_self_attn(struct ggml_context * ctx,
// K/V come in F32 from mul_mat (encoder, no KV cache). Cast to F16 before FA,
// mirroring llama.cpp build_attn_mha for graphs without a KV cache.
if (use_flash_attn) {
if (k->type == GGML_TYPE_F32) k = ggml_cast(ctx, k, GGML_TYPE_F16);
if (v->type == GGML_TYPE_F32) v = ggml_cast(ctx, v, GGML_TYPE_F16);
if (k->type == GGML_TYPE_F32) {
k = ggml_cast(ctx, k, GGML_TYPE_F16);
}
if (v->type == GGML_TYPE_F32) {
v = ggml_cast(ctx, v, GGML_TYPE_F16);
}
}

struct ggml_tensor * attn = use_flash_attn ? ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0.0f, 0.0f) :
Expand Down
Loading