Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions integration-tests/tests/dict_input.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Test bare dict as predict input type

# Build the image
cog build -t $TEST_IMAGE

# Dict input works via JSON
cog predict $TEST_IMAGE --json '{"data": {"greeting": "hello", "count": 3}}'
stdout '"status": "succeeded"'
stdout '"output": "count=3, greeting=hello"'

# Empty dict also works
cog predict $TEST_IMAGE --json '{"data": {}}'
stdout '"status": "succeeded"'
stdout '"output": ""'

# Verify the schema has the correct type for dict input
exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "run.cog.openapi_schema"}}'
stdout '"data":{"title":"Data","type":"object","x-order":0}'
stdout '"required":\["data"\]'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
def predict(self, data: dict) -> str:
parts = [f"{k}={v}" for k, v in sorted(data.items())]
return ", ".join(parts)
33 changes: 33 additions & 0 deletions integration-tests/tests/list_dict_input.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Test list[dict] as predict input type (common pattern for chat message arrays)

# Build the image
cog build -t $TEST_IMAGE

# List of dicts works as input via JSON
cog predict $TEST_IMAGE --json '{"messages": [{"role": "user", "content": "hello"}, {"role": "assistant", "content": "hi"}]}'
stdout '"status": "succeeded"'
stdout '"output": "user: hello\\nassistant: hi"'

# Empty list also works
cog predict $TEST_IMAGE --json '{"messages": []}'
stdout '"status": "succeeded"'
stdout '"output": ""'

# Verify the schema has array-of-objects type for list[dict] input
exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "run.cog.openapi_schema"}}'
stdout '"messages":{"items":{"type":"object"},"title":"Messages","type":"array","x-order":0}'
stdout '"required":\["messages"\]'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
def predict(self, messages: list[dict]) -> str:
parts = [f"{m['role']}: {m['content']}" for m in messages]
return "\n".join(parts)
34 changes: 34 additions & 0 deletions integration-tests/tests/typing_dict_input.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Test typing.Dict[str, Any] and List[Dict[str, Any]] as predict input types
# Verifies the typing-module variants work (not just bare dict/list[dict])

# Build the image
cog build -t $TEST_IMAGE

# Dict[str, Any] input works
cog predict $TEST_IMAGE --json '{"data": {"name": "alice"}, "items": [{"id": 1}, {"id": 2}]}'
stdout '"status": "succeeded"'
stdout '"output": "data=1 items=2"'

# Verify the schema has correct types for both inputs
exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "run.cog.openapi_schema"}}'
stdout '"data":{"title":"Data","type":"object","x-order":0}'
stdout '"items":{"items":{"type":"object"},"title":"Items","type":"array","x-order":1}'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from typing import Any, Dict, List

from cog import BasePredictor


class Predictor(BasePredictor):
def predict(
self,
data: Dict[str, Any],
items: List[Dict[str, Any]],
) -> str:
return f"data={len(data)} items={len(items)}"
203 changes: 203 additions & 0 deletions pkg/schema/python/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2598,3 +2598,206 @@ func indexOf(s, substr string) int {
}
return -1
}

// ---------------------------------------------------------------------------
// dict input types
// ---------------------------------------------------------------------------

func TestBareDictInput(t *testing.T) {
source := `
from cog import BasePredictor

class Predictor(BasePredictor):
def predict(self, data: dict) -> str:
return str(data)
`
info := parse(t, source, "Predictor")
data, ok := info.Inputs.Get("data")
require.True(t, ok)
require.Equal(t, schema.TypeAny, data.FieldType.Primitive)
require.Equal(t, schema.Required, data.FieldType.Repetition)
}

func TestDictStrAnyInput(t *testing.T) {
source := `
from typing import Dict, Any
from cog import BasePredictor

class Predictor(BasePredictor):
def predict(self, data: Dict[str, Any]) -> str:
return str(data)
`
info := parse(t, source, "Predictor")
data, ok := info.Inputs.Get("data")
require.True(t, ok)
require.Equal(t, schema.TypeAny, data.FieldType.Primitive)
require.Equal(t, schema.Required, data.FieldType.Repetition)
}

func TestListOfDictInput(t *testing.T) {
source := `
from cog import BasePredictor

class Predictor(BasePredictor):
def predict(self, messages: list[dict]) -> str:
return str(messages)
`
info := parse(t, source, "Predictor")
messages, ok := info.Inputs.Get("messages")
require.True(t, ok)
require.Equal(t, schema.TypeAny, messages.FieldType.Primitive)
require.Equal(t, schema.Repeated, messages.FieldType.Repetition)
}

func TestListOfTypingDictInput(t *testing.T) {
source := `
from typing import List, Dict, Any
from cog import BasePredictor

class Predictor(BasePredictor):
def predict(self, messages: List[Dict[str, Any]]) -> str:
return str(messages)
`
info := parse(t, source, "Predictor")
messages, ok := info.Inputs.Get("messages")
require.True(t, ok)
require.Equal(t, schema.TypeAny, messages.FieldType.Primitive)
require.Equal(t, schema.Repeated, messages.FieldType.Repetition)
}

func TestOptionalDictInput(t *testing.T) {
source := `
from typing import Optional
from cog import BasePredictor

class Predictor(BasePredictor):
def predict(self, data: Optional[dict] = None) -> str:
return str(data)
`
info := parse(t, source, "Predictor")
data, ok := info.Inputs.Get("data")
require.True(t, ok)
require.Equal(t, schema.TypeAny, data.FieldType.Primitive)
require.Equal(t, schema.Optional, data.FieldType.Repetition)
}

func TestOptionalListOfDictInput(t *testing.T) {
source := `
from typing import Optional
from cog import BasePredictor

class Predictor(BasePredictor):
def predict(self, messages: Optional[list[dict]] = None) -> str:
return str(messages)
`
info := parse(t, source, "Predictor")
messages, ok := info.Inputs.Get("messages")
require.True(t, ok)
require.Equal(t, schema.TypeAny, messages.FieldType.Primitive)
require.Equal(t, schema.OptionalRepeated, messages.FieldType.Repetition)
}

func TestDictInputJSONSchema(t *testing.T) {
source := `
from cog import BasePredictor

class Predictor(BasePredictor):
def predict(self, messages: list[dict]) -> str:
return str(messages)
`
info := parse(t, source, "Predictor")
messages, ok := info.Inputs.Get("messages")
require.True(t, ok)

jt := messages.FieldType.JSONType()
require.Equal(t, "array", jt["type"])
items, ok := jt["items"].(map[string]any)
require.True(t, ok)
require.Equal(t, "object", items["type"])
}

func TestPEP604DictOrNoneInput(t *testing.T) {
source := `
from cog import BasePredictor

class Predictor(BasePredictor):
def predict(self, data: dict | None = None) -> str:
return str(data)
`
info := parse(t, source, "Predictor")
data, ok := info.Inputs.Get("data")
require.True(t, ok)
require.Equal(t, schema.TypeAny, data.FieldType.Primitive)
require.Equal(t, schema.Optional, data.FieldType.Repetition)
}

func TestPEP604ListDictOrNoneInput(t *testing.T) {
source := `
from cog import BasePredictor

class Predictor(BasePredictor):
def predict(self, messages: list[dict] | None = None) -> str:
return str(messages)
`
info := parse(t, source, "Predictor")
messages, ok := info.Inputs.Get("messages")
require.True(t, ok)
require.Equal(t, schema.TypeAny, messages.FieldType.Primitive)
require.Equal(t, schema.OptionalRepeated, messages.FieldType.Repetition)
}

func TestDictStrIntInputTypeErasure(t *testing.T) {
// dict[str, int] should be accepted as TypeAny — type parameters are
// intentionally discarded because FieldType is a flat model (no recursive
// structure). The output path uses SchemaType which can represent typed
// dicts precisely, but inputs are always opaque JSON objects.
source := `
from cog import BasePredictor

class Predictor(BasePredictor):
def predict(self, data: dict[str, int]) -> str:
return str(data)
`
info := parse(t, source, "Predictor")
data, ok := info.Inputs.Get("data")
require.True(t, ok)
require.Equal(t, schema.TypeAny, data.FieldType.Primitive)
require.Equal(t, schema.Required, data.FieldType.Repetition)
}

func TestDictInputWithDefault(t *testing.T) {
source := `
from cog import BasePredictor

class Predictor(BasePredictor):
def predict(self, data: dict = {}) -> str:
return str(data)
`
info := parse(t, source, "Predictor")
data, ok := info.Inputs.Get("data")
require.True(t, ok)
require.Equal(t, schema.TypeAny, data.FieldType.Primitive)
require.Equal(t, schema.Required, data.FieldType.Repetition)
require.NotNil(t, data.Default)
require.Equal(t, schema.DefaultDict, data.Default.Kind)
require.False(t, data.IsRequired())
}

func TestListDictInputWithDefault(t *testing.T) {
source := `
from cog import BasePredictor

class Predictor(BasePredictor):
def predict(self, messages: list[dict] = []) -> str:
return str(messages)
`
info := parse(t, source, "Predictor")
messages, ok := info.Inputs.Get("messages")
require.True(t, ok)
require.Equal(t, schema.TypeAny, messages.FieldType.Primitive)
require.Equal(t, schema.Repeated, messages.FieldType.Repetition)
require.NotNil(t, messages.Default)
require.Equal(t, schema.DefaultList, messages.Default.Kind)
require.Empty(t, messages.Default.List)
require.False(t, messages.IsRequired())
}
12 changes: 12 additions & 0 deletions pkg/schema/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ func (ctx *ImportContext) IsBasePredictor(name string) bool {
func ResolveFieldType(ann TypeAnnotation, ctx *ImportContext) (FieldType, error) {
switch ann.Kind {
case TypeAnnotSimple:
// Bare dict / Dict → opaque JSON object (TypeAny)
if ann.Name == "dict" || ann.Name == "Dict" {
return FieldType{Primitive: TypeAny, Repetition: Required}, nil
}
prim, ok := PrimitiveFromName(ann.Name)
if !ok {
return FieldType{}, errUnsupportedType(ann.Name)
Expand All @@ -279,6 +283,14 @@ func ResolveFieldType(ann TypeAnnotation, ctx *ImportContext) (FieldType, error)

case TypeAnnotGeneric:
outer := ann.Name
// dict[K, V] / Dict[K, V] → opaque JSON object (TypeAny).
// Type parameters are intentionally discarded because FieldType is flat
// (PrimitiveType + Repetition only). The output path uses the recursive
// SchemaType model which can represent typed dicts (e.g. dict[str, int])
// precisely; for inputs, all dicts are treated as opaque JSON objects.
if outer == "dict" || outer == "Dict" {
return FieldType{Primitive: TypeAny, Repetition: Required}, nil
}
if outer == "Optional" {
if len(ann.Args) != 1 {
return FieldType{}, errUnsupportedType(fmt.Sprintf("Optional expects exactly 1 type argument, got %d", len(ann.Args)))
Expand Down
23 changes: 23 additions & 0 deletions python/cog/_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,13 @@ def from_type(tpe: type) -> "FieldType":
if len(t_args) != 1:
raise ValueError("List must have one type argument")
elem_t = t_args[0]
# dict elements in lists → treat as ANY (opaque JSON objects)
if elem_t is dict or typing.get_origin(elem_t) is dict:
return FieldType(
primitive=PrimitiveType.ANY,
repetition=Repetition.REPEATED,
coder=None,
)
nested_t = typing.get_origin(elem_t)
if nested_t is not None:
raise ValueError(
Expand All @@ -259,6 +266,13 @@ def from_type(tpe: type) -> "FieldType":
if len(list_args) != 1:
raise ValueError("List must have one type argument")
elem_t = list_args[0]
# dict elements in optional lists → ANY
if elem_t is dict or typing.get_origin(elem_t) is dict:
return FieldType(
primitive=PrimitiveType.ANY,
repetition=Repetition.OPTIONAL_REPEATED,
coder=None,
)
inner_origin = typing.get_origin(elem_t)
if inner_origin is not None:
raise ValueError(
Expand All @@ -267,6 +281,15 @@ def from_type(tpe: type) -> "FieldType":
else:
elem_t = Any
repetition = Repetition.OPTIONAL_REPEATED
elif nested_t is dict or elem_t is dict:
# Optional[dict] or Optional[Dict[str, Any]] → optional ANY.
# nested_t is dict: elem_t is parameterized (e.g. Dict[str, Any]).
# elem_t is dict: elem_t is bare dict (nested_t is None).
return FieldType(
primitive=PrimitiveType.ANY,
repetition=Repetition.OPTIONAL,
coder=None,
)
elif nested_t is not None:
raise ValueError(
f"Optional cannot have nested type {_type_name(nested_t)}"
Expand Down
Loading
Loading