Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,17 @@ Generates discrete values according to a user-defined probability set.
* a list (array) such as `[0.05, 0.8, 0.15]`, in which case the (zero-based) indices are the integer values generated
* or a dictionary (key-value) structure such as `{ "1":0.05, "3":0.8, "7":0.15 }` with integer keys (specified as strings in order to be valud JSON), in which case the keys are the integers generated

You may actually specify string or float keys in your `probabilities` dict to generate those values instead of integers, however you must specify the additional parameter `keys_type="varchar"` (or similar) so the the value types are correct. For example:
You may actually specify `string`, `number`, or `boolean` keys in your `probabilities` dict to generate those values instead of integers, however you must specify the additional parameter `keys_type="number"` (or similar) so the the value types are correct. For example:
```python
synth_distributions_discrete_probabilities(probabilities={"cat":0.3, "dog":0.5, "parrot":0.2}, keys_type="varchar")
synth_distributions_discrete_probabilities(probabilities={"97":0.3, "85":0.5, "64":0.2}, keys_type="number")
```
The default `keys_type` is `string`.

No matter what `keys_type` you choose, you may optionally cast the values to a different type in your database engine with `cast_to`, for example:
```python
synth_distributions_discrete_probabilities(probabilities={"2024-12-31":0.5, "2025-01-01":0.5}, cast_to="date")
```
(This will compile to something like `case .. when ... then CAST ("2024-12-31" AS date) when ... then CAST ("2025-01-01" AS date) ... end)`.)

`probabilities` must sum to `1.0`.

Expand Down
18 changes: 12 additions & 6 deletions macros/distributions/discrete/probabilities.sql
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{% macro synth_distribution_discrete_probabilities(probabilities) %}
{% macro synth_distribution_discrete_probabilities(probabilities, keys_type="string", cast_to=None) %}
{# Set up some variables: #}
{%- set epsilon = 0.00001 -%}{# "close enough" to zero #}
{%- set ns = namespace(max_prob_digits=1, keys=[], values=[], curr_idx=0, curr_threshold=0.0) -%}
Expand All @@ -8,6 +8,7 @@
{%- set ns.keys = probabilities.keys()|list -%}
{%- set ns.values = probabilities.values()|list -%}
{%- elif probabilities is iterable -%}{#- list -#}
{% set keys_type = "number" %}
{%- set ns.keys = range(probabilities|length) -%}
{%- set ns.values = probabilities -%}
{%- else -%}
Expand All @@ -18,12 +19,12 @@
{{ exceptions.raise_compiler_error("`probabilities` must sum to 1.0, not " + ns.values|sum|string) }}
{%- endif -%}

{%- if ns.keys[0] is number -%}
{%- if keys_type in ["number", "boolean"] -%}
{% set wrap = "" %}
{% elif ns.keys[0] is string %}
{% elif keys_type=="string" %}
{% set wrap = "'" %}
{% else %}
{{ exceptions.raise_compiler_error("`probabilities` keys must be strings or numbers") }}
{{ exceptions.raise_compiler_error("`keys_type` must be `string`, `number`, or `boolean`") }}
{% endif %}

{%- set ns.curr_threshold = ns.values[0] -%}
Expand Down Expand Up @@ -57,7 +58,9 @@
)
) }}
],
{{wrap}}{{value_list[value_list|length - 1]}}{{wrap}}
{% if cast_to %}CAST({% endif %}
{{wrap}}{{value_list[value_list|length - 1]}}{{wrap}}
{% if cast_to %} AS {{cast_to}} ){% endif %}
)
{% else %}
{# Case statement on uniformly-distributed range: #}
Expand All @@ -67,7 +70,10 @@
{%- set ns.curr_idx = ns.curr_idx + 1 -%}
{%- set ns.curr_threshold = ns.curr_threshold + ns.values[ns.curr_idx] -%}
{%- endif -%}
when {{i}} then {{wrap}}{{ns.keys[ns.curr_idx]}}{{wrap}}
when {{i}} then
{% if cast_to %}CAST({% endif %}
{{wrap}}{{ns.keys[ns.curr_idx]}}{{wrap}}
{% if cast_to %} AS {{cast_to}} ){% endif %}
{% endfor %}
end
{% endif %}
Expand Down