636 lines
17 KiB
Python
636 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import copy
|
|
import json
|
|
import re
|
|
import subprocess
|
|
from enum import Enum as PyEnum
|
|
from typing import Callable
|
|
from urllib import request
|
|
|
|
VoidFn = Callable[[], None]
|
|
|
|
CHEATCODES_JSON_URL = "https://raw.githubusercontent.com/foundry-rs/foundry/master/crates/cheatcodes/assets/cheatcodes.json"
|
|
OUT_PATH = "src/Vm.sol"
|
|
|
|
VM_SAFE_DOC = """\
|
|
/// The `VmSafe` interface does not allow manipulation of the EVM state or other actions that may
|
|
/// result in Script simulations differing from on-chain execution. It is recommended to only use
|
|
/// these cheats in scripts.
|
|
"""
|
|
|
|
VM_DOC = """\
|
|
/// The `Vm` interface does allow manipulation of the EVM state. These are all intended to be used
|
|
/// in tests, but it is not recommended to use these cheats in scripts.
|
|
"""
|
|
|
|
|
|
def main():
|
|
json_str = request.urlopen(CHEATCODES_JSON_URL).read().decode("utf-8")
|
|
contract = Cheatcodes.from_json(json_str)
|
|
|
|
ccs = contract.cheatcodes
|
|
ccs = list(filter(lambda cc: cc.status not in ["experimental", "internal"], ccs))
|
|
ccs.sort(key=lambda cc: cc.func.id)
|
|
|
|
safe = list(filter(lambda cc: cc.safety == "safe", ccs))
|
|
safe.sort(key=CmpCheatcode)
|
|
unsafe = list(filter(lambda cc: cc.safety == "unsafe", ccs))
|
|
unsafe.sort(key=CmpCheatcode)
|
|
assert len(safe) + len(unsafe) == len(ccs)
|
|
|
|
prefix_with_group_headers(safe)
|
|
prefix_with_group_headers(unsafe)
|
|
|
|
out = ""
|
|
|
|
out += "// Automatically @generated by scripts/vm.py. Do not modify manually.\n\n"
|
|
|
|
pp = CheatcodesPrinter(
|
|
spdx_identifier="MIT OR Apache-2.0",
|
|
solidity_requirement=">=0.6.2 <0.9.0",
|
|
abicoder_pragma=True,
|
|
)
|
|
pp.p_prelude()
|
|
pp.prelude = False
|
|
out += pp.finish()
|
|
|
|
out += "\n\n"
|
|
out += VM_SAFE_DOC
|
|
vm_safe = Cheatcodes(
|
|
# TODO: Custom errors were introduced in 0.8.4
|
|
errors=[], # contract.errors
|
|
events=contract.events,
|
|
enums=contract.enums,
|
|
structs=contract.structs,
|
|
cheatcodes=safe,
|
|
)
|
|
pp.p_contract(vm_safe, "VmSafe")
|
|
out += pp.finish()
|
|
|
|
out += "\n\n"
|
|
out += VM_DOC
|
|
vm_unsafe = Cheatcodes(
|
|
errors=[],
|
|
events=[],
|
|
enums=[],
|
|
structs=[],
|
|
cheatcodes=unsafe,
|
|
)
|
|
pp.p_contract(vm_unsafe, "Vm", "VmSafe")
|
|
out += pp.finish()
|
|
|
|
# Compatibility with <0.8.0
|
|
def memory_to_calldata(m: re.Match) -> str:
|
|
return " calldata " + m.group(1)
|
|
|
|
out = re.sub(r" memory (.*returns)", memory_to_calldata, out)
|
|
|
|
with open(OUT_PATH, "w") as f:
|
|
f.write(out)
|
|
|
|
forge_fmt = ["forge", "fmt", OUT_PATH]
|
|
res = subprocess.run(forge_fmt)
|
|
assert res.returncode == 0, f"command failed: {forge_fmt}"
|
|
|
|
print(f"Wrote to {OUT_PATH}")
|
|
|
|
|
|
class CmpCheatcode:
|
|
cheatcode: "Cheatcode"
|
|
|
|
def __init__(self, cheatcode: "Cheatcode"):
|
|
self.cheatcode = cheatcode
|
|
|
|
def __lt__(self, other: "CmpCheatcode") -> bool:
|
|
return cmp_cheatcode(self.cheatcode, other.cheatcode) < 0
|
|
|
|
def __eq__(self, other: "CmpCheatcode") -> bool:
|
|
return cmp_cheatcode(self.cheatcode, other.cheatcode) == 0
|
|
|
|
def __gt__(self, other: "CmpCheatcode") -> bool:
|
|
return cmp_cheatcode(self.cheatcode, other.cheatcode) > 0
|
|
|
|
|
|
def cmp_cheatcode(a: "Cheatcode", b: "Cheatcode") -> int:
|
|
if a.group != b.group:
|
|
return -1 if a.group < b.group else 1
|
|
if a.status != b.status:
|
|
return -1 if a.status < b.status else 1
|
|
if a.safety != b.safety:
|
|
return -1 if a.safety < b.safety else 1
|
|
if a.func.id != b.func.id:
|
|
return -1 if a.func.id < b.func.id else 1
|
|
return 0
|
|
|
|
|
|
# HACK: A way to add group header comments without having to modify printer code
|
|
def prefix_with_group_headers(cheats: list["Cheatcode"]):
|
|
s = set()
|
|
for i, cheat in enumerate(cheats):
|
|
if cheat.group in s:
|
|
continue
|
|
|
|
s.add(cheat.group)
|
|
|
|
c = copy.deepcopy(cheat)
|
|
c.func.description = ""
|
|
c.func.declaration = f"// ======== {group(c.group)} ========"
|
|
cheats.insert(i, c)
|
|
return cheats
|
|
|
|
|
|
def group(s: str) -> str:
|
|
if s == "evm":
|
|
return "EVM"
|
|
if s == "json":
|
|
return "JSON"
|
|
return s[0].upper() + s[1:]
|
|
|
|
|
|
class Visibility(PyEnum):
|
|
EXTERNAL: str = "external"
|
|
PUBLIC: str = "public"
|
|
INTERNAL: str = "internal"
|
|
PRIVATE: str = "private"
|
|
|
|
def __str__(self):
|
|
return self.value
|
|
|
|
|
|
class Mutability(PyEnum):
|
|
PURE: str = "pure"
|
|
VIEW: str = "view"
|
|
NONE: str = ""
|
|
|
|
def __str__(self):
|
|
return self.value
|
|
|
|
|
|
class Function:
|
|
id: str
|
|
description: str
|
|
declaration: str
|
|
visibility: Visibility
|
|
mutability: Mutability
|
|
signature: str
|
|
selector: str
|
|
selector_bytes: bytes
|
|
|
|
def __init__(
|
|
self,
|
|
id: str,
|
|
description: str,
|
|
declaration: str,
|
|
visibility: Visibility,
|
|
mutability: Mutability,
|
|
signature: str,
|
|
selector: str,
|
|
selector_bytes: bytes,
|
|
):
|
|
self.id = id
|
|
self.description = description
|
|
self.declaration = declaration
|
|
self.visibility = visibility
|
|
self.mutability = mutability
|
|
self.signature = signature
|
|
self.selector = selector
|
|
self.selector_bytes = selector_bytes
|
|
|
|
@staticmethod
|
|
def from_dict(d: dict) -> "Function":
|
|
return Function(
|
|
d["id"],
|
|
d["description"],
|
|
d["declaration"],
|
|
Visibility(d["visibility"]),
|
|
Mutability(d["mutability"]),
|
|
d["signature"],
|
|
d["selector"],
|
|
bytes(d["selectorBytes"]),
|
|
)
|
|
|
|
|
|
class Cheatcode:
|
|
func: Function
|
|
group: str
|
|
status: str
|
|
safety: str
|
|
|
|
def __init__(self, func: Function, group: str, status: str, safety: str):
|
|
self.func = func
|
|
self.group = group
|
|
self.status = status
|
|
self.safety = safety
|
|
|
|
@staticmethod
|
|
def from_dict(d: dict) -> "Cheatcode":
|
|
return Cheatcode(
|
|
Function.from_dict(d["func"]),
|
|
str(d["group"]),
|
|
str(d["status"]),
|
|
str(d["safety"]),
|
|
)
|
|
|
|
|
|
class Error:
|
|
name: str
|
|
description: str
|
|
declaration: str
|
|
|
|
def __init__(self, name: str, description: str, declaration: str):
|
|
self.name = name
|
|
self.description = description
|
|
self.declaration = declaration
|
|
|
|
@staticmethod
|
|
def from_dict(d: dict) -> "Error":
|
|
return Error(**d)
|
|
|
|
|
|
class Event:
|
|
name: str
|
|
description: str
|
|
declaration: str
|
|
|
|
def __init__(self, name: str, description: str, declaration: str):
|
|
self.name = name
|
|
self.description = description
|
|
self.declaration = declaration
|
|
|
|
@staticmethod
|
|
def from_dict(d: dict) -> "Event":
|
|
return Event(**d)
|
|
|
|
|
|
class EnumVariant:
|
|
name: str
|
|
description: str
|
|
|
|
def __init__(self, name: str, description: str):
|
|
self.name = name
|
|
self.description = description
|
|
|
|
|
|
class Enum:
|
|
name: str
|
|
description: str
|
|
variants: list[EnumVariant]
|
|
|
|
def __init__(self, name: str, description: str, variants: list[EnumVariant]):
|
|
self.name = name
|
|
self.description = description
|
|
self.variants = variants
|
|
|
|
@staticmethod
|
|
def from_dict(d: dict) -> "Enum":
|
|
return Enum(
|
|
d["name"],
|
|
d["description"],
|
|
list(map(lambda v: EnumVariant(**v), d["variants"])),
|
|
)
|
|
|
|
|
|
class StructField:
|
|
name: str
|
|
ty: str
|
|
description: str
|
|
|
|
def __init__(self, name: str, ty: str, description: str):
|
|
self.name = name
|
|
self.ty = ty
|
|
self.description = description
|
|
|
|
|
|
class Struct:
|
|
name: str
|
|
description: str
|
|
fields: list[StructField]
|
|
|
|
def __init__(self, name: str, description: str, fields: list[StructField]):
|
|
self.name = name
|
|
self.description = description
|
|
self.fields = fields
|
|
|
|
@staticmethod
|
|
def from_dict(d: dict) -> "Struct":
|
|
return Struct(
|
|
d["name"],
|
|
d["description"],
|
|
list(map(lambda f: StructField(**f), d["fields"])),
|
|
)
|
|
|
|
|
|
class Cheatcodes:
|
|
errors: list[Error]
|
|
events: list[Event]
|
|
enums: list[Enum]
|
|
structs: list[Struct]
|
|
cheatcodes: list[Cheatcode]
|
|
|
|
def __init__(
|
|
self,
|
|
errors: list[Error],
|
|
events: list[Event],
|
|
enums: list[Enum],
|
|
structs: list[Struct],
|
|
cheatcodes: list[Cheatcode],
|
|
):
|
|
self.errors = errors
|
|
self.events = events
|
|
self.enums = enums
|
|
self.structs = structs
|
|
self.cheatcodes = cheatcodes
|
|
|
|
@staticmethod
|
|
def from_dict(d: dict) -> "Cheatcodes":
|
|
return Cheatcodes(
|
|
errors=[Error.from_dict(e) for e in d["errors"]],
|
|
events=[Event.from_dict(e) for e in d["events"]],
|
|
enums=[Enum.from_dict(e) for e in d["enums"]],
|
|
structs=[Struct.from_dict(e) for e in d["structs"]],
|
|
cheatcodes=[Cheatcode.from_dict(e) for e in d["cheatcodes"]],
|
|
)
|
|
|
|
@staticmethod
|
|
def from_json(s) -> "Cheatcodes":
|
|
return Cheatcodes.from_dict(json.loads(s))
|
|
|
|
@staticmethod
|
|
def from_json_file(file_path: str) -> "Cheatcodes":
|
|
with open(file_path, "r") as f:
|
|
return Cheatcodes.from_dict(json.load(f))
|
|
|
|
|
|
class Item(PyEnum):
|
|
ERROR: str = "error"
|
|
EVENT: str = "event"
|
|
ENUM: str = "enum"
|
|
STRUCT: str = "struct"
|
|
FUNCTION: str = "function"
|
|
|
|
|
|
class ItemOrder:
|
|
_list: list[Item]
|
|
|
|
def __init__(self, list: list[Item]) -> None:
|
|
assert len(list) <= len(Item), "list must not contain more items than Item"
|
|
assert len(list) == len(set(list)), "list must not contain duplicates"
|
|
self._list = list
|
|
pass
|
|
|
|
def get_list(self) -> list[Item]:
|
|
return self._list
|
|
|
|
@staticmethod
|
|
def default() -> "ItemOrder":
|
|
return ItemOrder(
|
|
[
|
|
Item.ERROR,
|
|
Item.EVENT,
|
|
Item.ENUM,
|
|
Item.STRUCT,
|
|
Item.FUNCTION,
|
|
]
|
|
)
|
|
|
|
|
|
class CheatcodesPrinter:
|
|
buffer: str
|
|
|
|
prelude: bool
|
|
spdx_identifier: str
|
|
solidity_requirement: str
|
|
abicoder_v2: bool
|
|
|
|
block_doc_style: bool
|
|
|
|
indent_level: int
|
|
_indent_str: str
|
|
|
|
nl_str: str
|
|
|
|
items_order: ItemOrder
|
|
|
|
def __init__(
|
|
self,
|
|
buffer: str = "",
|
|
prelude: bool = True,
|
|
spdx_identifier: str = "UNLICENSED",
|
|
solidity_requirement: str = "",
|
|
abicoder_pragma: bool = False,
|
|
block_doc_style: bool = False,
|
|
indent_level: int = 0,
|
|
indent_with: int | str = 4,
|
|
nl_str: str = "\n",
|
|
items_order: ItemOrder = ItemOrder.default(),
|
|
):
|
|
self.prelude = prelude
|
|
self.spdx_identifier = spdx_identifier
|
|
self.solidity_requirement = solidity_requirement
|
|
self.abicoder_v2 = abicoder_pragma
|
|
self.block_doc_style = block_doc_style
|
|
self.buffer = buffer
|
|
self.indent_level = indent_level
|
|
self.nl_str = nl_str
|
|
|
|
if isinstance(indent_with, int):
|
|
assert indent_with >= 0
|
|
self._indent_str = " " * indent_with
|
|
elif isinstance(indent_with, str):
|
|
self._indent_str = indent_with
|
|
else:
|
|
assert False, "indent_with must be int or str"
|
|
|
|
self.items_order = items_order
|
|
|
|
def finish(self) -> str:
|
|
ret = self.buffer.rstrip()
|
|
self.buffer = ""
|
|
return ret
|
|
|
|
def p_contract(self, contract: Cheatcodes, name: str, inherits: str = ""):
|
|
if self.prelude:
|
|
self.p_prelude(contract)
|
|
|
|
self._p_str("interface ")
|
|
name = name.strip()
|
|
if name != "":
|
|
self._p_str(name)
|
|
self._p_str(" ")
|
|
if inherits != "":
|
|
self._p_str("is ")
|
|
self._p_str(inherits)
|
|
self._p_str(" ")
|
|
self._p_str("{")
|
|
self._p_nl()
|
|
self._with_indent(lambda: self._p_items(contract))
|
|
self._p_str("}")
|
|
self._p_nl()
|
|
|
|
def _p_items(self, contract: Cheatcodes):
|
|
for item in self.items_order.get_list():
|
|
if item == Item.ERROR:
|
|
self.p_errors(contract.errors)
|
|
elif item == Item.EVENT:
|
|
self.p_events(contract.events)
|
|
elif item == Item.ENUM:
|
|
self.p_enums(contract.enums)
|
|
elif item == Item.STRUCT:
|
|
self.p_structs(contract.structs)
|
|
elif item == Item.FUNCTION:
|
|
self.p_functions(contract.cheatcodes)
|
|
else:
|
|
assert False, f"unknown item {item}"
|
|
|
|
def p_prelude(self, contract: Cheatcodes | None = None):
|
|
self._p_str(f"// SPDX-License-Identifier: {self.spdx_identifier}")
|
|
self._p_nl()
|
|
|
|
if self.solidity_requirement != "":
|
|
req = self.solidity_requirement
|
|
elif contract and len(contract.errors) > 0:
|
|
req = ">=0.8.4 <0.9.0"
|
|
else:
|
|
req = ">=0.6.0 <0.9.0"
|
|
self._p_str(f"pragma solidity {req};")
|
|
self._p_nl()
|
|
|
|
if self.abicoder_v2:
|
|
self._p_str("pragma experimental ABIEncoderV2;")
|
|
self._p_nl()
|
|
|
|
self._p_nl()
|
|
|
|
def p_errors(self, errors: list[Error]):
|
|
for error in errors:
|
|
self._p_line(lambda: self.p_error(error))
|
|
|
|
def p_error(self, error: Error):
|
|
self._p_comment(error.description, doc=True)
|
|
self._p_line(lambda: self._p_str(error.declaration))
|
|
|
|
def p_events(self, events: list[Event]):
|
|
for event in events:
|
|
self._p_line(lambda: self.p_event(event))
|
|
|
|
def p_event(self, event: Event):
|
|
self._p_comment(event.description, doc=True)
|
|
self._p_line(lambda: self._p_str(event.declaration))
|
|
|
|
def p_enums(self, enums: list[Enum]):
|
|
for enum in enums:
|
|
self._p_line(lambda: self.p_enum(enum))
|
|
|
|
def p_enum(self, enum: Enum):
|
|
self._p_comment(enum.description, doc=True)
|
|
self._p_line(lambda: self._p_str(f"enum {enum.name} {{"))
|
|
self._with_indent(lambda: self.p_enum_variants(enum.variants))
|
|
self._p_line(lambda: self._p_str("}"))
|
|
|
|
def p_enum_variants(self, variants: list[EnumVariant]):
|
|
for i, variant in enumerate(variants):
|
|
self._p_indent()
|
|
self._p_comment(variant.description)
|
|
|
|
self._p_indent()
|
|
self._p_str(variant.name)
|
|
if i < len(variants) - 1:
|
|
self._p_str(",")
|
|
self._p_nl()
|
|
|
|
def p_structs(self, structs: list[Struct]):
|
|
for struct in structs:
|
|
self._p_line(lambda: self.p_struct(struct))
|
|
|
|
def p_struct(self, struct: Struct):
|
|
self._p_comment(struct.description, doc=True)
|
|
self._p_line(lambda: self._p_str(f"struct {struct.name} {{"))
|
|
self._with_indent(lambda: self.p_struct_fields(struct.fields))
|
|
self._p_line(lambda: self._p_str("}"))
|
|
|
|
def p_struct_fields(self, fields: list[StructField]):
|
|
for field in fields:
|
|
self._p_line(lambda: self.p_struct_field(field))
|
|
|
|
def p_struct_field(self, field: StructField):
|
|
self._p_comment(field.description)
|
|
self._p_indented(lambda: self._p_str(f"{field.ty} {field.name};"))
|
|
|
|
def p_functions(self, cheatcodes: list[Cheatcode]):
|
|
for cheatcode in cheatcodes:
|
|
self._p_line(lambda: self.p_function(cheatcode.func))
|
|
|
|
def p_function(self, func: Function):
|
|
self._p_comment(func.description, doc=True)
|
|
self._p_line(lambda: self._p_str(func.declaration))
|
|
|
|
def _p_comment(self, s: str, doc: bool = False):
|
|
s = s.strip()
|
|
if s == "":
|
|
return
|
|
|
|
s = map(lambda line: line.lstrip(), s.split("\n"))
|
|
if self.block_doc_style:
|
|
self._p_str("/*")
|
|
if doc:
|
|
self._p_str("*")
|
|
self._p_nl()
|
|
for line in s:
|
|
self._p_indent()
|
|
self._p_str(" ")
|
|
if doc:
|
|
self._p_str("* ")
|
|
self._p_str(line)
|
|
self._p_nl()
|
|
self._p_indent()
|
|
self._p_str(" */")
|
|
self._p_nl()
|
|
else:
|
|
first_line = True
|
|
for line in s:
|
|
if not first_line:
|
|
self._p_indent()
|
|
first_line = False
|
|
|
|
if doc:
|
|
self._p_str("/// ")
|
|
else:
|
|
self._p_str("// ")
|
|
self._p_str(line)
|
|
self._p_nl()
|
|
|
|
def _with_indent(self, f: VoidFn):
|
|
self._inc_indent()
|
|
f()
|
|
self._dec_indent()
|
|
|
|
def _p_line(self, f: VoidFn):
|
|
self._p_indent()
|
|
f()
|
|
self._p_nl()
|
|
|
|
def _p_indented(self, f: VoidFn):
|
|
self._p_indent()
|
|
f()
|
|
|
|
def _p_indent(self):
|
|
for _ in range(self.indent_level):
|
|
self._p_str(self._indent_str)
|
|
|
|
def _p_nl(self):
|
|
self._p_str(self.nl_str)
|
|
|
|
def _p_str(self, txt: str):
|
|
self.buffer += txt
|
|
|
|
def _inc_indent(self):
|
|
self.indent_level += 1
|
|
|
|
def _dec_indent(self):
|
|
self.indent_level -= 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|