Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 270 additions & 4 deletions ds4_server.c
Original file line number Diff line number Diff line change
Expand Up @@ -7617,6 +7617,8 @@ struct server {
ds4_engine *engine;
ds4_session *session;
int default_tokens;
bool batching;
int batch_size;
kv_disk_cache kv;
tool_memory tool_mem;
live_tool_state responses_live;
Expand Down Expand Up @@ -7650,6 +7652,22 @@ struct job {
job *next;
};

typedef struct {
job *j;
ds4_session *session;
char id[96];
char err[160];
buf text;
size_t plain_stream_pos;
size_t stop_scan_from;
const char *finish;
int prompt_tokens;
int completion;
int max_tokens;
uint64_t rng;
double t0;
} batch_decode_job;

/* =========================================================================
* Tool Call Text Memory.
* =========================================================================
Expand Down Expand Up @@ -10582,6 +10600,13 @@ static bool enqueue(server *s, job *j) {
return true;
}

static void job_signal_done(job *j) {
pthread_mutex_lock(&j->mu);
j->done = true;
pthread_cond_signal(&j->cv);
pthread_mutex_unlock(&j->mu);
}

static job *dequeue(server *s) {
pthread_mutex_lock(&s->mu);
while (!s->head && !s->stopping) pthread_cond_wait(&s->cv, &s->mu);
Expand All @@ -10597,16 +10622,244 @@ static job *dequeue(server *s) {
return j;
}

static job *dequeue_ready(server *s) {
pthread_mutex_lock(&s->mu);
job *j = s->head;
if (j) {
s->head = j->next;
if (!s->head) s->tail = NULL;
}
pthread_mutex_unlock(&s->mu);
if (j) j->next = NULL;
return j;
}

static bool batch_job_supported(const job *j) {
const request *r = j ? &j->req : NULL;
return r && r->stream && r->api == API_OPENAI && //!r->has_tools &&
//!ds4_think_mode_enabled(r->think_mode) &&
(r->kind == REQ_CHAT || r->kind == REQ_COMPLETION);
}

static bool batch_decode_start(server *s, batch_decode_job *b, job *j) {
memset(b, 0, sizeof(*b));
b->j = j;
b->finish = "length";
b->t0 = now_sec();
b->prompt_tokens = j->req.prompt.len;
b->err[0] = '\0';
if (ds4_session_create(&b->session, s->engine, ds4_session_ctx(s->session)) != 0) {
http_error(j->fd, s->enable_cors, 500, "failed to create batched session");
return false;
}
j->req.cache_read_tokens = 0;
j->req.cache_write_tokens = b->prompt_tokens;
if (ds4_session_sync(b->session, &j->req.prompt, b->err, sizeof(b->err)) != 0) {
http_error(j->fd, s->enable_cors, 500, b->err[0] ? b->err : "prefill failed");
return false;
}
snprintf(b->id, sizeof(b->id), "%s-%llu",
j->req.kind == REQ_CHAT ? "chatcmpl" : "cmpl",
(unsigned long long)++s->seq);
if (!sse_headers(j->fd, s->enable_cors)) {
snprintf(b->err, sizeof(b->err), "client stream write failed");
return false;
}
if (j->req.kind == REQ_CHAT && !sse_chunk(j->fd, &j->req, b->id, NULL, NULL)) {
snprintf(b->err, sizeof(b->err), "client stream write failed");
return false;
}
int room = ds4_session_ctx(b->session) - ds4_session_pos(b->session);
b->max_tokens = j->req.max_tokens;
if (b->max_tokens < 0) b->max_tokens = 0;
if (b->max_tokens > room) b->max_tokens = room;
b->rng = j->req.seed ? j->req.seed :
(((uint64_t)time(NULL) << 32) ^ ((uint64_t)s->seq << 1) ^ (uint64_t)(uintptr_t)j);
server_log(DS4_LOG_GENERATION,
"ds4-server: batching start %s prompt=%d max=%d active_limit=%d",
j->req.kind == REQ_CHAT ? "chat" : "completion",
b->prompt_tokens,
b->max_tokens,
s->batch_size);
return true;
}

static bool batch_decode_step(server *s, batch_decode_job *b) {
job *j = b->j;
if (b->completion >= b->max_tokens ||
ds4_session_pos(b->session) >= ds4_session_ctx(b->session)) {
b->finish = "length";
return true;
}

int token = ds4_session_sample(b->session, j->req.temperature, j->req.top_k,
j->req.top_p, j->req.min_p, &b->rng);
if (token == ds4_token_eos(s->engine)) {
b->finish = "stop";
return true;
}
if (ds4_session_eval(b->session, token, b->err, sizeof(b->err)) != 0) {
b->finish = "error";
return true;
}

size_t piece_len = 0;
char *piece = ds4_token_text(s->engine, token, &piece_len);
b->completion++;
buf_append(&b->text, piece, piece_len);
free(piece);

size_t stop_pos = 0, stop_len = 0;
bool hit_stop = stop_list_find_from(&j->req.stops, b->text.ptr,
b->stop_scan_from,
&stop_pos, &stop_len);
size_t stream_len = hit_stop ?
stop_pos : stop_list_stream_safe_len(&j->req.stops, b->text.len);
if (stream_len > b->text.len) stream_len = b->text.len;
stream_len = utf8_stream_safe_len(b->text.ptr, b->plain_stream_pos,
stream_len, hit_stop);
if (!hit_stop && j->req.stops.max_len > 1) {
const size_t hold = j->req.stops.max_len - 1;
b->stop_scan_from = b->text.len > hold ? b->text.len - hold : 0;
}
if (stream_len > b->plain_stream_pos) {
char *delta = xstrndup(b->text.ptr + b->plain_stream_pos,
stream_len - b->plain_stream_pos);
bool ok = sse_chunk(j->fd, &j->req, b->id, delta, NULL);
free(delta);
if (!ok) {
b->finish = "error";
snprintf(b->err, sizeof(b->err), "client stream write failed");
return true;
}
b->plain_stream_pos = stream_len;
}
if (hit_stop) {
(void)stop_len;
b->finish = "stop";
b->text.len = stop_pos;
if (b->text.ptr) b->text.ptr[b->text.len] = '\0';
ds4_session_invalidate(b->session);
return true;
}
return b->completion >= b->max_tokens;
}

static void batch_decode_cleanup(batch_decode_job *b) {
ds4_session_free(b->session);
buf_free(&b->text);
memset(b, 0, sizeof(*b));
}

static void batch_decode_finish(server *s, batch_decode_job *b) {
job *j = b->j;
if (j->req.stream && b->text.len > b->plain_stream_pos) {
char *tail = xstrndup(b->text.ptr + b->plain_stream_pos,
b->text.len - b->plain_stream_pos);
if (!sse_chunk(j->fd, &j->req, b->id, tail, NULL)) b->finish = "error";
free(tail);
}
if (j->req.stream) {
if (!sse_chunk(j->fd, &j->req, b->id, NULL, b->finish) ||
!sse_done(j->fd, &j->req, b->id, b->prompt_tokens, b->completion)) {
server_log(DS4_LOG_DEFAULT,
"ds4-server: batching final stream failed");
}
} else {
final_response(j->fd, s->enable_cors, &j->req, b->id,
b->text.ptr ? b->text.ptr : "", NULL, NULL,
b->finish, b->prompt_tokens, b->completion);
}
if (!strcmp(b->finish, "error") && b->err[0]) {
server_log(DS4_LOG_GENERATION,
"ds4-server: batching %s gen=%d finish=%s error=\"%s\" %.3fs",
j->req.kind == REQ_CHAT ? "chat" : "completion",
b->completion,
b->finish,
b->err,
now_sec() - b->t0);
} else {
server_log(DS4_LOG_GENERATION,
"ds4-server: batching %s gen=%d finish=%s %.3fs",
j->req.kind == REQ_CHAT ? "chat" : "completion",
b->completion,
b->finish,
now_sec() - b->t0);
}
batch_decode_cleanup(b);
}

static void batch_decode_remove(batch_decode_job *active, int *nactive, int idx) {
for (int i = idx + 1; i < *nactive; i++) active[i - 1] = active[i];
(*nactive)--;
}

static bool worker_batch_admit(server *s, batch_decode_job *active, int *nactive,
int cap, bool block_if_empty) {
bool admitted = false;
while (*nactive < cap) {
job *j = (*nactive == 0 && block_if_empty) ? dequeue(s) : dequeue_ready(s);
if (!j) break;
if (!batch_job_supported(j)) {
generate_job(s, j);
job_signal_done(j);
admitted = true;
continue;
}
if (!batch_decode_start(s, &active[*nactive], j)) {
batch_decode_cleanup(&active[*nactive]);
job_signal_done(j);
admitted = true;
continue;
}
(*nactive)++;
admitted = true;
}
return admitted;
}

static void *worker_main_batched(void *arg) {
server *s = arg;
int cap = s->batch_size > 0 ? s->batch_size : 2;
batch_decode_job *active = xmalloc((size_t)cap * sizeof(active[0]));
memset(active, 0, (size_t)cap * sizeof(active[0]));
int nactive = 0;
server_log(DS4_LOG_DEFAULT,
"ds4-server: continuous batching enabled batch_size=%d", cap);
for (;;) {
worker_batch_admit(s, active, &nactive, cap, nactive == 0);
if (nactive == 0) {
pthread_mutex_lock(&s->mu);
bool stopping = s->stopping && !s->head;
pthread_mutex_unlock(&s->mu);
if (stopping) break;
continue;
}
worker_batch_admit(s, active, &nactive, cap, false);
for (int i = 0; i < nactive;) {
bool done = batch_decode_step(s, &active[i]);
if (!done) {
i++;
continue;
}
job *j = active[i].j;
batch_decode_finish(s, &active[i]);
job_signal_done(j);
batch_decode_remove(active, &nactive, i);
}
}
free(active);
return NULL;
}

static void *worker_main(void *arg) {
server *s = arg;
if (s->batching) return worker_main_batched(arg);
for (;;) {
job *j = dequeue(s);
if (!j) break;
generate_job(s, j);
pthread_mutex_lock(&j->mu);
j->done = true;
pthread_cond_signal(&j->cv);
pthread_mutex_unlock(&j->mu);
job_signal_done(j);
}
return NULL;
}
Expand Down Expand Up @@ -10914,6 +11167,8 @@ typedef struct {
bool disable_exact_dsml_tool_replay;
int tool_memory_max_ids;
bool enable_cors;
bool batching;
int batch_size;
} server_config;

static int parse_int_arg(const char *s, const char *opt) {
Expand Down Expand Up @@ -11029,6 +11284,10 @@ static void usage(FILE *fp) {
" Add Access-Control-Allow-* headers for browser JS clients. Does not change --host.\n"
" --trace FILE\n"
" Write a human-readable session trace: prompts, cache decisions, output, tool calls.\n"
" --batching\n"
" Enable continuous batching for simple OpenAI streaming requests.\n"
" --batch-size N\n"
" Maximum concurrently active batched requests. Default: 2\n"
"\n"
"Thinking and sampling:\n"
" DeepSeek-compatible chat requests default to thinking mode with high effort.\n"
Expand Down Expand Up @@ -11110,6 +11369,7 @@ static server_config parse_options(int argc, char **argv) {
.ctx_size = 32768,
.default_tokens = 393216,
.tool_memory_max_ids = DS4_TOOL_MEMORY_DEFAULT_MAX_IDS,
.batch_size = 2,
};
c.kv_cache = kv_cache_default_options();

Expand Down Expand Up @@ -11143,6 +11403,10 @@ static server_config parse_options(int argc, char **argv) {
c.enable_cors = true;
} else if (!strcmp(arg, "--trace")) {
c.trace_path = need_arg(&i, argc, argv, arg);
} else if (!strcmp(arg, "--batching")) {
c.batching = true;
} else if (!strcmp(arg, "--batch-size")) {
c.batch_size = parse_int_arg(need_arg(&i, argc, argv, arg), arg);
} else if (!strcmp(arg, "--kv-disk-dir")) {
c.kv_disk_dir = need_arg(&i, argc, argv, arg);
} else if (!strcmp(arg, "--kv-disk-space-mb")) {
Expand Down Expand Up @@ -11237,6 +11501,8 @@ int main(int argc, char **argv) {
s.engine = engine;
s.session = session;
s.default_tokens = cfg.default_tokens;
s.batching = cfg.batching;
s.batch_size = cfg.batch_size;
s.disable_exact_dsml_tool_replay = cfg.disable_exact_dsml_tool_replay;
s.tool_mem.max_entries = cfg.tool_memory_max_ids;
s.enable_cors = cfg.enable_cors;
Expand Down
Loading