ghost-dao-contracts/dependencies/forge-std-1.9.2/scripts/vm.py
Uncle Fatso 46b33b4c75
initial push for smart-contracts
Signed-off-by: Uncle Fatso <uncle.fatso@ghostchain.io>
2025-04-28 14:17:04 +03:00

636 lines
17 KiB
Python

#!/usr/bin/env python3
import copy
import json
import re
import subprocess
from enum import Enum as PyEnum
from typing import Callable
from urllib import request
VoidFn = Callable[[], None]
CHEATCODES_JSON_URL = "https://raw.githubusercontent.com/foundry-rs/foundry/master/crates/cheatcodes/assets/cheatcodes.json"
OUT_PATH = "src/Vm.sol"
VM_SAFE_DOC = """\
/// The `VmSafe` interface does not allow manipulation of the EVM state or other actions that may
/// result in Script simulations differing from on-chain execution. It is recommended to only use
/// these cheats in scripts.
"""
VM_DOC = """\
/// The `Vm` interface does allow manipulation of the EVM state. These are all intended to be used
/// in tests, but it is not recommended to use these cheats in scripts.
"""
def main():
json_str = request.urlopen(CHEATCODES_JSON_URL).read().decode("utf-8")
contract = Cheatcodes.from_json(json_str)
ccs = contract.cheatcodes
ccs = list(filter(lambda cc: cc.status not in ["experimental", "internal"], ccs))
ccs.sort(key=lambda cc: cc.func.id)
safe = list(filter(lambda cc: cc.safety == "safe", ccs))
safe.sort(key=CmpCheatcode)
unsafe = list(filter(lambda cc: cc.safety == "unsafe", ccs))
unsafe.sort(key=CmpCheatcode)
assert len(safe) + len(unsafe) == len(ccs)
prefix_with_group_headers(safe)
prefix_with_group_headers(unsafe)
out = ""
out += "// Automatically @generated by scripts/vm.py. Do not modify manually.\n\n"
pp = CheatcodesPrinter(
spdx_identifier="MIT OR Apache-2.0",
solidity_requirement=">=0.6.2 <0.9.0",
abicoder_pragma=True,
)
pp.p_prelude()
pp.prelude = False
out += pp.finish()
out += "\n\n"
out += VM_SAFE_DOC
vm_safe = Cheatcodes(
# TODO: Custom errors were introduced in 0.8.4
errors=[], # contract.errors
events=contract.events,
enums=contract.enums,
structs=contract.structs,
cheatcodes=safe,
)
pp.p_contract(vm_safe, "VmSafe")
out += pp.finish()
out += "\n\n"
out += VM_DOC
vm_unsafe = Cheatcodes(
errors=[],
events=[],
enums=[],
structs=[],
cheatcodes=unsafe,
)
pp.p_contract(vm_unsafe, "Vm", "VmSafe")
out += pp.finish()
# Compatibility with <0.8.0
def memory_to_calldata(m: re.Match) -> str:
return " calldata " + m.group(1)
out = re.sub(r" memory (.*returns)", memory_to_calldata, out)
with open(OUT_PATH, "w") as f:
f.write(out)
forge_fmt = ["forge", "fmt", OUT_PATH]
res = subprocess.run(forge_fmt)
assert res.returncode == 0, f"command failed: {forge_fmt}"
print(f"Wrote to {OUT_PATH}")
class CmpCheatcode:
cheatcode: "Cheatcode"
def __init__(self, cheatcode: "Cheatcode"):
self.cheatcode = cheatcode
def __lt__(self, other: "CmpCheatcode") -> bool:
return cmp_cheatcode(self.cheatcode, other.cheatcode) < 0
def __eq__(self, other: "CmpCheatcode") -> bool:
return cmp_cheatcode(self.cheatcode, other.cheatcode) == 0
def __gt__(self, other: "CmpCheatcode") -> bool:
return cmp_cheatcode(self.cheatcode, other.cheatcode) > 0
def cmp_cheatcode(a: "Cheatcode", b: "Cheatcode") -> int:
if a.group != b.group:
return -1 if a.group < b.group else 1
if a.status != b.status:
return -1 if a.status < b.status else 1
if a.safety != b.safety:
return -1 if a.safety < b.safety else 1
if a.func.id != b.func.id:
return -1 if a.func.id < b.func.id else 1
return 0
# HACK: A way to add group header comments without having to modify printer code
def prefix_with_group_headers(cheats: list["Cheatcode"]):
s = set()
for i, cheat in enumerate(cheats):
if cheat.group in s:
continue
s.add(cheat.group)
c = copy.deepcopy(cheat)
c.func.description = ""
c.func.declaration = f"// ======== {group(c.group)} ========"
cheats.insert(i, c)
return cheats
def group(s: str) -> str:
if s == "evm":
return "EVM"
if s == "json":
return "JSON"
return s[0].upper() + s[1:]
class Visibility(PyEnum):
EXTERNAL: str = "external"
PUBLIC: str = "public"
INTERNAL: str = "internal"
PRIVATE: str = "private"
def __str__(self):
return self.value
class Mutability(PyEnum):
PURE: str = "pure"
VIEW: str = "view"
NONE: str = ""
def __str__(self):
return self.value
class Function:
id: str
description: str
declaration: str
visibility: Visibility
mutability: Mutability
signature: str
selector: str
selector_bytes: bytes
def __init__(
self,
id: str,
description: str,
declaration: str,
visibility: Visibility,
mutability: Mutability,
signature: str,
selector: str,
selector_bytes: bytes,
):
self.id = id
self.description = description
self.declaration = declaration
self.visibility = visibility
self.mutability = mutability
self.signature = signature
self.selector = selector
self.selector_bytes = selector_bytes
@staticmethod
def from_dict(d: dict) -> "Function":
return Function(
d["id"],
d["description"],
d["declaration"],
Visibility(d["visibility"]),
Mutability(d["mutability"]),
d["signature"],
d["selector"],
bytes(d["selectorBytes"]),
)
class Cheatcode:
func: Function
group: str
status: str
safety: str
def __init__(self, func: Function, group: str, status: str, safety: str):
self.func = func
self.group = group
self.status = status
self.safety = safety
@staticmethod
def from_dict(d: dict) -> "Cheatcode":
return Cheatcode(
Function.from_dict(d["func"]),
str(d["group"]),
str(d["status"]),
str(d["safety"]),
)
class Error:
name: str
description: str
declaration: str
def __init__(self, name: str, description: str, declaration: str):
self.name = name
self.description = description
self.declaration = declaration
@staticmethod
def from_dict(d: dict) -> "Error":
return Error(**d)
class Event:
name: str
description: str
declaration: str
def __init__(self, name: str, description: str, declaration: str):
self.name = name
self.description = description
self.declaration = declaration
@staticmethod
def from_dict(d: dict) -> "Event":
return Event(**d)
class EnumVariant:
name: str
description: str
def __init__(self, name: str, description: str):
self.name = name
self.description = description
class Enum:
name: str
description: str
variants: list[EnumVariant]
def __init__(self, name: str, description: str, variants: list[EnumVariant]):
self.name = name
self.description = description
self.variants = variants
@staticmethod
def from_dict(d: dict) -> "Enum":
return Enum(
d["name"],
d["description"],
list(map(lambda v: EnumVariant(**v), d["variants"])),
)
class StructField:
name: str
ty: str
description: str
def __init__(self, name: str, ty: str, description: str):
self.name = name
self.ty = ty
self.description = description
class Struct:
name: str
description: str
fields: list[StructField]
def __init__(self, name: str, description: str, fields: list[StructField]):
self.name = name
self.description = description
self.fields = fields
@staticmethod
def from_dict(d: dict) -> "Struct":
return Struct(
d["name"],
d["description"],
list(map(lambda f: StructField(**f), d["fields"])),
)
class Cheatcodes:
errors: list[Error]
events: list[Event]
enums: list[Enum]
structs: list[Struct]
cheatcodes: list[Cheatcode]
def __init__(
self,
errors: list[Error],
events: list[Event],
enums: list[Enum],
structs: list[Struct],
cheatcodes: list[Cheatcode],
):
self.errors = errors
self.events = events
self.enums = enums
self.structs = structs
self.cheatcodes = cheatcodes
@staticmethod
def from_dict(d: dict) -> "Cheatcodes":
return Cheatcodes(
errors=[Error.from_dict(e) for e in d["errors"]],
events=[Event.from_dict(e) for e in d["events"]],
enums=[Enum.from_dict(e) for e in d["enums"]],
structs=[Struct.from_dict(e) for e in d["structs"]],
cheatcodes=[Cheatcode.from_dict(e) for e in d["cheatcodes"]],
)
@staticmethod
def from_json(s) -> "Cheatcodes":
return Cheatcodes.from_dict(json.loads(s))
@staticmethod
def from_json_file(file_path: str) -> "Cheatcodes":
with open(file_path, "r") as f:
return Cheatcodes.from_dict(json.load(f))
class Item(PyEnum):
ERROR: str = "error"
EVENT: str = "event"
ENUM: str = "enum"
STRUCT: str = "struct"
FUNCTION: str = "function"
class ItemOrder:
_list: list[Item]
def __init__(self, list: list[Item]) -> None:
assert len(list) <= len(Item), "list must not contain more items than Item"
assert len(list) == len(set(list)), "list must not contain duplicates"
self._list = list
pass
def get_list(self) -> list[Item]:
return self._list
@staticmethod
def default() -> "ItemOrder":
return ItemOrder(
[
Item.ERROR,
Item.EVENT,
Item.ENUM,
Item.STRUCT,
Item.FUNCTION,
]
)
class CheatcodesPrinter:
buffer: str
prelude: bool
spdx_identifier: str
solidity_requirement: str
abicoder_v2: bool
block_doc_style: bool
indent_level: int
_indent_str: str
nl_str: str
items_order: ItemOrder
def __init__(
self,
buffer: str = "",
prelude: bool = True,
spdx_identifier: str = "UNLICENSED",
solidity_requirement: str = "",
abicoder_pragma: bool = False,
block_doc_style: bool = False,
indent_level: int = 0,
indent_with: int | str = 4,
nl_str: str = "\n",
items_order: ItemOrder = ItemOrder.default(),
):
self.prelude = prelude
self.spdx_identifier = spdx_identifier
self.solidity_requirement = solidity_requirement
self.abicoder_v2 = abicoder_pragma
self.block_doc_style = block_doc_style
self.buffer = buffer
self.indent_level = indent_level
self.nl_str = nl_str
if isinstance(indent_with, int):
assert indent_with >= 0
self._indent_str = " " * indent_with
elif isinstance(indent_with, str):
self._indent_str = indent_with
else:
assert False, "indent_with must be int or str"
self.items_order = items_order
def finish(self) -> str:
ret = self.buffer.rstrip()
self.buffer = ""
return ret
def p_contract(self, contract: Cheatcodes, name: str, inherits: str = ""):
if self.prelude:
self.p_prelude(contract)
self._p_str("interface ")
name = name.strip()
if name != "":
self._p_str(name)
self._p_str(" ")
if inherits != "":
self._p_str("is ")
self._p_str(inherits)
self._p_str(" ")
self._p_str("{")
self._p_nl()
self._with_indent(lambda: self._p_items(contract))
self._p_str("}")
self._p_nl()
def _p_items(self, contract: Cheatcodes):
for item in self.items_order.get_list():
if item == Item.ERROR:
self.p_errors(contract.errors)
elif item == Item.EVENT:
self.p_events(contract.events)
elif item == Item.ENUM:
self.p_enums(contract.enums)
elif item == Item.STRUCT:
self.p_structs(contract.structs)
elif item == Item.FUNCTION:
self.p_functions(contract.cheatcodes)
else:
assert False, f"unknown item {item}"
def p_prelude(self, contract: Cheatcodes | None = None):
self._p_str(f"// SPDX-License-Identifier: {self.spdx_identifier}")
self._p_nl()
if self.solidity_requirement != "":
req = self.solidity_requirement
elif contract and len(contract.errors) > 0:
req = ">=0.8.4 <0.9.0"
else:
req = ">=0.6.0 <0.9.0"
self._p_str(f"pragma solidity {req};")
self._p_nl()
if self.abicoder_v2:
self._p_str("pragma experimental ABIEncoderV2;")
self._p_nl()
self._p_nl()
def p_errors(self, errors: list[Error]):
for error in errors:
self._p_line(lambda: self.p_error(error))
def p_error(self, error: Error):
self._p_comment(error.description, doc=True)
self._p_line(lambda: self._p_str(error.declaration))
def p_events(self, events: list[Event]):
for event in events:
self._p_line(lambda: self.p_event(event))
def p_event(self, event: Event):
self._p_comment(event.description, doc=True)
self._p_line(lambda: self._p_str(event.declaration))
def p_enums(self, enums: list[Enum]):
for enum in enums:
self._p_line(lambda: self.p_enum(enum))
def p_enum(self, enum: Enum):
self._p_comment(enum.description, doc=True)
self._p_line(lambda: self._p_str(f"enum {enum.name} {{"))
self._with_indent(lambda: self.p_enum_variants(enum.variants))
self._p_line(lambda: self._p_str("}"))
def p_enum_variants(self, variants: list[EnumVariant]):
for i, variant in enumerate(variants):
self._p_indent()
self._p_comment(variant.description)
self._p_indent()
self._p_str(variant.name)
if i < len(variants) - 1:
self._p_str(",")
self._p_nl()
def p_structs(self, structs: list[Struct]):
for struct in structs:
self._p_line(lambda: self.p_struct(struct))
def p_struct(self, struct: Struct):
self._p_comment(struct.description, doc=True)
self._p_line(lambda: self._p_str(f"struct {struct.name} {{"))
self._with_indent(lambda: self.p_struct_fields(struct.fields))
self._p_line(lambda: self._p_str("}"))
def p_struct_fields(self, fields: list[StructField]):
for field in fields:
self._p_line(lambda: self.p_struct_field(field))
def p_struct_field(self, field: StructField):
self._p_comment(field.description)
self._p_indented(lambda: self._p_str(f"{field.ty} {field.name};"))
def p_functions(self, cheatcodes: list[Cheatcode]):
for cheatcode in cheatcodes:
self._p_line(lambda: self.p_function(cheatcode.func))
def p_function(self, func: Function):
self._p_comment(func.description, doc=True)
self._p_line(lambda: self._p_str(func.declaration))
def _p_comment(self, s: str, doc: bool = False):
s = s.strip()
if s == "":
return
s = map(lambda line: line.lstrip(), s.split("\n"))
if self.block_doc_style:
self._p_str("/*")
if doc:
self._p_str("*")
self._p_nl()
for line in s:
self._p_indent()
self._p_str(" ")
if doc:
self._p_str("* ")
self._p_str(line)
self._p_nl()
self._p_indent()
self._p_str(" */")
self._p_nl()
else:
first_line = True
for line in s:
if not first_line:
self._p_indent()
first_line = False
if doc:
self._p_str("/// ")
else:
self._p_str("// ")
self._p_str(line)
self._p_nl()
def _with_indent(self, f: VoidFn):
self._inc_indent()
f()
self._dec_indent()
def _p_line(self, f: VoidFn):
self._p_indent()
f()
self._p_nl()
def _p_indented(self, f: VoidFn):
self._p_indent()
f()
def _p_indent(self):
for _ in range(self.indent_level):
self._p_str(self._indent_str)
def _p_nl(self):
self._p_str(self.nl_str)
def _p_str(self, txt: str):
self.buffer += txt
def _inc_indent(self):
self.indent_level += 1
def _dec_indent(self):
self.indent_level -= 1
if __name__ == "__main__":
main()