Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pytm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"Assumption",
"Boundary",
"Classification",
"Likelihood",
"Severity",
"TLSVersion",
"Data",
"Dataflow",
Expand Down Expand Up @@ -32,7 +34,7 @@
from .pytm import var

# Import from new Pydantic models
from .enums import Action, Classification, DatastoreType, Lifetime, TLSVersion
from .enums import Action, Classification, DatastoreType, Lifetime, Likelihood, Severity, TLSVersion
from .base import Assumption, Controls
from .element import Element
from .data import Data
Expand Down
24 changes: 24 additions & 0 deletions pytm/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,30 @@ def label(self):
return self.value.lower().replace("_", " ")


class Likelihood(OrderedEnum):
"""Likelihood of a threat occurring."""

LOW = 1
MEDIUM = 2
HIGH = 3

def label(self):
return self.name.capitalize()


class Severity(OrderedEnum):
"""Severity level of a threat."""

VERY_LOW = 1
LOW = 2
MEDIUM = 3
HIGH = 4
VERY_HIGH = 5

def label(self):
return self.name.replace("_", " ").capitalize()


class TLSVersion(OrderedEnum):
"""TLS/SSL version levels."""

Expand Down
88 changes: 51 additions & 37 deletions pytm/threat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import ast
import sys
from types import CodeType
from typing import Any, ClassVar, Tuple, List
from typing import Any, ClassVar
from collections.abc import Iterable

import builtins
Expand All @@ -14,6 +14,7 @@
BaseModel,
Field,
ConfigDict,
field_validator,
model_validator,
PrivateAttr,
)
Expand Down Expand Up @@ -130,8 +131,8 @@ def visit_Name(self, node: ast.Name) -> Any: # noqa: D401
return None

@staticmethod
def _attribute_chain(node: ast.Attribute) -> List[str]:
chain: List[str] = [node.attr]
def _attribute_chain(node: ast.Attribute) -> list[str]:
chain: list[str] = [node.attr]
current = node.value
while isinstance(current, ast.Attribute):
if isinstance(current.attr, str) and current.attr.startswith("__"):
Expand Down Expand Up @@ -184,13 +185,21 @@ class Threat(BaseModel):
default="", description="Likelihood of the threat occurring"
)
severity: str = Field(default="", description="Severity level of the threat")

@field_validator("likelihood", "severity", mode="before")
@classmethod
def _coerce_enum_to_str(cls, v: Any) -> str:
"""Accept Likelihood/Severity enum values and coerce them to their label strings."""
if hasattr(v, "label"):
return v.label()
return v
mitigations: str = Field(
default="", description="Possible mitigations for the threat"
)
prerequisites: str = Field(default="", description="Prerequisites for the threat")
example: str = Field(default="", description="Example of the threat")
references: str = Field(default="", description="References for the threat")
target: Tuple = Field(default=(), description="Target classes for this threat")
target: tuple = Field(default=(), description="Target classes for this threat")

_compiled_condition: CodeType | None = PrivateAttr(default=None)
_eval_globals: ClassVar[dict[str, Any] | None] = None
Expand All @@ -210,26 +219,33 @@ def _normalize_input(cls, data: Any) -> Any:
if "Likelihood Of Attack" in data:
data.setdefault("likelihood", data.pop("Likelihood Of Attack"))

# Normalise target to a tuple
target = data.get("target", "Element")
if isinstance(target, str) or not isinstance(target, Iterable):
target = (target,)
else:
target = tuple(target)

# Resolve target name strings to actual Python classes
resolved = []
for name in target:
if isinstance(name, type):
resolved.append(name)
# Normalise target to a tuple — only when explicitly passed (e.g. JSON
# loading). Class-level tuple defaults on Python-native Threat subclasses
# are already correct types and must not be overridden here.
if "target" in data:
target = data["target"]
if isinstance(target, str) or not isinstance(target, Iterable):
target = (target,)
else:
klass = getattr(sys.modules.get("pytm"), name, None)
resolved.append(klass if klass is not None else name)
data["target"] = tuple(resolved)
target = tuple(target)

# Resolve target name strings to actual Python classes
resolved = []
for name in target:
if isinstance(name, type):
resolved.append(name)
else:
klass = getattr(sys.modules.get("pytm"), name, None)
resolved.append(klass if klass is not None else name)
data["target"] = tuple(resolved)

return data

def model_post_init(self, __context: Any) -> None: # noqa: D401
# Skip string compilation when _check_condition is overridden in a subclass.
if type(self)._check_condition is not Threat._check_condition:
return

if not self.condition:
self._compiled_condition = None
return
Expand All @@ -248,13 +264,6 @@ def model_post_init(self, __context: Any) -> None: # noqa: D401
f"Invalid syntax in condition for threat {self.id}: {exc}"
) from exc

def _safeset(self, attr: str, value) -> None:
"""Safely set an attribute value."""
try:
setattr(self, attr, value)
except (ValueError, TypeError):
pass

def __repr__(self):
return (
f"<{self.__module__}.{type(self).__name__}({self.id}) at {hex(id(self))}>"
Expand Down Expand Up @@ -300,32 +309,37 @@ def _allowed_global_names(cls) -> set[str]:
globals_dict = cls._build_eval_globals()
return {key for key in globals_dict.keys() if key != "__builtins__"}

def apply(self, target):
"""Apply the threat condition to a target."""
# Check if target matches any of the target types
def _check_condition(self, target) -> bool:
"""Evaluate whether this threat applies to the given target.

Override this method in subclasses to define conditions natively in Python
instead of using string eval. The base implementation uses the compiled
string condition (for JSON-loaded threats).
"""
if self._compiled_condition is None:
return False

globals_dict = dict(self._build_eval_globals())
return bool(eval(self._compiled_condition, globals_dict, {"target": target}))

def apply(self, target) -> bool:
"""Return True if this threat applies to the given target element."""
if self.target:
target_matches = False
for target_type in self.target:
if isinstance(target_type, str):
# String comparison for backward compatibility
if target_type == type(target).__name__:
target_matches = True
break
elif isinstance(target_type, type):
# Class type comparison
if isinstance(target, target_type):
target_matches = True
break

if not target_matches:
return False

if self._compiled_condition is None:
return False

try:
globals_dict = dict(self._build_eval_globals())
locals_dict = {"target": target}
return bool(eval(self._compiled_condition, globals_dict, locals_dict))
return bool(self._check_condition(target))
except Exception:
return False
19 changes: 19 additions & 0 deletions pytm/threatlib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Threat library — auto-exports all Threat subclasses from category modules.

Import threat classes directly from this package without needing to know
which file they live in::

from pytm.threatlib import INP01, CR01, AA01
"""

import inspect
import importlib
import pkgutil

from pytm.threat import Threat

for _finder, _mod_name, _ispkg in pkgutil.iter_modules(__path__, prefix=__name__ + "."):
_module = importlib.import_module(_mod_name)
for _cls_name, _cls in inspect.getmembers(_module, inspect.isclass):
if issubclass(_cls, Threat) and _cls is not Threat and _cls.__module__ == _module.__name__:
globals()[_cls_name] = _cls
Loading
Loading