diff --git a/Dockerfile b/Dockerfile index 1fe62e71..36673778 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,8 @@ FROM ciimage/python:3.7 RUN apt update -RUN apt install -y make libgmp3-dev g++ python3-pip python3.7-dev python3.7-venv npm +RUN apt -y -o Dpkg::Options::="--force-overwrite" install python3.7-dev +RUN apt install -y make libgmp3-dev g++ python3-pip python3.7-venv npm # Installing cmake via apt doesn't bring the most up-to-date version. RUN pip install cmake==3.22 diff --git a/README.md b/README.md index e936b99a..6ec3b7e7 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Once the docker image is built, you can fetch the python package zip file using: ```bash > container_id=$(docker create cairo) -> docker cp ${container_id}:/app/cairo-lang-0.8.1.zip . +> docker cp ${container_id}:/app/cairo-lang-0.8.2.zip . > docker rm -v ${container_id} ``` diff --git a/src/starkware/cairo/common/CMakeLists.txt b/src/starkware/cairo/common/CMakeLists.txt index 09ffe717..8cb3d945 100644 --- a/src/starkware/cairo/common/CMakeLists.txt +++ b/src/starkware/cairo/common/CMakeLists.txt @@ -8,8 +8,15 @@ python_lib(cairo_common_lib cairo_blake2s/blake2s_utils.py cairo_blake2s/packed_blake2s.cairo cairo_builtins.cairo + cairo_keccak/keccak.cairo cairo_keccak/keccak_utils.py + cairo_keccak/packed_keccak.cairo + cairo_secp/bigint.cairo + cairo_secp/constants.cairo + cairo_secp/ec.cairo + cairo_secp/field.cairo cairo_secp/secp_utils.py + cairo_secp/signature.cairo cairo_sha256/sha256_utils.py default_dict.cairo dict_access.cairo diff --git a/src/starkware/cairo/common/cairo_keccak/keccak.cairo b/src/starkware/cairo/common/cairo_keccak/keccak.cairo new file mode 100644 index 00000000..fbc41859 --- /dev/null +++ b/src/starkware/cairo/common/cairo_keccak/keccak.cairo @@ -0,0 +1,471 @@ +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.bitwise import bitwise_and, bitwise_or, bitwise_xor +from starkware.cairo.common.cairo_builtins import BitwiseBuiltin +from starkware.cairo.common.cairo_keccak.packed_keccak import BLOCK_SIZE, packed_keccak_func +from starkware.cairo.common.math import ( + assert_lt, + assert_nn, + assert_nn_le, + assert_not_zero, + split_felt, + unsigned_div_rem, +) +from starkware.cairo.common.memcpy import memcpy +from starkware.cairo.common.memset import memset +from starkware.cairo.common.pow import pow +from starkware.cairo.common.uint256 import Uint256, uint256_reverse_endian + +const KECCAK_STATE_SIZE_FELTS = 25 +const KECCAK_FULL_RATE_IN_BYTES = 136 +const KECCAK_FULL_RATE_IN_WORDS = 17 +const KECCAK_CAPACITY_IN_WORDS = 8 +const BYTES_IN_WORD = 8 + +# Computes the keccak hash of multiple uint256 numbers. +func keccak_uint256s{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt*}( + n_elements : felt, elements : Uint256* +) -> (res : Uint256): + alloc_locals + + let (inputs) = alloc() + let inputs_start = inputs + + keccak_add_uint256s{inputs=inputs}(n_elements=n_elements, elements=elements, bigend=0) + + return keccak(inputs=inputs_start, n_bytes=n_elements * 32) +end + +# Computes the keccak hash of multiple uint256 numbers (big-endian). +# Note that both the output and the input are in big endian representation. +func keccak_uint256s_bigend{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt*}( + n_elements : felt, elements : Uint256* +) -> (res : Uint256): + alloc_locals + + let (inputs) = alloc() + let inputs_start = inputs + + keccak_add_uint256s{inputs=inputs}(n_elements=n_elements, elements=elements, bigend=1) + + return keccak_bigend(inputs=inputs_start, n_bytes=n_elements * 32) +end + +# Computes the keccak hash of multiple field elements. +func keccak_felts{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt*}( + n_elements : felt, elements : felt* +) -> (res : Uint256): + alloc_locals + + let (inputs) = alloc() + let inputs_start = inputs + + keccak_add_felts{inputs=inputs}(n_elements=n_elements, elements=elements, bigend=0) + + return keccak(inputs=inputs_start, n_bytes=n_elements * 32) +end + +# Computes the keccak hash of multiple field elements (big-endian). +# Note that both the output and the input are in big endian representation. +func keccak_felts_bigend{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt*}( + n_elements : felt, elements : felt* +) -> (res : Uint256): + alloc_locals + + let (inputs) = alloc() + let inputs_start = inputs + + keccak_add_felts{inputs=inputs}(n_elements=n_elements, elements=elements, bigend=1) + + return keccak_bigend(inputs=inputs_start, n_bytes=n_elements * 32) +end + +# Helper functions. +# These functions serialize input to an array of 64-bit little endian words +# to be used with keccak() or keccak_as_words(). +# Note: You must call finalize_keccak() at the end of the program, where the range of the input +# is checked. Otherwise, these functions are not sound. + +# Serializes a uint256 number in a keccak compatible way. +# The argument 'bigend' is either 0 or 1, representing the endianness of the given number. +func keccak_add_uint256{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, inputs : felt*}( + num : Uint256, bigend : felt +): + if bigend != 0: + let (num_reversed) = uint256_reverse_endian(num=num) + tempvar bitwise_ptr = bitwise_ptr + tempvar high = num_reversed.high + tempvar low = num_reversed.low + else: + tempvar bitwise_ptr = bitwise_ptr + tempvar high = num.high + tempvar low = num.low + end + + %{ + segments.write_arg(ids.inputs, [ids.low % 2 ** 64, ids.low // 2 ** 64]) + segments.write_arg(ids.inputs + 2, [ids.high % 2 ** 64, ids.high // 2 ** 64]) + %} + + assert inputs[1] * 2 ** 64 + inputs[0] = low + assert inputs[3] * 2 ** 64 + inputs[2] = high + + let inputs = inputs + 4 + return () +end + +# Serializes multiple uint256 numbers in a keccak compatible way. +# The argument 'bigend' is either 0 or 1, representing the endianness of the given numbers. +# Note: This function does not serialize the number of elements. If desired, this is the caller's +# responsibility. +func keccak_add_uint256s{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, inputs : felt*}( + n_elements : felt, elements : Uint256*, bigend : felt +): + if n_elements == 0: + return () + end + + keccak_add_uint256(num=elements[0], bigend=bigend) + return keccak_add_uint256s(n_elements=n_elements - 1, elements=&elements[1], bigend=bigend) +end + +# Serializes a field element in a keccak compatible way. +# The argument 'bigend' is either 0 or 1, representing the endianness of the given element. +func keccak_add_felt{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, inputs : felt*}( + num : felt, bigend : felt +): + let (high, low) = split_felt(value=num) + keccak_add_uint256(num=Uint256(low=low, high=high), bigend=bigend) + + return () +end + +# Serializes multiple field elements in a keccak compatible way. +# The argument 'bigend' is either 0 or 1, representing the endianness of the given elements. +# Note: This function does not serialize the number of elements. If desired, this is the caller's +# responsibility. +func keccak_add_felts{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, inputs : felt*}( + n_elements : felt, elements : felt*, bigend : felt +): + if n_elements == 0: + return () + end + + keccak_add_felt(num=elements[0], bigend=bigend) + return keccak_add_felts(n_elements=n_elements - 1, elements=&elements[1], bigend=bigend) +end + +# Computes the keccak of 'input'. +# To use this function, split the input into words of 64 bits (little endian). +# For example, to compute keccak('Hello world!'), use: +# inputs = [8031924123371070792, 560229490] +# where: +# 8031924123371070792 == int.from_bytes(b'Hello wo', 'little') +# 560229490 == int.from_bytes(b'rld!', 'little') +# +# Returns the hash as a Uint256. +# +# Note: You must call finalize_keccak() at the end of the program. Otherwise, this function +# is not sound and a malicious prover may return a wrong result. +func keccak{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt*}( + inputs : felt*, n_bytes : felt +) -> (res : Uint256): + let (output) = keccak_as_words(inputs=inputs, n_bytes=n_bytes) + + let res_low = output[1] * 2 ** 64 + output[0] + let res_high = output[3] * 2 ** 64 + output[2] + + return (res=Uint256(low=res_low, high=res_high)) +end + +# Same as keccak, but outputs the hash in big endian representation. +# Note that the input is still treated as little endian. +func keccak_bigend{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt*}( + inputs : felt*, n_bytes : felt +) -> (res : Uint256): + let (hash) = keccak(inputs=inputs, n_bytes=n_bytes) + let (res) = uint256_reverse_endian(num=hash) + return (res=res) +end + +# Same as keccak, but outputs a pointer to 4 64-bit little endian words instead. +func keccak_as_words{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt*}( + inputs : felt*, n_bytes : felt +) -> (output : felt*): + alloc_locals + + let (local state) = alloc() + memset(dst=state, value=0, n=KECCAK_STATE_SIZE_FELTS) + + return _keccak(inputs=inputs, n_bytes=n_bytes, state=state) +end + +# Prepares a block for the block permutation: adds padding (of the form 100...001, see the +# _padding function) and capacity (8 64-bits words of zeros) to the input, xors the result +# with the previous block permutation's output, and writes it to keccak_ptr. +# +# This function is called for every block that is sent to the _block_permutation +# function. Each time it is called with a chunk of the input of at +# most 17 64-bit words. That is, with n_bytes <= 136. When it is called with exactly 136 +# bytes, no padding is added. Only the last block is padded. +# +# Arguments: +# inputs - chunk of the input, in little endian. +# n_bytes - the length of inputs in bytes. Must be in the range [0, 136]. +# state - the output of the previous block permutation that contains 25 64-bits words. +func _prepare_block{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt*}( + inputs : felt*, n_bytes : felt, state : felt* +): + alloc_locals + + let inputs_start = inputs + _copy_inputs{inputs=inputs, n_bytes=n_bytes, state=state}() + # n_words_written is the number of words written to keccak_ptr. + let n_words_written = inputs - inputs_start + + tempvar padding_len = (KECCAK_FULL_RATE_IN_WORDS - n_words_written) + local input_word + if n_bytes == 0: + input_word = 0 + else: + input_word = inputs[0] + end + + _padding(input_word=input_word, n_bytes=n_bytes, state=state, padding_len=padding_len) + let state = state + padding_len + + # Since the capacity part consists of zeros, we simply copy the state. + memcpy(dst=keccak_ptr, src=state, len=KECCAK_CAPACITY_IN_WORDS) + let keccak_ptr = keccak_ptr + KECCAK_CAPACITY_IN_WORDS + + return () +end + +# Xors full words from the input with the corresponding words from the output of the +# previous block permutation, and writes the restult to keccak_ptr. +func _copy_inputs{ + range_check_ptr, + bitwise_ptr : BitwiseBuiltin*, + keccak_ptr : felt*, + inputs : felt*, + n_bytes : felt, + state : felt*, +}(): + if nondet %{ ids.n_bytes < ids.BYTES_IN_WORD %} != 0: + assert_nn_le(n_bytes, BYTES_IN_WORD - 1) + return () + end + + let (next_word) = bitwise_xor(inputs[0], state[0]) + assert keccak_ptr[0] = next_word + + let inputs = &inputs[1] + let state = &state[1] + let keccak_ptr = &keccak_ptr[1] + let n_bytes = n_bytes - BYTES_IN_WORD + + return _copy_inputs() +end + +# Adds padding of the form 100...001 to the last bytes of the input, to a total of +# padding_len words, xors the result with the output of the last block permutation, +# from the corresponding offset, and writes it to keccak_ptr. +# +# Arguments: +# input_word - the last word of the input to keccak (given in little endian) +# if it has less than 8 bytes, otherwise 0. +# n_bytes - the number of bytes in input_word. Must be in the range [0, 8). +# state - the output of the last block permutation, from the word that corresponds +# to the input_word. (i.e., if input_word is the i-th word in the current block, +# then state points to the i-th word of the last block permutation's output). +# padding_len - the length of the required padding (in words). Must be in the range [0, 17]. +func _padding{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt*}( + input_word : felt, n_bytes : felt, state : felt*, padding_len : felt +): + if padding_len == 0: + return () + end + + let (first_one) = pow(256, n_bytes) + # The beginning of the padding with the last bytes of the input and the first 1. + let input_word_with_initial_padding = input_word + first_one + + if padding_len == 1: + let both_ones = 2 ** 63 + input_word_with_initial_padding + let (word) = bitwise_xor(both_ones, state[0]) + + assert keccak_ptr[0] = word + let keccak_ptr = &keccak_ptr[1] + + return () + end + + return _long_padding( + input_word_with_initial_padding=input_word_with_initial_padding, + state=state, + padding_len=padding_len, + ) +end + +# Padding of more than 1 word. +func _long_padding{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt*}( + input_word_with_initial_padding : felt, state : felt*, padding_len : felt +): + alloc_locals + + # First word. + let (first_one) = bitwise_xor(input_word_with_initial_padding, state[0]) + assert keccak_ptr[0] = first_one + let keccak_ptr = &keccak_ptr[1] + let state = state + 1 + + # The padding of the inner words is zero, so we should simply copy them. + memcpy(dst=keccak_ptr, src=state, len=padding_len - 2) + let keccak_ptr = keccak_ptr + padding_len - 2 + let state = state + padding_len - 2 + + # Last word. + let (second_one) = bitwise_xor(2 ** 63, state[0]) + assert keccak_ptr[0] = second_one + let keccak_ptr = &keccak_ptr[1] + + return () +end + +func _block_permutation{keccak_ptr : felt*}(): + %{ + from starkware.cairo.common.cairo_keccak.keccak_utils import keccak_func + _keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS) + assert 0 <= _keccak_state_size_felts < 100 + + output_values = keccak_func(memory.get_range( + ids.keccak_ptr - _keccak_state_size_felts, _keccak_state_size_felts)) + segments.write_arg(ids.keccak_ptr, output_values) + %} + let keccak_ptr = keccak_ptr + KECCAK_STATE_SIZE_FELTS + + return () +end + +func _keccak{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt*}( + inputs : felt*, n_bytes : felt, state : felt* +) -> (output : felt*): + if nondet %{ ids.n_bytes >= ids.KECCAK_FULL_RATE_IN_BYTES %} != 0: + _prepare_block(inputs=inputs, n_bytes=KECCAK_FULL_RATE_IN_BYTES, state=state) + _block_permutation() + + return _keccak( + inputs=inputs + KECCAK_FULL_RATE_IN_WORDS, + n_bytes=n_bytes - KECCAK_FULL_RATE_IN_BYTES, + state=keccak_ptr - KECCAK_STATE_SIZE_FELTS, + ) + end + + assert_nn_le(n_bytes, KECCAK_FULL_RATE_IN_BYTES - 1) + + _prepare_block(inputs=inputs, n_bytes=n_bytes, state=state) + _block_permutation() + + return (keccak_ptr - KECCAK_STATE_SIZE_FELTS) +end + +# Verifies that the results of keccak() are valid. +func finalize_keccak{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}( + keccak_ptr_start : felt*, keccak_ptr_end : felt* +): + alloc_locals + + tempvar n = (keccak_ptr_end - keccak_ptr_start) / (2 * KECCAK_STATE_SIZE_FELTS) + if n == 0: + return () + end + + %{ + # Add dummy pairs of input and output. + _keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS) + _block_size = int(ids.BLOCK_SIZE) + assert 0 <= _keccak_state_size_felts < 100 + assert 0 <= _block_size < 10 + inp = [0] * _keccak_state_size_felts + padding = (inp + keccak_func(inp)) * _block_size + segments.write_arg(ids.keccak_ptr_end, padding) + %} + + # Compute the amount of blocks (rounded up). + let (local q, r) = unsigned_div_rem(n + BLOCK_SIZE - 1, BLOCK_SIZE) + _finalize_keccak_inner(keccak_ptr_start, n=q) + return () +end + +# Handles n blocks of BLOCK_SIZE keccak instances. +func _finalize_keccak_inner{range_check_ptr, bitwise_ptr : BitwiseBuiltin*}( + keccak_ptr : felt*, n : felt +): + if n == 0: + return () + end + + ap += SIZEOF_LOCALS + + local MAX_VALUE = 2 ** 64 - 1 + + let keccak_ptr_start = keccak_ptr + + let (local inputs_start : felt*) = alloc() + + # Handle inputs. + + tempvar inputs = inputs_start + tempvar keccak_ptr = keccak_ptr + tempvar range_check_ptr = range_check_ptr + tempvar m = 25 + + input_loop: + tempvar x0 = keccak_ptr[0] + assert [range_check_ptr] = x0 + assert [range_check_ptr + 1] = MAX_VALUE - x0 + tempvar x1 = keccak_ptr[50] + assert [range_check_ptr + 2] = x1 + assert [range_check_ptr + 3] = MAX_VALUE - x1 + tempvar x2 = keccak_ptr[100] + assert [range_check_ptr + 4] = x2 + assert [range_check_ptr + 5] = MAX_VALUE - x2 + assert inputs[0] = x0 + 2 ** 64 * x1 + 2 ** 128 * x2 + + tempvar inputs = inputs + 1 + tempvar keccak_ptr = keccak_ptr + 1 + tempvar range_check_ptr = range_check_ptr + 6 + tempvar m = m - 1 + jmp input_loop if m != 0 + + # Run keccak on the 3 instances. + + let (outputs) = packed_keccak_func(inputs_start) + local bitwise_ptr : BitwiseBuiltin* = bitwise_ptr + + # Handle outputs. + + tempvar outputs = outputs + tempvar keccak_ptr = keccak_ptr + tempvar range_check_ptr = range_check_ptr + tempvar m = 25 + + output_loop: + tempvar x0 = keccak_ptr[0] + assert [range_check_ptr] = x0 + assert [range_check_ptr + 1] = MAX_VALUE - x0 + tempvar x1 = keccak_ptr[50] + assert [range_check_ptr + 2] = x1 + assert [range_check_ptr + 3] = MAX_VALUE - x1 + tempvar x2 = keccak_ptr[100] + assert [range_check_ptr + 4] = x2 + assert [range_check_ptr + 5] = MAX_VALUE - x2 + assert outputs[0] = x0 + 2 ** 64 * x1 + 2 ** 128 * x2 + + tempvar outputs = outputs + 1 + tempvar keccak_ptr = keccak_ptr + 1 + tempvar range_check_ptr = range_check_ptr + 6 + tempvar m = m - 1 + jmp output_loop if m != 0 + + return _finalize_keccak_inner(keccak_ptr=keccak_ptr_start + 150, n=n - 1) +end diff --git a/src/starkware/cairo/common/cairo_keccak/packed_keccak.cairo b/src/starkware/cairo/common/cairo_keccak/packed_keccak.cairo new file mode 100644 index 00000000..ae6bd384 --- /dev/null +++ b/src/starkware/cairo/common/cairo_keccak/packed_keccak.cairo @@ -0,0 +1,556 @@ +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.cairo_builtins import BitwiseBuiltin +from starkware.cairo.common.registers import get_fp_and_pc + +const ALL_ONES = 2 ** 251 - 1 +const BLOCK_SIZE = 3 +const SHIFTS = 1 + 2 ** 64 + 2 ** 128 + +func keccak_round{bitwise_ptr : BitwiseBuiltin*}(values : felt*, rc : felt) -> (values_b : felt*): + ap += SIZEOF_LOCALS + + ############################################################################################## + # Compute: c[x] = a[0][x] ^ a[1][x] ^ a[2][x] ^ a[3][x] ^ a[4][x]. # + ############################################################################################## + + let values_start = values + + assert bitwise_ptr[0].x = values[0] + assert bitwise_ptr[0].y = values[5] + assert bitwise_ptr[1].x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].y = values[10] + assert bitwise_ptr[2].x = bitwise_ptr[1].x_xor_y + assert bitwise_ptr[2].y = values[15] + assert bitwise_ptr[3].x = bitwise_ptr[2].x_xor_y + assert bitwise_ptr[3].y = values[20] + tempvar c0 = bitwise_ptr[3].x_xor_y + let values = values + 1 + let bitwise_ptr = bitwise_ptr + 4 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[0] + assert bitwise_ptr[0].y = values[5] + assert bitwise_ptr[1].x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].y = values[10] + assert bitwise_ptr[2].x = bitwise_ptr[1].x_xor_y + assert bitwise_ptr[2].y = values[15] + assert bitwise_ptr[3].x = bitwise_ptr[2].x_xor_y + assert bitwise_ptr[3].y = values[20] + tempvar c1 = bitwise_ptr[3].x_xor_y + let values = values + 1 + let bitwise_ptr = bitwise_ptr + 4 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[0] + assert bitwise_ptr[0].y = values[5] + assert bitwise_ptr[1].x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].y = values[10] + assert bitwise_ptr[2].x = bitwise_ptr[1].x_xor_y + assert bitwise_ptr[2].y = values[15] + assert bitwise_ptr[3].x = bitwise_ptr[2].x_xor_y + assert bitwise_ptr[3].y = values[20] + tempvar c2 = bitwise_ptr[3].x_xor_y + let values = values + 1 + let bitwise_ptr = bitwise_ptr + 4 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[0] + assert bitwise_ptr[0].y = values[5] + assert bitwise_ptr[1].x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].y = values[10] + assert bitwise_ptr[2].x = bitwise_ptr[1].x_xor_y + assert bitwise_ptr[2].y = values[15] + assert bitwise_ptr[3].x = bitwise_ptr[2].x_xor_y + assert bitwise_ptr[3].y = values[20] + tempvar c3 = bitwise_ptr[3].x_xor_y + let values = values + 1 + let bitwise_ptr = bitwise_ptr + 4 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[0] + assert bitwise_ptr[0].y = values[5] + assert bitwise_ptr[1].x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].y = values[10] + assert bitwise_ptr[2].x = bitwise_ptr[1].x_xor_y + assert bitwise_ptr[2].y = values[15] + assert bitwise_ptr[3].x = bitwise_ptr[2].x_xor_y + assert bitwise_ptr[3].y = values[20] + tempvar c4 = bitwise_ptr[3].x_xor_y + let values = values + 1 + let bitwise_ptr = bitwise_ptr + 4 * BitwiseBuiltin.SIZE + + ############################################################################################## + # Compute: d[x] = c[(x - 1) % 5] ^ rot_left(c[(x + 1) % 5], 1). # + ############################################################################################## + + # Saving constants as local variables is more efficient in some instructions. + local mask = 0x800000000000000080000000000000008000000000000000 + + let x = c1 + let y = c4 + assert bitwise_ptr[0].x = x + assert bitwise_ptr[0].y = mask + tempvar x0 = bitwise_ptr[0].x_and_y + let rotx = 2 * x + (1 / 2 ** 63 - 2) * x0 + assert bitwise_ptr[1].x = rotx + assert bitwise_ptr[1].y = y + tempvar d0 = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + let x = c2 + let y = c0 + assert bitwise_ptr[0].x = x + assert bitwise_ptr[0].y = mask + tempvar x0 = bitwise_ptr[0].x_and_y + let rotx = 2 * x + (1 / 2 ** 63 - 2) * x0 + assert bitwise_ptr[1].x = rotx + assert bitwise_ptr[1].y = y + tempvar d1 = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + let x = c3 + let y = c1 + assert bitwise_ptr[0].x = x + assert bitwise_ptr[0].y = mask + tempvar x0 = bitwise_ptr[0].x_and_y + let rotx = 2 * x + (1 / 2 ** 63 - 2) * x0 + assert bitwise_ptr[1].x = rotx + assert bitwise_ptr[1].y = y + tempvar d2 = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + let x = c4 + let y = c2 + assert bitwise_ptr[0].x = x + assert bitwise_ptr[0].y = mask + tempvar x0 = bitwise_ptr[0].x_and_y + let rotx = 2 * x + (1 / 2 ** 63 - 2) * x0 + assert bitwise_ptr[1].x = rotx + assert bitwise_ptr[1].y = y + tempvar d3 = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + let x = c0 + let y = c3 + assert bitwise_ptr[0].x = x + assert bitwise_ptr[0].y = mask + tempvar x0 = bitwise_ptr[0].x_and_y + let rotx = 2 * x + (1 / 2 ** 63 - 2) * x0 + assert bitwise_ptr[1].x = rotx + assert bitwise_ptr[1].y = y + tempvar d4 = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + ############################################################################################## + # Compute: b[(2 * x + 3 * y) % 5][y] = rot_left([a[y][x] ^ d[x], OFFSETS[x][y]) # + ############################################################################################## + + let values = values_start + + assert bitwise_ptr[0].x = values[0] + assert bitwise_ptr[0].y = d0 + tempvar b0 = bitwise_ptr[0].x_xor_y + let bitwise_ptr = bitwise_ptr + 1 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[5] + assert bitwise_ptr[0].y = d0 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xfffffffff0000000fffffffff0000000fffffffff0000000 + tempvar b16 = 2 ** 36 * x + (1 / 2 ** 28 - 2 ** 36) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[10] + assert bitwise_ptr[0].y = d0 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xe000000000000000e000000000000000e000000000000000 + tempvar b7 = 2 ** 3 * x + (1 / 2 ** 61 - 2 ** 3) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[15] + assert bitwise_ptr[0].y = d0 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xffffffffff800000ffffffffff800000ffffffffff800000 + tempvar b23 = 2 ** 41 * x + (1 / 2 ** 23 - 2 ** 41) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[20] + assert bitwise_ptr[0].y = d0 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xffffc00000000000ffffc00000000000ffffc00000000000 + tempvar b14 = 2 ** 18 * x + (1 / 2 ** 46 - 2 ** 18) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[1] + assert bitwise_ptr[0].y = d1 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0x800000000000000080000000000000008000000000000000 + tempvar b10 = 2 ** 1 * x + (1 / 2 ** 63 - 2 ** 1) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[6] + assert bitwise_ptr[0].y = d1 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xfffffffffff00000fffffffffff00000fffffffffff00000 + tempvar b1 = 2 ** 44 * x + (1 / 2 ** 20 - 2 ** 44) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[11] + assert bitwise_ptr[0].y = d1 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xffc0000000000000ffc0000000000000ffc0000000000000 + tempvar b17 = 2 ** 10 * x + (1 / 2 ** 54 - 2 ** 10) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[16] + assert bitwise_ptr[0].y = d1 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xfffffffffff80000fffffffffff80000fffffffffff80000 + tempvar b8 = 2 ** 45 * x + (1 / 2 ** 19 - 2 ** 45) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[21] + assert bitwise_ptr[0].y = d1 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xc000000000000000c000000000000000c000000000000000 + tempvar b24 = 2 ** 2 * x + (1 / 2 ** 62 - 2 ** 2) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[2] + assert bitwise_ptr[0].y = d2 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xfffffffffffffffcfffffffffffffffcfffffffffffffffc + tempvar b20 = 2 ** 62 * x + (1 / 2 ** 2 - 2 ** 62) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[7] + assert bitwise_ptr[0].y = d2 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xfc00000000000000fc00000000000000fc00000000000000 + tempvar b11 = 2 ** 6 * x + (1 / 2 ** 58 - 2 ** 6) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[12] + assert bitwise_ptr[0].y = d2 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xffffffffffe00000ffffffffffe00000ffffffffffe00000 + tempvar b2 = 2 ** 43 * x + (1 / 2 ** 21 - 2 ** 43) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[17] + assert bitwise_ptr[0].y = d2 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xfffe000000000000fffe000000000000fffe000000000000 + tempvar b18 = 2 ** 15 * x + (1 / 2 ** 49 - 2 ** 15) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[22] + assert bitwise_ptr[0].y = d2 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xfffffffffffffff8fffffffffffffff8fffffffffffffff8 + tempvar b9 = 2 ** 61 * x + (1 / 2 ** 3 - 2 ** 61) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[3] + assert bitwise_ptr[0].y = d3 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xfffffff000000000fffffff000000000fffffff000000000 + tempvar b5 = 2 ** 28 * x + (1 / 2 ** 36 - 2 ** 28) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[8] + assert bitwise_ptr[0].y = d3 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xfffffffffffffe00fffffffffffffe00fffffffffffffe00 + tempvar b21 = 2 ** 55 * x + (1 / 2 ** 9 - 2 ** 55) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[13] + assert bitwise_ptr[0].y = d3 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xffffff8000000000ffffff8000000000ffffff8000000000 + tempvar b12 = 2 ** 25 * x + (1 / 2 ** 39 - 2 ** 25) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[18] + assert bitwise_ptr[0].y = d3 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xfffff80000000000fffff80000000000fffff80000000000 + tempvar b3 = 2 ** 21 * x + (1 / 2 ** 43 - 2 ** 21) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[23] + assert bitwise_ptr[0].y = d3 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xffffffffffffff00ffffffffffffff00ffffffffffffff00 + tempvar b19 = 2 ** 56 * x + (1 / 2 ** 8 - 2 ** 56) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[4] + assert bitwise_ptr[0].y = d4 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xffffffe000000000ffffffe000000000ffffffe000000000 + tempvar b15 = 2 ** 27 * x + (1 / 2 ** 37 - 2 ** 27) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[9] + assert bitwise_ptr[0].y = d4 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xfffff00000000000fffff00000000000fffff00000000000 + tempvar b6 = 2 ** 20 * x + (1 / 2 ** 44 - 2 ** 20) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[14] + assert bitwise_ptr[0].y = d4 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xfffffffffe000000fffffffffe000000fffffffffe000000 + tempvar b22 = 2 ** 39 * x + (1 / 2 ** 25 - 2 ** 39) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[19] + assert bitwise_ptr[0].y = d4 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xff00000000000000ff00000000000000ff00000000000000 + tempvar b13 = 2 ** 8 * x + (1 / 2 ** 56 - 2 ** 8) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = values[24] + assert bitwise_ptr[0].y = d4 + tempvar x = bitwise_ptr[0].x_xor_y + assert bitwise_ptr[1].x = x + assert bitwise_ptr[1].y = 0xfffc000000000000fffc000000000000fffc000000000000 + tempvar b4 = 2 ** 14 * x + (1 / 2 ** 50 - 2 ** 14) * bitwise_ptr[1].x_and_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + ############################################################################################## + # Compute: a[y][x] = [[a[y][x] ^ ((~a[y][(x + 1) % 5]) & a[y][(x + 2) % 5]) # + ############################################################################################## + + let (local output : felt*) = alloc() + + assert bitwise_ptr[0].x = ALL_ONES - b1 + assert bitwise_ptr[0].y = b2 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b0 + assert bitwise_ptr[2].x = bitwise_ptr[1].x_xor_y + assert bitwise_ptr[2].y = rc + assert output[0] = bitwise_ptr[2].x_xor_y + let bitwise_ptr = bitwise_ptr + 3 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b6 + assert bitwise_ptr[0].y = b7 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b5 + assert output[5] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b11 + assert bitwise_ptr[0].y = b12 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b10 + assert output[10] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b16 + assert bitwise_ptr[0].y = b17 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b15 + assert output[15] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b21 + assert bitwise_ptr[0].y = b22 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b20 + assert output[20] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b2 + assert bitwise_ptr[0].y = b3 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b1 + assert output[1] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b7 + assert bitwise_ptr[0].y = b8 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b6 + assert output[6] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b12 + assert bitwise_ptr[0].y = b13 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b11 + assert output[11] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b17 + assert bitwise_ptr[0].y = b18 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b16 + assert output[16] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b22 + assert bitwise_ptr[0].y = b23 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b21 + assert output[21] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b3 + assert bitwise_ptr[0].y = b4 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b2 + assert output[2] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b8 + assert bitwise_ptr[0].y = b9 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b7 + assert output[7] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b13 + assert bitwise_ptr[0].y = b14 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b12 + assert output[12] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b18 + assert bitwise_ptr[0].y = b19 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b17 + assert output[17] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b23 + assert bitwise_ptr[0].y = b24 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b22 + assert output[22] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b4 + assert bitwise_ptr[0].y = b0 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b3 + assert output[3] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b9 + assert bitwise_ptr[0].y = b5 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b8 + assert output[8] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b14 + assert bitwise_ptr[0].y = b10 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b13 + assert output[13] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b19 + assert bitwise_ptr[0].y = b15 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b18 + assert output[18] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b24 + assert bitwise_ptr[0].y = b20 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b23 + assert output[23] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b0 + assert bitwise_ptr[0].y = b1 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b4 + assert output[4] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b5 + assert bitwise_ptr[0].y = b6 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b9 + assert output[9] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b10 + assert bitwise_ptr[0].y = b11 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b14 + assert output[14] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b15 + assert bitwise_ptr[0].y = b16 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b19 + assert output[19] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + assert bitwise_ptr[0].x = ALL_ONES - b20 + assert bitwise_ptr[0].y = b21 + assert bitwise_ptr[1].x = bitwise_ptr[0].x_and_y + assert bitwise_ptr[1].y = b24 + assert output[24] = bitwise_ptr[1].x_xor_y + let bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE + + return (output) +end + +func packed_keccak_func{bitwise_ptr : BitwiseBuiltin*}(values : felt*) -> (values : felt*): + let (values) = keccak_round(values, 0x0000000000000001 * SHIFTS) + let (values) = keccak_round(values, 0x0000000000008082 * SHIFTS) + let (values) = keccak_round(values, 0x800000000000808A * SHIFTS) + let (values) = keccak_round(values, 0x8000000080008000 * SHIFTS) + let (values) = keccak_round(values, 0x000000000000808B * SHIFTS) + let (values) = keccak_round(values, 0x0000000080000001 * SHIFTS) + let (values) = keccak_round(values, 0x8000000080008081 * SHIFTS) + let (values) = keccak_round(values, 0x8000000000008009 * SHIFTS) + let (values) = keccak_round(values, 0x000000000000008A * SHIFTS) + let (values) = keccak_round(values, 0x0000000000000088 * SHIFTS) + let (values) = keccak_round(values, 0x0000000080008009 * SHIFTS) + let (values) = keccak_round(values, 0x000000008000000A * SHIFTS) + let (values) = keccak_round(values, 0x000000008000808B * SHIFTS) + let (values) = keccak_round(values, 0x800000000000008B * SHIFTS) + let (values) = keccak_round(values, 0x8000000000008089 * SHIFTS) + let (values) = keccak_round(values, 0x8000000000008003 * SHIFTS) + let (values) = keccak_round(values, 0x8000000000008002 * SHIFTS) + let (values) = keccak_round(values, 0x8000000000000080 * SHIFTS) + let (values) = keccak_round(values, 0x000000000000800A * SHIFTS) + let (values) = keccak_round(values, 0x800000008000000A * SHIFTS) + let (values) = keccak_round(values, 0x8000000080008081 * SHIFTS) + let (values) = keccak_round(values, 0x8000000000008080 * SHIFTS) + let (values) = keccak_round(values, 0x0000000080000001 * SHIFTS) + let (values) = keccak_round(values, 0x8000000080008008 * SHIFTS) + + return (values) +end diff --git a/src/starkware/cairo/common/cairo_secp/bigint.cairo b/src/starkware/cairo/common/cairo_secp/bigint.cairo new file mode 100644 index 00000000..af0ac1a3 --- /dev/null +++ b/src/starkware/cairo/common/cairo_secp/bigint.cairo @@ -0,0 +1,125 @@ +from starkware.cairo.common.cairo_secp.constants import BASE +from starkware.cairo.common.math import assert_nn, assert_nn_le, unsigned_div_rem +from starkware.cairo.common.math_cmp import RC_BOUND +from starkware.cairo.common.uint256 import Uint256 + +# Represents a big integer defined by: +# d0 + BASE * d1 + BASE**2 * d2. +# Note that the limbs (d_i) are NOT restricted to the range [0, BASE) and in particular they +# can be negative. +# In most cases this is used to represent a secp256k1 field element. +struct UnreducedBigInt3: + member d0 : felt + member d1 : felt + member d2 : felt +end + +# Same as UnreducedBigInt3, except that d0, d1 and d2 must be in the range [0, 3 * BASE). +# In most cases this is used to represent a secp256k1 field element. +struct BigInt3: + member d0 : felt + member d1 : felt + member d2 : felt +end + +# Represents a big integer defined by: +# sum_i(BASE**i * d_i). +# Note that the limbs (d_i) are NOT restricted to the range [0, BASE) and in particular they +# can be negative. +struct UnreducedBigInt5: + member d0 : felt + member d1 : felt + member d2 : felt + member d3 : felt + member d4 : felt +end + +# Computes the multiplication of two big integers, given in BigInt3 representation. +# +# Arguments: +# x, y - the two BigInt3 to operate on. +# +# Returns: +# x * y in an UnreducedBigInt5 representation. +func bigint_mul(x : BigInt3, y : BigInt3) -> (res : UnreducedBigInt5): + return ( + UnreducedBigInt5( + d0=x.d0 * y.d0, + d1=x.d0 * y.d1 + x.d1 * y.d0, + d2=x.d0 * y.d2 + x.d1 * y.d1 + x.d2 * y.d0, + d3=x.d1 * y.d2 + x.d2 * y.d1, + d4=x.d2 * y.d2), + ) +end + +# Returns a BigInt3 instance whose value is controlled by a prover hint. +# +# Soundness guarantee: each limb is in the range [0, 3 * BASE). +# Completeness guarantee (honest prover): the value is in reduced form and in particular, +# each limb is in the range [0, BASE). +# +# Implicit arguments: +# range_check_ptr - range check builtin pointer. +# +# Hint arguments: value. +func nondet_bigint3{range_check_ptr}() -> (res : BigInt3): + # The result should be at the end of the stack after the function returns. + let res : BigInt3 = [cast(ap + 5, BigInt3*)] + %{ + from starkware.cairo.common.cairo_secp.secp_utils import split + + segments.write_arg(ids.res.address_, split(value)) + %} + # The maximal possible sum of the limbs, assuming each of them is in the range [0, BASE). + const MAX_SUM = 3 * (BASE - 1) + assert [range_check_ptr] = MAX_SUM - (res.d0 + res.d1 + res.d2) + + # Prepare the result at the end of the stack. + tempvar range_check_ptr = range_check_ptr + 4 + [range_check_ptr - 3] = res.d0; ap++ + [range_check_ptr - 2] = res.d1; ap++ + [range_check_ptr - 1] = res.d2; ap++ + static_assert &res + BigInt3.SIZE == ap + return (res=res) +end + +# Converts a BigInt3 instance into a Uint256. +# +# Assumptions: +# * The limbs of x are in the range [0, BASE * 3). +# * x is in the range [0, 2 ** 256). +# * PRIME is at least 174 bits. +# Implicit arguments: +# range_check_ptr - range check builtin pointer. +func bigint_to_uint256{range_check_ptr}(x : BigInt3) -> (res : Uint256): + let low = [range_check_ptr] + let high = [range_check_ptr + 1] + let range_check_ptr = range_check_ptr + 2 + %{ ids.low = (ids.x.d0 + ids.x.d1 * ids.BASE) & ((1 << 128) - 1) %} + # Because PRIME is at least 174 bits, the numerator doesn't overflow. + tempvar a = ((x.d0 + x.d1 * BASE) - low) / RC_BOUND + const D2_SHIFT = BASE * BASE / RC_BOUND + const A_BOUND = 4 * D2_SHIFT + # We'll check that the division in `a` doesn't cause an overflow. This means that the 128 LSB + # of (x.d0 + x.d1 * BASE) and low are identical, which ensures that low is correct. + assert_nn_le(a, A_BOUND - 1) + # high * RC_BOUND = a * RC_BOUND + x.d2 * BASE ** 2 = + # = x.d0 + x.d1 * BASE + x.d2 * BASE ** 2 - low = num - low. + with_attr error_message("x out of range"): + assert high = a + x.d2 * D2_SHIFT + end + + return (res=Uint256(low=low, high=high)) +end + +# Converts a Uint256 instance into a BigInt3. +# Assuming x is a valid Uint256 (its two limbs are below 2 ** 128), the resulting number will have +# limbs in the range [0, BASE). +func uint256_to_bigint{range_check_ptr}(x : Uint256) -> (res : BigInt3): + const D1_HIGH_BOUND = BASE ** 2 / RC_BOUND + const D1_LOW_BOUND = RC_BOUND / BASE + let (d1_low, d0) = unsigned_div_rem(x.low, BASE) + let (d2, d1_high) = unsigned_div_rem(x.high, D1_HIGH_BOUND) + let d1 = d1_high * D1_LOW_BOUND + d1_low + return (BigInt3(d0=d0, d1=d1, d2=d2)) +end diff --git a/src/starkware/cairo/common/cairo_secp/constants.cairo b/src/starkware/cairo/common/cairo_secp/constants.cairo new file mode 100644 index 00000000..113f9cc5 --- /dev/null +++ b/src/starkware/cairo/common/cairo_secp/constants.cairo @@ -0,0 +1,21 @@ +# Basic definitions for the secp256k1 elliptic curve. +# The curve is given by the equation: +# y^2 = x^3 + 7 +# over the field Z/p for +# p = secp256k1_prime = 2 ** 256 - (2 ** 32 + 2 ** 9 + 2 ** 8 + 2 ** 7 + 2 ** 6 + 2 ** 4 + 1). +# The size of the curve is +# n = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141 (prime). + +# SECP_REM is defined by the equation: +# secp256k1_prime = 2 ** 256 - SECP_REM. +const SECP_REM = 2 ** 32 + 2 ** 9 + 2 ** 8 + 2 ** 7 + 2 ** 6 + 2 ** 4 + 1 + +# The following constants represent the size of the secp256k1 curve: +# n = N0 + BASE * N1 + BASE**2 * N2. +const BASE = 2 ** 86 +const N0 = 0x8a03bbfd25e8cd0364141 +const N1 = 0x3ffffffffffaeabb739abd +const N2 = 0xfffffffffffffffffffff + +# BETA is the free term in the curve equation. +const BETA = 7 diff --git a/src/starkware/cairo/common/cairo_secp/ec.cairo b/src/starkware/cairo/common/cairo_secp/ec.cairo new file mode 100644 index 00000000..5c0a5c5e --- /dev/null +++ b/src/starkware/cairo/common/cairo_secp/ec.cairo @@ -0,0 +1,316 @@ +from starkware.cairo.common.cairo_secp.bigint import BigInt3, UnreducedBigInt3, nondet_bigint3 +from starkware.cairo.common.cairo_secp.field import ( + is_zero, + unreduced_mul, + unreduced_sqr, + verify_zero, +) + +# Represents a point on the secp256k1 elliptic curve. +# The zero point is represented as a point with x = 0 (there is no point on the curve with a zero +# x value). +struct EcPoint: + member x : BigInt3 + member y : BigInt3 +end + +# Computes the negation of a point on the elliptic curve, which is a point with the same x value and +# the negation of the y value. If the point is the zero point, returns the zero point. +# +# Arguments: +# point - The point to operate on. +# +# Returns: +# point - The negation of the given point. +func ec_negate{range_check_ptr}(point : EcPoint) -> (point : EcPoint): + %{ + from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack + + y = pack(ids.point.y, PRIME) % SECP_P + # The modulo operation in python always returns a nonnegative number. + value = (-y) % SECP_P + %} + let (minus_y) = nondet_bigint3() + verify_zero( + UnreducedBigInt3( + d0=minus_y.d0 + point.y.d0, + d1=minus_y.d1 + point.y.d1, + d2=minus_y.d2 + point.y.d2), + ) + + return (point=EcPoint(x=point.x, y=minus_y)) +end + +# Computes the slope of the elliptic curve at a given point. +# The slope is used to compute point + point. +# +# Arguments: +# point - the point to operate on. +# +# Returns: +# slope - the slope of the curve at point, in BigInt3 representation. +# +# Assumption: point != 0. +func compute_doubling_slope{range_check_ptr}(point : EcPoint) -> (slope : BigInt3): + # Note that y cannot be zero: assume that it is, then point = -point, so 2 * point = 0, which + # contradicts the fact that the size of the curve is odd. + %{ + from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack + from starkware.python.math_utils import ec_double_slope + + # Compute the slope. + x = pack(ids.point.x, PRIME) + y = pack(ids.point.y, PRIME) + value = slope = ec_double_slope(point=(x, y), alpha=0, p=SECP_P) + %} + let (slope : BigInt3) = nondet_bigint3() + + let (x_sqr : UnreducedBigInt3) = unreduced_sqr(point.x) + let (slope_y : UnreducedBigInt3) = unreduced_mul(slope, point.y) + + verify_zero( + UnreducedBigInt3( + d0=3 * x_sqr.d0 - 2 * slope_y.d0, + d1=3 * x_sqr.d1 - 2 * slope_y.d1, + d2=3 * x_sqr.d2 - 2 * slope_y.d2), + ) + + return (slope=slope) +end + +# Computes the slope of the line connecting the two given points. +# The slope is used to compute point0 + point1. +# +# Arguments: +# point0, point1 - the points to operate on. +# +# Returns: +# slope - the slope of the line connecting point0 and point1, in BigInt3 representation. +# +# Assumptions: +# * point0.x != point1.x (mod secp256k1_prime). +# * point0, point1 != 0. +func compute_slope{range_check_ptr}(point0 : EcPoint, point1 : EcPoint) -> (slope : BigInt3): + %{ + from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack + from starkware.python.math_utils import line_slope + + # Compute the slope. + x0 = pack(ids.point0.x, PRIME) + y0 = pack(ids.point0.y, PRIME) + x1 = pack(ids.point1.x, PRIME) + y1 = pack(ids.point1.y, PRIME) + value = slope = line_slope(point1=(x0, y0), point2=(x1, y1), p=SECP_P) + %} + let (slope) = nondet_bigint3() + + let x_diff = BigInt3( + d0=point0.x.d0 - point1.x.d0, d1=point0.x.d1 - point1.x.d1, d2=point0.x.d2 - point1.x.d2 + ) + let (x_diff_slope : UnreducedBigInt3) = unreduced_mul(x_diff, slope) + + verify_zero( + UnreducedBigInt3( + d0=x_diff_slope.d0 - point0.y.d0 + point1.y.d0, + d1=x_diff_slope.d1 - point0.y.d1 + point1.y.d1, + d2=x_diff_slope.d2 - point0.y.d2 + point1.y.d2), + ) + + return (slope) +end + +# Computes the addition of a given point to itself. +# +# Arguments: +# point - the point to operate on. +# +# Returns: +# res - a point representing point + point. +func ec_double{range_check_ptr}(point : EcPoint) -> (res : EcPoint): + # The zero point. + if point.x.d0 == 0: + if point.x.d1 == 0: + if point.x.d2 == 0: + return (point) + end + end + end + + let (slope : BigInt3) = compute_doubling_slope(point) + let (slope_sqr : UnreducedBigInt3) = unreduced_sqr(slope) + + %{ + from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack + + slope = pack(ids.slope, PRIME) + x = pack(ids.point.x, PRIME) + y = pack(ids.point.y, PRIME) + + value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P + %} + let (new_x : BigInt3) = nondet_bigint3() + + %{ value = new_y = (slope * (x - new_x) - y) % SECP_P %} + let (new_y : BigInt3) = nondet_bigint3() + + verify_zero( + UnreducedBigInt3( + d0=slope_sqr.d0 - new_x.d0 - 2 * point.x.d0, + d1=slope_sqr.d1 - new_x.d1 - 2 * point.x.d1, + d2=slope_sqr.d2 - new_x.d2 - 2 * point.x.d2), + ) + + let (x_diff_slope : UnreducedBigInt3) = unreduced_mul( + BigInt3(d0=point.x.d0 - new_x.d0, d1=point.x.d1 - new_x.d1, d2=point.x.d2 - new_x.d2), slope + ) + + verify_zero( + UnreducedBigInt3( + d0=x_diff_slope.d0 - point.y.d0 - new_y.d0, + d1=x_diff_slope.d1 - point.y.d1 - new_y.d1, + d2=x_diff_slope.d2 - point.y.d2 - new_y.d2), + ) + + return (res=EcPoint(new_x, new_y)) +end + +# Computes the addition of two given points. +# +# Arguments: +# point0, point1 - the points to operate on. +# +# Returns: +# res - the sum of the two points (point0 + point1). +# +# Assumption: point0.x != point1.x (however, point0 = point1 = 0 is allowed). +# Note that this means that the function cannot be used if point0 = point1 != 0 +# (use ec_double() in this case) or point0 = -point1 != 0 (the result is 0 in this case). +func fast_ec_add{range_check_ptr}(point0 : EcPoint, point1 : EcPoint) -> (res : EcPoint): + # Check whether point0 is the zero point. + if point0.x.d0 == 0: + if point0.x.d1 == 0: + if point0.x.d2 == 0: + return (point1) + end + end + end + + # Check whether point1 is the zero point. + if point1.x.d0 == 0: + if point1.x.d1 == 0: + if point1.x.d2 == 0: + return (point0) + end + end + end + + let (slope : BigInt3) = compute_slope(point0, point1) + let (slope_sqr : UnreducedBigInt3) = unreduced_sqr(slope) + + %{ + from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack + + slope = pack(ids.slope, PRIME) + x0 = pack(ids.point0.x, PRIME) + x1 = pack(ids.point1.x, PRIME) + y0 = pack(ids.point0.y, PRIME) + + value = new_x = (pow(slope, 2, SECP_P) - x0 - x1) % SECP_P + %} + let (new_x : BigInt3) = nondet_bigint3() + + %{ value = new_y = (slope * (x0 - new_x) - y0) % SECP_P %} + let (new_y : BigInt3) = nondet_bigint3() + + verify_zero( + UnreducedBigInt3( + d0=slope_sqr.d0 - new_x.d0 - point0.x.d0 - point1.x.d0, + d1=slope_sqr.d1 - new_x.d1 - point0.x.d1 - point1.x.d1, + d2=slope_sqr.d2 - new_x.d2 - point0.x.d2 - point1.x.d2), + ) + + let (x_diff_slope : UnreducedBigInt3) = unreduced_mul( + BigInt3(d0=point0.x.d0 - new_x.d0, d1=point0.x.d1 - new_x.d1, d2=point0.x.d2 - new_x.d2), + slope, + ) + + verify_zero( + UnreducedBigInt3( + d0=x_diff_slope.d0 - point0.y.d0 - new_y.d0, + d1=x_diff_slope.d1 - point0.y.d1 - new_y.d1, + d2=x_diff_slope.d2 - point0.y.d2 - new_y.d2), + ) + + return (EcPoint(new_x, new_y)) +end + +# Same as fast_ec_add, except that the cases point0 = +/-point1 are supported. +func ec_add{range_check_ptr}(point0 : EcPoint, point1 : EcPoint) -> (res : EcPoint): + let x_diff = BigInt3( + d0=point0.x.d0 - point1.x.d0, d1=point0.x.d1 - point1.x.d1, d2=point0.x.d2 - point1.x.d2 + ) + let (same_x : felt) = is_zero(x_diff) + if same_x == 0: + # point0.x != point1.x so we can use fast_ec_add. + return fast_ec_add(point0, point1) + end + + # We have point0.x = point1.x. This implies point0.y = +/-point1.y. + # Check whether point0.y = -point1.y. + let y_sum = BigInt3( + d0=point0.y.d0 + point1.y.d0, d1=point0.y.d1 + point1.y.d1, d2=point0.y.d2 + point1.y.d2 + ) + let (opposite_y : felt) = is_zero(y_sum) + if opposite_y != 0: + # point0.y = -point1.y. + # Note that the case point0 = point1 = 0 falls into this branch as well. + let ZERO_POINT = EcPoint(BigInt3(0, 0, 0), BigInt3(0, 0, 0)) + return (ZERO_POINT) + else: + # point0.y = point1.y. + return ec_double(point0) + end +end + +# Given a scalar, an integer m in the range [0, 250), and a point on the elliptic curve, point, +# verifies that 0 <= scalar < 2**m and returns (2**m * point, scalar * point). +func ec_mul_inner{range_check_ptr}(point : EcPoint, scalar : felt, m : felt) -> ( + pow2 : EcPoint, res : EcPoint +): + if m == 0: + with_attr error_message("Too large scalar"): + scalar = 0 + end + let ZERO_POINT = EcPoint(BigInt3(0, 0, 0), BigInt3(0, 0, 0)) + return (pow2=point, res=ZERO_POINT) + end + + alloc_locals + let (double_point : EcPoint) = ec_double(point) + %{ memory[ap] = (ids.scalar % PRIME) % 2 %} + jmp odd if [ap] != 0; ap++ + return ec_mul_inner(point=double_point, scalar=scalar / 2, m=m - 1) + + odd: + let (local inner_pow2 : EcPoint, inner_res : EcPoint) = ec_mul_inner( + point=double_point, scalar=(scalar - 1) / 2, m=m - 1 + ) + # Here inner_res = (scalar - 1) / 2 * double_point = (scalar - 1) * point. + # Assume point != 0 and that inner_res = +/-point. We obtain (scalar - 1) * point = +/-point => + # scalar - 1 = +/-1 (mod N) => scalar = 0 or 2 (mod N). + # By induction, we know that (scalar - 1) / 2 must be in the range [0, 2**(m-1)), + # so scalar is an odd number in the range [0, 2**m), and we get a contradiction. + let (res : EcPoint) = fast_ec_add(point0=point, point1=inner_res) + return (pow2=inner_pow2, res=res) +end + +# Given a point and a 256-bit scalar, returns scalar * point. +func ec_mul{range_check_ptr}(point : EcPoint, scalar : BigInt3) -> (res : EcPoint): + alloc_locals + let (pow2_0 : EcPoint, local res0 : EcPoint) = ec_mul_inner(point, scalar.d0, 86) + let (pow2_1 : EcPoint, local res1 : EcPoint) = ec_mul_inner(pow2_0, scalar.d1, 86) + let (_, local res2 : EcPoint) = ec_mul_inner(pow2_1, scalar.d2, 84) + let (res : EcPoint) = ec_add(res0, res1) + let (res : EcPoint) = ec_add(res, res2) + return (res) +end diff --git a/src/starkware/cairo/common/cairo_secp/field.cairo b/src/starkware/cairo/common/cairo_secp/field.cairo new file mode 100644 index 00000000..4b9cb956 --- /dev/null +++ b/src/starkware/cairo/common/cairo_secp/field.cairo @@ -0,0 +1,145 @@ +from starkware.cairo.common.cairo_secp.bigint import BigInt3, UnreducedBigInt3, nondet_bigint3 +from starkware.cairo.common.cairo_secp.constants import BASE, SECP_REM + +# Computes the multiplication of two big integers, given in BigInt3 representation, modulo the +# secp256k1 prime. +# +# Arguments: +# x, y - the two BigInt3 to operate on. +# +# Returns: +# x * y in an UnreducedBigInt3 representation (the returned limbs may be above 3 * BASE). +# +# If each of the input limbs is in the range (-x, x), the result's limbs are guaranteed to be +# in the range (-x**2 * (2 ** 35.01), x**2 * (2 ** 35.01)) since log(8 * SECP_REM + 1) < 35.01. +# +# This means that if unreduced_mul is called on the result of nondet_bigint3, or the difference +# between two such results, we have: +# Soundness guarantee: the limbs are in the range (-2**210.18, 2**210.18). +# Completeness guarantee: the limbs are in the range (-2**207.01, 2**207.01). +func unreduced_mul(a : BigInt3, b : BigInt3) -> (res_low : UnreducedBigInt3): + # The result of the product is: + # sum_{i, j} a.d_i * b.d_j * BASE**(i + j) + # Since we are computing it mod secp256k1_prime, we replace the term + # a.d_i * b.d_j * BASE**(i + j) + # where i + j >= 3 with + # a.d_i * b.d_j * BASE**(i + j - 3) * 4 * SECP_REM + # since BASE ** 3 = 4 * SECP_REM (mod secp256k1_prime). + return ( + UnreducedBigInt3( + d0=a.d0 * b.d0 + (a.d1 * b.d2 + a.d2 * b.d1) * (4 * SECP_REM), + d1=a.d0 * b.d1 + a.d1 * b.d0 + (a.d2 * b.d2) * (4 * SECP_REM), + d2=a.d0 * b.d2 + a.d1 * b.d1 + a.d2 * b.d0), + ) +end + +# Computes the square of a big integer, given in BigInt3 representation, modulo the +# secp256k1 prime. +# +# Has the same guarantees as in unreduced_mul(a, a). +func unreduced_sqr(a : BigInt3) -> (res_low : UnreducedBigInt3): + tempvar twice_d0 = a.d0 * 2 + return ( + UnreducedBigInt3( + d0=a.d0 * a.d0 + (a.d1 * a.d2) * (2 * 4 * SECP_REM), + d1=twice_d0 * a.d1 + (a.d2 * a.d2) * (4 * SECP_REM), + d2=twice_d0 * a.d2 + a.d1 * a.d1), + ) +end + +# Verifies that the given unreduced value is equal to zero modulo the secp256k1 prime. +# +# Completeness assumption: val's limbs are in the range (-2**210.99, 2**210.99). +# Soundness assumption: val's limbs are in the range (-2**250, 2**250). +func verify_zero{range_check_ptr}(val : UnreducedBigInt3): + let q = [ap] + %{ + from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack + + q, r = divmod(pack(ids.val, PRIME), SECP_P) + assert r == 0, f"verify_zero: Invalid input {ids.val.d0, ids.val.d1, ids.val.d2}." + ids.q = q % PRIME + %} + let q_biased = [ap + 1] + q_biased = q + 2 ** 127; ap++ + [range_check_ptr] = q_biased; ap++ + # This implies that q is in the range [-2**127, 2**127). + + tempvar r1 = (val.d0 + q * SECP_REM) / BASE + assert [range_check_ptr + 1] = r1 + 2 ** 127 + # This implies that r1 is in the range [-2**127, 2**127). + # Therefore, r1 * BASE is in the range [-2**213, 2**213). + # By the soundness assumption, val.d0 is in the range (-2**250, 2**250). + # This implies that r1 * BASE = val.d0 + q * SECP_REM (as integers). + + tempvar r2 = (val.d1 + r1) / BASE + assert [range_check_ptr + 2] = r2 + 2 ** 127 + # Similarly, this implies that r2 * BASE = val.d1 + r1 (as integers). + # Therefore, r2 * BASE**2 = val.d1 * BASE + r1 * BASE. + + assert val.d2 = q * (BASE / 4) - r2 + # Similarly, this implies that q * BASE / 4 = val.d2 + r2 (as integers). + # Therefore, + # q * BASE**3 / 4 = val.d2 * BASE**2 + r2 * BASE ** 2 = + # val.d2 * BASE**2 + val.d1 * BASE + r1 * BASE = + # val.d2 * BASE**2 + val.d1 * BASE + val.d0 + q * SECP_REM = + # val + q * SECP_REM. + # Hence, val = q * (BASE**3 / 4 - SECP_REM) = q * (2**256 - SECP_REM) = q * secp256k1_prime. + + let range_check_ptr = range_check_ptr + 3 + return () +end + +# Returns 1 if x == 0 (mod secp256k1_prime), and 0 otherwise. +# +# Completeness assumption: x's limbs are in the range (-BASE, 2*BASE). +# Soundness assumption: x's limbs are in the range (-2**107.49, 2**107.49). +func is_zero{range_check_ptr}(x : BigInt3) -> (res : felt): + %{ + from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack + + x = pack(ids.x, PRIME) % SECP_P + %} + if nondet %{ x == 0 %} != 0: + verify_zero(UnreducedBigInt3(d0=x.d0, d1=x.d1, d2=x.d2)) + return (res=1) + end + + %{ + from starkware.cairo.common.cairo_secp.secp_utils import SECP_P + from starkware.python.math_utils import div_mod + + value = x_inv = div_mod(1, x, SECP_P) + %} + let (x_inv) = nondet_bigint3() + let (x_x_inv) = unreduced_mul(x, x_inv) + + # Check that x * x_inv = 1 to verify that x != 0. + verify_zero(UnreducedBigInt3( + d0=x_x_inv.d0 - 1, + d1=x_x_inv.d1, + d2=x_x_inv.d2)) + return (res=0) +end + +# Receives an unreduced number, and returns a number that is equal to the original number mod SECP_P +# and in reduced form (meaning every limb is in the range [0, BASE)). +# +# Completeness assumption: x's limbs are in the range (-2**210.99, 2**210.99). +# Soundness assumption: x's limbs are in the range (-2**249.99, 2**249.99). +func reduce{range_check_ptr}(x : UnreducedBigInt3) -> (reduced_x : BigInt3): + %{ + from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack + + value = pack(ids.x, PRIME) % SECP_P + %} + let (reduced_x : BigInt3) = nondet_bigint3() + + verify_zero( + UnreducedBigInt3( + d0=x.d0 - reduced_x.d0, + d1=x.d1 - reduced_x.d1, + d2=x.d2 - reduced_x.d2), + ) + return (reduced_x=reduced_x) +end diff --git a/src/starkware/cairo/common/cairo_secp/signature.cairo b/src/starkware/cairo/common/cairo_secp/signature.cairo new file mode 100644 index 00000000..3ee37493 --- /dev/null +++ b/src/starkware/cairo/common/cairo_secp/signature.cairo @@ -0,0 +1,239 @@ +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.cairo_builtins import BitwiseBuiltin +from starkware.cairo.common.cairo_keccak.keccak import finalize_keccak, keccak_uint256s_bigend +from starkware.cairo.common.cairo_secp.bigint import ( + BASE, + BigInt3, + UnreducedBigInt3, + bigint_mul, + bigint_to_uint256, + nondet_bigint3, + uint256_to_bigint, +) +from starkware.cairo.common.cairo_secp.constants import BETA, N0, N1, N2 +from starkware.cairo.common.cairo_secp.ec import EcPoint, ec_add, ec_mul, ec_negate +from starkware.cairo.common.cairo_secp.field import ( + reduce, + unreduced_mul, + unreduced_sqr, + verify_zero, +) +from starkware.cairo.common.math import assert_nn, assert_nn_le, assert_not_zero, unsigned_div_rem +from starkware.cairo.common.math_cmp import RC_BOUND +from starkware.cairo.common.uint256 import Uint256 + +@known_ap_change +func get_generator_point() -> (point : EcPoint): + # generator_point = ( + # 0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798, + # 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8 + # ) + return ( + point=EcPoint( + BigInt3(0xe28d959f2815b16f81798, 0xa573a1c2c1c0a6ff36cb7, 0x79be667ef9dcbbac55a06), + BigInt3(0x554199c47d08ffb10d4b8, 0x2ff0384422a3f45ed1229a, 0x483ada7726a3c4655da4f)), + ) +end + +# Computes a * b^(-1) modulo the size of the elliptic curve (N). +# +# Prover assumptions: +# * All the limbs of a are in the range (-2 ** 210.99, 2 ** 210.99). +# * All the limbs of b are in the range (-2 ** 124.99, 2 ** 124.99). +# * b is in the range [0, 2 ** 256). +func div_mod_n{range_check_ptr}(a : BigInt3, b : BigInt3) -> (res : BigInt3): + %{ + from starkware.cairo.common.cairo_secp.secp_utils import N, pack + from starkware.python.math_utils import div_mod, safe_div + + a = pack(ids.a, PRIME) + b = pack(ids.b, PRIME) + value = res = div_mod(a, b, N) + %} + let (res) = nondet_bigint3() + + %{ value = k = safe_div(res * b - a, N) %} + let (k) = nondet_bigint3() + + let (res_b) = bigint_mul(res, b) + let n = BigInt3(N0, N1, N2) + let (k_n) = bigint_mul(k, n) + + # We should now have res_b = k_n + a. Since the numbers are in unreduced form, + # we should handle the carry. + + tempvar carry1 = (res_b.d0 - k_n.d0 - a.d0) / BASE + assert [range_check_ptr + 0] = carry1 + 2 ** 127 + + tempvar carry2 = (res_b.d1 - k_n.d1 - a.d1 + carry1) / BASE + assert [range_check_ptr + 1] = carry2 + 2 ** 127 + + tempvar carry3 = (res_b.d2 - k_n.d2 - a.d2 + carry2) / BASE + assert [range_check_ptr + 2] = carry3 + 2 ** 127 + + tempvar carry4 = (res_b.d3 - k_n.d3 + carry3) / BASE + assert [range_check_ptr + 3] = carry4 + 2 ** 127 + + assert res_b.d4 - k_n.d4 + carry4 = 0 + + let range_check_ptr = range_check_ptr + 4 + + return (res=res) +end + +# Verifies that val is in the range [1, N) and that the limbs of val are in the range [0, BASE). +func validate_signature_entry{range_check_ptr}(val : BigInt3): + assert_nn_le(val.d2, N2) + assert_nn_le(val.d1, BASE - 1) + assert_nn_le(val.d0, BASE - 1) + + if val.d2 == N2: + if val.d1 == N1: + assert_nn_le(val.d0, N0 - 1) + return () + end + assert_nn_le(val.d1, N1 - 1) + return () + end + + # Check that val > 0. + if val.d2 == 0: + if val.d1 == 0: + assert_not_zero(val.d0) + return () + end + end + return () +end + +# Converts a public key point to the corresponding Ethereum address. +func public_key_point_to_eth_address{ + range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt* +}(public_key_point : EcPoint) -> (eth_address : felt): + alloc_locals + let (local elements : Uint256*) = alloc() + let (x_uint256 : Uint256) = bigint_to_uint256(public_key_point.x) + assert elements[0] = x_uint256 + let (y_uint256 : Uint256) = bigint_to_uint256(public_key_point.y) + assert elements[1] = y_uint256 + let (point_hash : Uint256) = keccak_uint256s_bigend(n_elements=2, elements=elements) + + # The Ethereum address is the 20 least significant bytes of the keccak of the public key. + let (high_high, high_low) = unsigned_div_rem(point_hash.high, 2 ** 32) + return (eth_address=point_hash.low + RC_BOUND * high_low) +end + +# Returns a point on the secp256k1 curve with the given x coordinate. Chooses the y that has the +# same parity as v (there are two y values that correspond to x, with different parities). +# Also verifies that v is in the range [0, 2 ** 128). +# Prover assumption: +# x is the x coordinate of some nonzero point on the curve. +func get_point_from_x{range_check_ptr}(x : BigInt3, v : felt) -> (point : EcPoint): + with_attr error_message("Out of range v {v}."): + assert_nn(v) + end + let (x_square : UnreducedBigInt3) = unreduced_sqr(x) + let (x_square_reduced : BigInt3) = reduce(x_square) + let (x_cube : UnreducedBigInt3) = unreduced_mul(x, x_square_reduced) + + %{ + from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack + + x_cube_int = pack(ids.x_cube, PRIME) % SECP_P + y_square_int = (x_cube_int + ids.BETA) % SECP_P + y = pow(y_square_int, (SECP_P + 1) // 4, SECP_P) + + # We need to decide whether to take y or SECP_P - y. + if ids.v % 2 == y % 2: + value = y + else: + value = (-y) % SECP_P + %} + let (y : BigInt3) = nondet_bigint3() + + # Check that y has same parity as v. + assert_nn((y.d0 + v) / 2) + + let (y_square : UnreducedBigInt3) = unreduced_sqr(y) + # Check that y_square = x_cube + BETA. + verify_zero( + UnreducedBigInt3( + d0=x_cube.d0 + BETA - y_square.d0, + d1=x_cube.d1 - y_square.d1, + d2=x_cube.d2 - y_square.d2, + ), + ) + + return (point=EcPoint(x, y)) +end + +# Receives a signature and the signed message hash. +# Returns the public key associated with the signer, represented as a point on the curve. +# Note: +# Some places use the values 27 and 28 instead of 0 and 1 for v. In that case, a subtraction by 27 +# returns a v that can be used by this function. +# Prover assumptions: +# * r is the x coordinate of some nonzero point on the curve. +# * All the limbs of s and msg_hash are in the range (-2 ** 210.99, 2 ** 210.99). +# * All the limbs of r are in the range (-2 ** 124.99, 2 ** 124.99). +func recover_public_key{range_check_ptr}( + msg_hash : BigInt3, r : BigInt3, s : BigInt3, v : felt +) -> (public_key_point : EcPoint): + alloc_locals + let (local r_point : EcPoint) = get_point_from_x(x=r, v=v) + let (generator_point : EcPoint) = get_generator_point() + # The result is given by + # -(msg_hash / r) * gen + (s / r) * r_point + # where the division by r is modulo N. + + let (u1 : BigInt3) = div_mod_n(msg_hash, r) + let (u2 : BigInt3) = div_mod_n(s, r) + + let (point1) = ec_mul(generator_point, u1) + # We prefer negating the point over negating the scalar because negating mod SECP_P is + # computationally easier than mod N. + let (minus_point1) = ec_negate(point1) + + let (point2) = ec_mul(r_point, u2) + + return ec_add(minus_point1, point2) +end + +# Verifies a Secp256k1 ECDSA signature. +# Also verifies that r and s are in the range (0, N), that their limbs are in the range +# [0, BASE), and that v is in the range [0, 2 ** 128). +# Receives a keccak_ptr for computing keccak. finalize_keccak should be called after all the keccak +# calculations are done. +# Assumptions: +# * All the limbs of msg_hash are in the range [0, 3 * BASE). +func verify_eth_signature{range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt*}( + msg_hash : BigInt3, r : BigInt3, s : BigInt3, v : felt, eth_address : felt +): + alloc_locals + + with_attr error_message("Signature out of range."): + validate_signature_entry(r) + validate_signature_entry(s) + end + + with_attr error_message("Invalid signature."): + let (public_key_point : EcPoint) = recover_public_key(msg_hash=msg_hash, r=r, s=s, v=v) + let (calculated_eth_address) = public_key_point_to_eth_address( + public_key_point=public_key_point + ) + assert eth_address = calculated_eth_address + end + return () +end + +# Same as verify_eth_signature, except that msg_hash, r and s are Uint256. +func verify_eth_signature_uint256{ + range_check_ptr, bitwise_ptr : BitwiseBuiltin*, keccak_ptr : felt* +}(msg_hash : Uint256, r : Uint256, s : Uint256, v : felt, eth_address : felt): + let (msg_hash_bigint : BigInt3) = uint256_to_bigint(msg_hash) + let (r_bigint : BigInt3) = uint256_to_bigint(r) + let (s_bigint : BigInt3) = uint256_to_bigint(s) + return verify_eth_signature( + msg_hash=msg_hash_bigint, r=r_bigint, s=s_bigint, v=v, eth_address=eth_address + ) +end diff --git a/src/starkware/cairo/common/uint256.cairo b/src/starkware/cairo/common/uint256.cairo index c967cdde..6cf8aa6c 100644 --- a/src/starkware/cairo/common/uint256.cairo +++ b/src/starkware/cairo/common/uint256.cairo @@ -362,3 +362,59 @@ func uint256_shr{range_check_ptr}(a : Uint256, b : Uint256) -> (res : Uint256): let (res, _) = uint256_unsigned_div_rem(a, c) return (res) end + +# Reverses byte endianness of a 128-bit word. +# +# The algorithm works in steps. Generally speaking, on the i-th step, +# we switch between every two consecutive sequences of 2 ** i bytes. +# To illustrate how it works, here are the steps when running +# on a 64-bit word = [b0, b1, b2, b3, b4, b5, b6, b7] (3 steps instead of 4): +# +# step 1: +# [b0, b1, b2, b3, b4, b5, b6, b7] - +# [b0, 0,  b2, 0,  b4, 0,  b6, 0 ] + +# [0,  0,  b0, 0,  b2, 0,  b4, 0, b6] = +# [0,  b1, b0, b3, b2, b5, b4, b7, b6] +# +# step 2: +# [0, b1, b0, b3, b2, b5, b4, b7, b6] - +# [0, b1, b0, 0,  0,  b5, b4, 0,  0 ] + +# [0, 0,  0,  0,  0,  b1, b0, 0,  0,  b5, b4] = +# [0, 0,  0,  b3, b2, b1, b0, b7, b6, b5, b4] +# +# step 3: +# [0, 0, 0, b3, b2, b1, b0, b7, b6, b5, b4] - +# [0, 0, 0, b3, b2, b1, b0, 0,  0,  0,  0 ] + +# [0, 0, 0, 0,  0,  0,  0,  0,  0,  0,  0, b3, b2, b1, b0] = +# [0, 0, 0, 0,  0,  0,  0,  b7, b6, b5, b4, b3, b2, b1, b0] +# +# Next, we divide by 2 ** (8 + 16 + 32) and get [b7, b6, b5, b4, b3, b2, b1, b0]. +func word_reverse_endian{bitwise_ptr : BitwiseBuiltin*}(word : felt) -> (res : felt): + # Step 1. + assert bitwise_ptr[0].x = word + assert bitwise_ptr[0].y = 0x00ff00ff00ff00ff00ff00ff00ff00ff + tempvar word = word + (2 ** 16 - 1) * bitwise_ptr[0].x_and_y + # Step 2. + assert bitwise_ptr[1].x = word + assert bitwise_ptr[1].y = 0x00ffff0000ffff0000ffff0000ffff00 + tempvar word = word + (2 ** 32 - 1) * bitwise_ptr[1].x_and_y + # Step 3. + assert bitwise_ptr[2].x = word + assert bitwise_ptr[2].y = 0x00ffffffff00000000ffffffff000000 + tempvar word = word + (2 ** 64 - 1) * bitwise_ptr[2].x_and_y + # Step 4. + assert bitwise_ptr[3].x = word + assert bitwise_ptr[3].y = 0x00ffffffffffffffff00000000000000 + tempvar word = word + (2 ** 128 - 1) * bitwise_ptr[3].x_and_y + + let bitwise_ptr = bitwise_ptr + 4 * BitwiseBuiltin.SIZE + return (res=word / 2 ** (8 + 16 + 32 + 64)) +end + +# Reverses byte endianness of a uint256 integer. +func uint256_reverse_endian{bitwise_ptr : BitwiseBuiltin*}(num : Uint256) -> (res : Uint256): + let (high) = word_reverse_endian(num.high) + let (low) = word_reverse_endian(num.low) + + return (res=Uint256(low=high, high=low)) +end diff --git a/src/starkware/cairo/lang/VERSION b/src/starkware/cairo/lang/VERSION index 6f4eebdf..100435be 100644 --- a/src/starkware/cairo/lang/VERSION +++ b/src/starkware/cairo/lang/VERSION @@ -1 +1 @@ -0.8.1 +0.8.2 diff --git a/src/starkware/cairo/lang/compiler/ast/arguments.py b/src/starkware/cairo/lang/compiler/ast/arguments.py index edf69891..63a2449e 100644 --- a/src/starkware/cairo/lang/compiler/ast/arguments.py +++ b/src/starkware/cairo/lang/compiler/ast/arguments.py @@ -17,7 +17,7 @@ class IdentifierList(AstNode): def get_particles(self): for note in self.notes: note.assert_no_comments() - return [x.format() for x in self.identifiers] + return [x.to_particle() for x in self.identifiers] def get_children(self) -> Sequence[Optional[AstNode]]: return self.identifiers diff --git a/src/starkware/cairo/lang/compiler/ast/cairo_types.py b/src/starkware/cairo/lang/compiler/ast/cairo_types.py index af0dd9d1..bb930584 100644 --- a/src/starkware/cairo/lang/compiler/ast/cairo_types.py +++ b/src/starkware/cairo/lang/compiler/ast/cairo_types.py @@ -3,7 +3,12 @@ from enum import Enum, auto from typing import List, Optional, Sequence -from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField +from starkware.cairo.lang.compiler.ast.formatting_utils import ( + LocationField, + Particle, + SeparatedParticleList, + SingleParticle, +) from starkware.cairo.lang.compiler.ast.node import AstNode from starkware.cairo.lang.compiler.ast.notes import Notes from starkware.cairo.lang.compiler.error_handling import Location @@ -14,10 +19,16 @@ class CairoType(AstNode): location: Optional[Location] @abstractmethod + def to_particle(self) -> Particle: + """ + Returns a representation of the type as a Particle. + """ + def format(self) -> str: """ Returns a representation of the type as a string. """ + return str(self.to_particle()) def get_pointer_type(self) -> "CairoType": """ @@ -30,8 +41,8 @@ def get_pointer_type(self) -> "CairoType": class TypeFelt(CairoType): location: Optional[Location] = LocationField - def format(self): - return "felt" + def to_particle(self) -> Particle: + return SingleParticle(text="felt") def get_children(self) -> Sequence[Optional[AstNode]]: return [] @@ -41,8 +52,8 @@ def get_children(self) -> Sequence[Optional[AstNode]]: class TypeCodeoffset(CairoType): location: Optional[Location] = LocationField - def format(self): - return "codeoffset" + def to_particle(self) -> Particle: + return SingleParticle(text="codeoffset") def get_children(self) -> Sequence[Optional[AstNode]]: return [] @@ -53,8 +64,8 @@ class TypePointer(CairoType): pointee: CairoType location: Optional[Location] = LocationField - def format(self): - return f"{self.pointee.format()}*" + def to_particle(self) -> Particle: + return SingleParticle(text=f"{self.pointee.format()}*") def get_children(self) -> Sequence[Optional[AstNode]]: return [self.pointee] @@ -67,8 +78,8 @@ class TypeStruct(CairoType): is_fully_resolved: bool location: Optional[Location] = LocationField - def format(self): - return str(self.scope) + def to_particle(self) -> Particle: + return SingleParticle(text=str(self.scope)) @property def resolved_scope(self): @@ -100,10 +111,11 @@ class Item(AstNode): typ: CairoType location: Optional[Location] = LocationField - def format(self): - if self.name is None: - return self.typ.format() - return f"{self.name} : {self.typ.format()}" + def to_particle(self) -> Particle: + particle = self.typ.to_particle() + if self.name is not None: + particle.add_prefix(f"{self.name} : ") + return particle def get_children(self) -> Sequence[Optional[AstNode]]: return [self.typ] @@ -120,10 +132,10 @@ def assert_no_comments(self): for note in self.notes: note.assert_no_comments() - def format(self): + def to_particle(self) -> Particle: self.assert_no_comments() - member_formats = [member.format() for member in self.members] - return f"({', '.join(member_formats)})" + member_particles = [member.to_particle() for member in self.members] + return SeparatedParticleList(elements=member_particles, start="(", end=")") def get_children(self) -> Sequence[Optional[AstNode]]: return self.members diff --git a/src/starkware/cairo/lang/compiler/ast/code_elements.py b/src/starkware/cairo/lang/compiler/ast/code_elements.py index 87674cea..e90f1d84 100644 --- a/src/starkware/cairo/lang/compiler/ast/code_elements.py +++ b/src/starkware/cairo/lang/compiler/ast/code_elements.py @@ -70,7 +70,14 @@ class CodeElementMember(CodeElement): typed_identifier: TypedIdentifier def format(self, allowed_line_length): - return f"member {self.typed_identifier.format()}" + particle = self.typed_identifier.to_particle() + particle.add_prefix("member ") + return particles_in_lines( + particles=particle, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, line_indent=INDENTATION + ), + ) def get_children(self) -> Sequence[Optional[AstNode]]: return [self.typed_identifier] @@ -82,7 +89,15 @@ class CodeElementReference(CodeElement): expr: Expression def format(self, allowed_line_length): - return f"let {self.typed_identifier.format()} = {self.expr.format()}" + particle = self.typed_identifier.to_particle() + particle.add_prefix("let ") + particle.add_suffix(f" = {self.expr.format()}") + return particles_in_lines( + particles=particle, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, line_indent=INDENTATION + ), + ) def get_children(self) -> Sequence[Optional[AstNode]]: return [self.typed_identifier, self.expr] @@ -102,8 +117,16 @@ class CodeElementLocalVariable(CodeElement): location: Optional[Location] = LocationField def format(self, allowed_line_length): - assignment = "" if self.expr is None else f" = {self.expr.format()}" - return f"local {self.typed_identifier.format()}{assignment}" + particle = self.typed_identifier.to_particle() + particle.add_prefix("local ") + if self.expr is not None: + particle.add_suffix(f" = {self.expr.format()}") + return particles_in_lines( + particles=particle, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, line_indent=INDENTATION + ), + ) def get_children(self) -> Sequence[Optional[AstNode]]: return [self.typed_identifier, self.expr] @@ -121,8 +144,16 @@ class CodeElementTemporaryVariable(CodeElement): location: Optional[Location] = LocationField def format(self, allowed_line_length): - assignment = "" if self.expr is None else f" = {self.expr.format()}" - return f"tempvar {self.typed_identifier.format()}{assignment}" + particle = self.typed_identifier.to_particle() + particle.add_prefix("tempvar ") + if self.expr is not None: + particle.add_suffix(f" = {self.expr.format()}") + return particles_in_lines( + particles=particle, + config=ParticleFormattingConfig( + allowed_line_length=allowed_line_length, line_indent=INDENTATION + ), + ) def get_children(self) -> Sequence[Optional[AstNode]]: return [self.typed_identifier, self.expr] @@ -251,7 +282,9 @@ class CodeElementReturnValueReference(CodeElement): def format(self, allowed_line_length): call_particles = self.func_call.get_particles() - first_particle = f"let {self.typed_identifier.format()} = " + call_particles[0] + first_particle = self.typed_identifier.to_particle() + first_particle.add_prefix("let ") + first_particle.add_suffix(f" = {call_particles[0]}") return particles_in_lines( particles=ParticleList(elements=[first_particle] + call_particles[1:]), @@ -284,7 +317,7 @@ def format(self, allowed_line_length): unpacking_list_particles = SeparatedParticleList( elements=self.unpacking_list.get_particles(), end=end_particle ) - particles = ["let ("] + unpacking_list_particles.to_strings() + particles[1:] + particles = ["let (", unpacking_list_particles] + particles[1:] return particles_in_lines( particles=ParticleList(elements=particles), diff --git a/src/starkware/cairo/lang/compiler/ast/formatting_utils.py b/src/starkware/cairo/lang/compiler/ast/formatting_utils.py index f14ffd5b..727a7052 100644 --- a/src/starkware/cairo/lang/compiler/ast/formatting_utils.py +++ b/src/starkware/cairo/lang/compiler/ast/formatting_utils.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from contextvars import ContextVar from dataclasses import field -from typing import List, Union +from typing import List, Sequence, Union import marshmallow @@ -163,6 +163,20 @@ def is_splitable(self) -> bool: Returns True if and only if the particle can be split into multiple lines. """ + @abstractmethod + def add_prefix(self, prefix: str): + """ + Adds a prefix to the beginning of the particle. + The prefix is glued to the particle (not splittable). + """ + + @abstractmethod + def add_suffix(self, suffix: str): + """ + Appends a suffix to the end of the particle. + The suffix is glued to the particle (not splittable). + """ + @abstractmethod def add_to_builder(self, builder: ParticleLineBuilder, suffix: str = ""): """ @@ -185,6 +199,12 @@ def __str__(self): def is_splitable(self) -> bool: return False + def add_prefix(self, prefix: str): + self.text = prefix + self.text + + def add_suffix(self, suffix: str): + self.text += suffix + def add_to_builder(self, builder: ParticleLineBuilder, suffix: str = ""): builder.add_to_line(f"{self.text}{suffix}") @@ -197,7 +217,7 @@ class ParticleList(Particle): def __init__( self, - elements: List[Union[Particle, str]], + elements: Sequence[Union[Particle, str]], ): self.elements = [] for elm in elements: @@ -209,6 +229,14 @@ def __str__(self): def is_splitable(self) -> bool: return len(self.elements) > 0 + def add_prefix(self, prefix: str): + assert len(self.elements) > 0 + self.elements[0].add_prefix(prefix) + + def add_suffix(self, suffix: str): + assert len(self.elements) > 0 + self.elements[-1].add_suffix(suffix) + def add_to_builder(self, builder: ParticleLineBuilder, suffix: str = ""): for i, particle in enumerate(self.elements): particle.add_to_builder( @@ -224,7 +252,7 @@ class SeparatedParticleList(Particle): def __init__( self, - elements: List[Union[Particle, str]], + elements: Sequence[Union[Particle, str]], separator: str = ", ", start: str = "", end: str = "", @@ -242,14 +270,11 @@ def __str__(self): def is_splitable(self) -> bool: return len(self.elements) > 0 - def to_strings(self) -> List[str]: - if len(self.elements) == 0: - # If the list is empty, return the single element 'end'. - return [self.end] - # Concatenate the 'separator' to all elements and 'end' to the last one. - return [str(elm) + self.separator for elm in self.elements[:-1]] + [ - str(self.elements[-1]) + self.end - ] + def add_prefix(self, prefix: str): + self.start = prefix + self.start + + def add_suffix(self, suffix: str): + self.end += suffix def elements_to_string(self) -> str: """ diff --git a/src/starkware/cairo/lang/compiler/ast/types.py b/src/starkware/cairo/lang/compiler/ast/types.py index bc10a617..b9a1cb73 100644 --- a/src/starkware/cairo/lang/compiler/ast/types.py +++ b/src/starkware/cairo/lang/compiler/ast/types.py @@ -3,7 +3,11 @@ from starkware.cairo.lang.compiler.ast.cairo_types import CairoType, TypeFelt from starkware.cairo.lang.compiler.ast.expr import ExprIdentifier -from starkware.cairo.lang.compiler.ast.formatting_utils import LocationField +from starkware.cairo.lang.compiler.ast.formatting_utils import ( + LocationField, + Particle, + SingleParticle, +) from starkware.cairo.lang.compiler.ast.node import AstNode from starkware.cairo.lang.compiler.error_handling import Location @@ -31,10 +35,17 @@ class TypedIdentifier(AstNode): location: Optional[Location] = LocationField modifier: Optional[Modifier] = None - def format(self): + def to_particle(self) -> Particle: modifier_str = "" if self.modifier is None else self.modifier.format() + " " - type_str = "" if self.expr_type is None else f" : {self.expr_type.format()}" - return modifier_str + self.identifier.format() + type_str + if self.expr_type is None: + return SingleParticle(text=modifier_str + self.identifier.format()) + else: + particle = self.expr_type.to_particle() + particle.add_prefix(modifier_str + self.identifier.format() + " : ") + return particle + + def format(self): + return str(self.to_particle()) def override_type(self, expr_type): return dataclasses.replace(self, expr_type=expr_type) diff --git a/src/starkware/cairo/lang/compiler/ast_objects_test.py b/src/starkware/cairo/lang/compiler/ast_objects_test.py index 73381beb..0dc0b4fa 100644 --- a/src/starkware/cairo/lang/compiler/ast_objects_test.py +++ b/src/starkware/cairo/lang/compiler/ast_objects_test.py @@ -741,3 +741,39 @@ def test_100_chars_long_import(): """ with set_one_item_per_line(False): assert parse_file(code).format() == code + + +def test_tuples(): + code = """\ +local x : ( + a : felt, + b : (c : felt, + d : (felt, (felt, felt)), + e : (f : felt, g : felt), + h : (felt, felt, felt))) +""" + with set_one_item_per_line(True): + assert ( + parse_file(code).format() + == """\ +local x : ( + a : felt, + b : (c : felt, d : (felt, (felt, felt)), e : (f : felt, g : felt), h : (felt, felt, felt)), +) +""" + ) + + assert ( + parse_file(code).format(allowed_line_length=50) + == """\ +local x : ( + a : felt, + b : ( + c : felt, + d : (felt, (felt, felt)), + e : (f : felt, g : felt), + h : (felt, felt, felt), + ), +) +""" + ) diff --git a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py index fe2b627e..457a5e3d 100644 --- a/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py +++ b/src/starkware/cairo/lang/compiler/preprocessor/preprocessor_test.py @@ -3889,6 +3889,20 @@ def test_struct_constructor_failures(): ) verify_exception( """ +using A = (felt, felt) +using B = A + +assert B(1, 2) = 0 +""", + """ +file:?:?: Struct constructor cannot be used for type '(felt, felt)'. +assert B(1, 2) = 0 + ^*****^ +""", + exc_type=CairoTypeError, + ) + verify_exception( + """ struct A: member next: A* end diff --git a/src/starkware/cairo/lang/compiler/substitute_identifiers.py b/src/starkware/cairo/lang/compiler/substitute_identifiers.py index 5f6da9fa..0113f36a 100644 --- a/src/starkware/cairo/lang/compiler/substitute_identifiers.py +++ b/src/starkware/cairo/lang/compiler/substitute_identifiers.py @@ -90,9 +90,14 @@ def visit_ExprFuncCall(self, expr: ExprFuncCall): ) ) + if not isinstance(struct_type, TypeStruct): + raise CairoTypeError( + f"Struct constructor cannot be used for type '{struct_type.format()}'.", + location=expr.location, + ) + # Verify named arguments in struct constructor. if self.get_struct_members_callback is not None: - assert isinstance(struct_type, TypeStruct) struct_members = self.get_struct_members_callback(struct_type) # Note that it's OK if len(struct_members) != len(rvalue.arguments.args) as # length compatibility of cast is checked later on. diff --git a/src/starkware/cairo/lang/ide/vscode-cairo/package.json b/src/starkware/cairo/lang/ide/vscode-cairo/package.json index c8d092b5..9cfa8eec 100644 --- a/src/starkware/cairo/lang/ide/vscode-cairo/package.json +++ b/src/starkware/cairo/lang/ide/vscode-cairo/package.json @@ -2,7 +2,7 @@ "name": "cairo", "displayName": "Cairo", "description": "Support Cairo syntax", - "version": "0.8.1", + "version": "0.8.2", "engines": { "vscode": "^1.30.0" }, diff --git a/src/starkware/cairo/lang/vm/virtual_machine_base.py b/src/starkware/cairo/lang/vm/virtual_machine_base.py index fe27d26d..e978912e 100644 --- a/src/starkware/cairo/lang/vm/virtual_machine_base.py +++ b/src/starkware/cairo/lang/vm/virtual_machine_base.py @@ -204,11 +204,12 @@ def load_hints(self, program: Program, program_base: MaybeRelocatable): compiled_hints = [] for hint_index, hint in enumerate(hints): hint_id = len(self.hint_pc_and_index) - self.hint_pc_and_index[hint_id] = (pc + program_base, hint_index) + relocated_pc = pc + program_base + self.hint_pc_and_index[hint_id] = (relocated_pc, hint_index) compiled_hints.append( CompiledHint( compiled=self.compile_hint( - hint.code, f"", hint_index=hint_index + hint.code, f"", hint_index=hint_index, pc=relocated_pc ), # Use hint=hint in the lambda's arguments to capture this value (otherwise, # it will use the same hint object for all iterations). @@ -269,7 +270,7 @@ def exit_scope(self): assert len(self.exec_scopes) > 1, "Cannot exit main scope." self.exec_scopes.pop() - def compile_hint(self, source, filename, hint_index: int): + def compile_hint(self, source, filename, hint_index: int, pc: MaybeRelocatable): """ Compiles the given python source code. This function can be overridden by subclasses. @@ -278,8 +279,15 @@ def compile_hint(self, source, filename, hint_index: int): return compile(source, filename, mode="exec") except (IndentationError, SyntaxError): hint_exception = HintException(self, *sys.exc_info()) - raise self.as_vm_exception( - hint_exception, notes=[hint_exception.exception_str], hint_index=hint_index + + raise VmException( + pc=pc, + inst_location=self.get_location(pc=pc), + inner_exc=hint_exception, + error_attr_value=None, + traceback=None, + notes=[hint_exception.exception_str], + hint_index=hint_index, ) from None def exec_hint(self, code, globals_, hint_index): diff --git a/src/starkware/cairo/lang/vm/vm_exceptions.py b/src/starkware/cairo/lang/vm/vm_exceptions.py index 4b55792a..dab4f6f6 100644 --- a/src/starkware/cairo/lang/vm/vm_exceptions.py +++ b/src/starkware/cairo/lang/vm/vm_exceptions.py @@ -76,6 +76,11 @@ def __init__(self, vm, exc_type, exc_value, exc_tb): exc_value.msg, (filename, line_num, exc_value.offset, exc_value.text) ) + exc_string = "Got an exception while compiling a hint." + else: + exc_string = "Got an exception while executing a hint." + super().__init__(exc_string) + tb_exception = traceback.TracebackException(exc_type, exc_value, exc_tb) # First item in the traceback is the call to exec, remove it. assert tb_exception.stack[0].filename.endswith("virtual_machine_base.py") @@ -93,7 +98,7 @@ def replace_stack_item(item: traceback.FrameSummary) -> traceback.FrameSummary: tb_exception.stack = traceback.StackSummary.from_list( map(replace_stack_item, tb_exception.stack) # type: ignore ) - super().__init__(f"Got an exception while executing a hint.") + self.exception_str = "".join(tb_exception.format()) self.inner_exc = exc_value diff --git a/src/starkware/cairo/lang/vm/vm_test.py b/src/starkware/cairo/lang/vm/vm_test.py index 5749de84..0663f088 100644 --- a/src/starkware/cairo/lang/vm/vm_test.py +++ b/src/starkware/cairo/lang/vm/vm_test.py @@ -375,7 +375,7 @@ def f(): VirtualMachine(program, context, {}) expected_error = f"""\ {cairo_file.name}:4:1: Error at pc=10: -Got an exception while executing a hint. +Got an exception while compiling a hint. %{{ ^^ Traceback (most recent call last): @@ -389,7 +389,8 @@ def f(): def test_hint_syntax_error(): code = """ -# Some comment. +# Make sure the hint is not located at the start of the program. +[ap] = 1 %{ def f(): @@ -422,12 +423,12 @@ def f(): with pytest.raises(VmException) as excinfo: VirtualMachine(program, context, {}) expected_error = f"""\ -{cairo_file.name}:4:1: Error at pc=10: -Got an exception while executing a hint. +{cairo_file.name}:5:1: Error at pc=12: +Got an exception while compiling a hint. %{{ ^^ Traceback (most recent call last): - File "{cairo_file.name}", line 6 + File "{cairo_file.name}", line 7 b = # Wrong syntax. ^ SyntaxError: invalid syntax\ diff --git a/src/starkware/python/math_utils.py b/src/starkware/python/math_utils.py index 39b67ba7..f7bb1985 100644 --- a/src/starkware/python/math_utils.py +++ b/src/starkware/python/math_utils.py @@ -176,12 +176,16 @@ def ec_mult(m, point, alpha, p): return ec_add(ec_mult(m - 1, point, alpha, p), point, p) -def ec_safe_mult(m: int, point: Tuple[int, int], alpha: int, p: int) -> Union[Tuple[int, int], str]: +def ec_safe_mult( + m: int, point: Tuple[int, int], alpha: int, p: int +) -> Union[Tuple[int, int], EcInfinity]: """ Multiplies by m a point on the elliptic curve with equation y^2 = x^3 + alpha*x + beta mod p. Assumes the point is given in affine form (x, y). Safe to use always. May get or return the point at infinity, represented as EC_INFINITY. """ + if m == 0: + return EC_INFINITY if m == 1: return point if m % 2 == 0: diff --git a/src/starkware/starknet/business_logic/CMakeLists.txt b/src/starkware/starknet/business_logic/CMakeLists.txt index d9627efb..bd45091b 100644 --- a/src/starkware/starknet/business_logic/CMakeLists.txt +++ b/src/starkware/starknet/business_logic/CMakeLists.txt @@ -14,6 +14,7 @@ python_lib(starknet_business_logic_utils_lib starknet_contract_definition_lib starknet_definitions_lib starknet_execution_usage_lib + starknet_general_config_lib starknet_transaction_execution_objects_lib starknet_transaction_hash_lib starkware_error_handling_lib @@ -84,6 +85,7 @@ python_lib(starknet_transaction_fee_lib LIBS starknet_abi_lib starknet_business_logic_state_lib + starknet_business_logic_utils_lib starknet_contract_definition_lib starknet_definitions_lib starknet_execute_entry_point_lib diff --git a/src/starkware/starknet/business_logic/internal_transaction.py b/src/starkware/starknet/business_logic/internal_transaction.py index cff084b3..a516fa6e 100644 --- a/src/starkware/starknet/business_logic/internal_transaction.py +++ b/src/starkware/starknet/business_logic/internal_transaction.py @@ -28,14 +28,8 @@ ContractState, ) from starkware.starknet.business_logic.state.state import BlockInfo, CarriedState, StateSelector -from starkware.starknet.business_logic.transaction_fee import ( - calculate_tx_fee_by_cairo_usage, - charge_fee, -) -from starkware.starknet.business_logic.utils import ( - get_invoke_tx_total_resources, - preprocess_invoke_function_fields, -) +from starkware.starknet.business_logic.transaction_fee import calculate_tx_fee, charge_fee +from starkware.starknet.business_logic.utils import preprocess_invoke_function_fields from starkware.starknet.core.os.contract_hash import compute_contract_hash from starkware.starknet.core.os.transaction_hash.transaction_hash import ( calculate_deploy_transaction_hash, @@ -400,9 +394,13 @@ async def invoke_constructor( caller_address=0, ) - tx_execution_context = TransactionExecutionContext.create_for_call( + tx_execution_context = TransactionExecutionContext.create( account_contract_address=0, + transaction_hash=self.hash_value, + signature=[], + max_fee=0, n_steps=general_config.invoke_tx_max_n_steps, + version=constants.TRANSACTION_VERSION, ) call_info = await call.execute( state=state, general_config=general_config, tx_execution_context=tx_execution_context @@ -603,16 +601,11 @@ async def _apply_specific_state_updates( if self.max_fee > 0: # Should always pass on regular flows (verified in the create() method). assert self.entry_point_selector == starknet_abi.EXECUTE_ENTRY_POINT_SELECTOR - l1_gas_usage, cairo_resource_usage = get_invoke_tx_total_resources( + assert self.entry_point_type is EntryPointType.EXTERNAL + actual_fee = calculate_tx_fee( state=state, call_info=call_info, - l1_handler_payload_size=self.get_l1_handler_payload_size(), - ) - actual_fee = calculate_tx_fee_by_cairo_usage( general_config=general_config, - cairo_resource_usage=cairo_resource_usage, - l1_gas_usage=l1_gas_usage, - gas_price=state.block_info.gas_price, ) fee_transfer_info = await charge_fee( general_config=general_config, diff --git a/src/starkware/starknet/business_logic/state/state.py b/src/starkware/starknet/business_logic/state/state.py index 9231f979..548c1d38 100644 --- a/src/starkware/starknet/business_logic/state/state.py +++ b/src/starkware/starknet/business_logic/state/state.py @@ -18,7 +18,11 @@ from starkware.starknet.business_logic.state.objects import ContractCarriedState, ContractState from starkware.starknet.definitions import fields from starkware.starknet.definitions.error_codes import StarknetErrorCode -from starkware.starknet.definitions.general_config import DEFAULT_GAS_PRICE, StarknetGeneralConfig +from starkware.starknet.definitions.general_config import ( + DEFAULT_GAS_PRICE, + DEFAULT_SEQUENCER_ADDRESS, + StarknetGeneralConfig, +) from starkware.starknet.services.api.contract_definition import ContractDefinition from starkware.starknet.storage.starknet_storage import StorageLeaf from starkware.starkware_utils.commitment_tree.binary_fact_tree import BinaryFactDict @@ -46,12 +50,17 @@ class BlockInfo(ValidatedMarshmallowDataclass): # L1 gas price (in Wei) measured at the beginning of the last block creation attempt. gas_price: int = field(metadata=fields.gas_price_metadata) + # The sequencer address of this block. + sequencer_address: Optional[int] = field(metadata=fields.optional_sequencer_address_metadata) + @classmethod - def empty(cls) -> "BlockInfo": + def empty(cls, sequencer_address: Optional[int]) -> "BlockInfo": """ Returns an empty BlockInfo object; i.e., the one before the first in the chain. """ - return cls(block_number=-1, block_timestamp=0, gas_price=0) + return cls( + block_number=-1, block_timestamp=0, gas_price=0, sequencer_address=sequencer_address + ) @classmethod def create_for_testing(cls, block_number: int, block_timestamp: int) -> "BlockInfo": @@ -62,6 +71,7 @@ def create_for_testing(cls, block_number: int, block_timestamp: int) -> "BlockIn block_number=block_number, block_timestamp=block_timestamp, gas_price=DEFAULT_GAS_PRICE, + sequencer_address=DEFAULT_SEQUENCER_ADDRESS, ) def validate_legal_progress(self, next_block_info: "BlockInfo"): @@ -370,7 +380,7 @@ async def empty(cls, ffc: FactFetchingContext, general_config: Config) -> "Share return cls( contract_states=empty_contract_states, - block_info=BlockInfo.empty(), + block_info=BlockInfo.empty(sequencer_address=general_config.sequencer_address), ) def to_carried_state(self, ffc: FactFetchingContext) -> CarriedState: diff --git a/src/starkware/starknet/business_logic/transaction_fee.py b/src/starkware/starknet/business_logic/transaction_fee.py index 52ad11a9..2dd7a082 100644 --- a/src/starkware/starknet/business_logic/transaction_fee.py +++ b/src/starkware/starknet/business_logic/transaction_fee.py @@ -7,6 +7,7 @@ TransactionExecutionContext, ) from starkware.starknet.business_logic.state.state import CarriedState +from starkware.starknet.business_logic.utils import get_invoke_tx_total_resources from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.definitions.general_config import StarknetGeneralConfig from starkware.starknet.public import abi as starknet_abi @@ -82,3 +83,24 @@ def calculate_tx_fee_by_cairo_usage( total_l1_gas_usage = cairo_l1_gas_usage + l1_gas_usage return math.ceil(total_l1_gas_usage * gas_price) + + +def calculate_tx_fee( + state: CarriedState, + call_info: CallInfo, + general_config: StarknetGeneralConfig, +) -> int: + """ + Calculates the fee of the most recent + InvokeFunction transaction (recent w.r.t. application on the given state). + Assumes entry point of type EXTERNAL, since only those may be charged. + """ + l1_gas_usage, cairo_resource_usage = get_invoke_tx_total_resources( + state=state, call_info=call_info + ) + return calculate_tx_fee_by_cairo_usage( + general_config=general_config, + cairo_resource_usage=cairo_resource_usage, + l1_gas_usage=l1_gas_usage, + gas_price=state.block_info.gas_price, + ) diff --git a/src/starkware/starknet/business_logic/utils.py b/src/starkware/starknet/business_logic/utils.py index 543e00c1..52600fc5 100644 --- a/src/starkware/starknet/business_logic/utils.py +++ b/src/starkware/starknet/business_logic/utils.py @@ -98,7 +98,7 @@ def preprocess_invoke_function_fields( def get_invoke_tx_total_resources( - state: CarriedState, call_info: CallInfo, l1_handler_payload_size: Optional[int] + state: CarriedState, call_info: CallInfo ) -> Tuple[int, Mapping[str, int]]: """ Returns the total resources needed to include the most recent InvokeFunction transaction in @@ -115,7 +115,8 @@ def get_invoke_tx_total_resources( l2_to_l1_messages=call_info.get_sorted_l2_to_l1_messages(), n_modified_contracts=n_modified_contracts_by_tx, n_storage_writes=tx_syscall_counter.get("storage_write", 0), - l1_handler_payload_size=l1_handler_payload_size, + # L1 handlers cannot be called. + l1_handler_payload_size=None, constructor_calldata_length=None, # Not relevant for InvokeFunction transaction. ) diff --git a/src/starkware/starknet/cli/starknet_cli.py b/src/starkware/starknet/cli/starknet_cli.py index 09f65dbc..77689618 100755 --- a/src/starkware/starknet/cli/starknet_cli.py +++ b/src/starkware/starknet/cli/starknet_cli.py @@ -489,11 +489,15 @@ async def invoke_or_call(args: argparse.Namespace, command_args: List[str], call address = invoke_tx_args.address has_wallet = get_wallet_provider(args=args) is not None + is_account_contract_invocation = has_wallet and not call max_fee = args.max_fee if max_fee is None: - if has_wallet: + if is_account_contract_invocation: fee_info = await estimate_fee_inner( - args=args, invoke_tx_args=invoke_tx_args, has_wallet=has_wallet, has_block_info=call + args=args, + invoke_tx_args=invoke_tx_args, + has_wallet=has_wallet, + has_block_info=False, ) max_fee = math.ceil(fee_info["amount"] * FEE_MARGIN_OF_ESTIMATION) max_fee_eth = float(Web3.fromWei(max_fee, "ether")) @@ -646,8 +650,8 @@ async def get_transaction_receipt(args, command_args): async def get_block(args, command_args): parser = argparse.ArgumentParser( description=( - "Outputs the block corresponding to the given ID. " - "In case no ID is given, outputs the latest block." + "Outputs the block corresponding to the given identifier (hash or number). " + "In case no identifer is given, outputs the pending block." ) ) add_block_identifier_arguments( @@ -681,7 +685,7 @@ async def get_code(args, command_args): parser = argparse.ArgumentParser( description=( "Outputs the bytecode of the contract at the given address with respect to " - "a specific block. In case no block ID is given, uses the latest block." + "a specific block. In case no block identifier is given, uses the pending block." ) ) parser.add_argument( @@ -705,7 +709,7 @@ async def get_full_contract(args, command_args): parser = argparse.ArgumentParser( description=( "Outputs the contract definition of the contract at the given address with respect to " - "a specific block. In case no block ID is given, uses the latest block." + "a specific block. In case no block identifier is given, uses the pending block." ) ) parser.add_argument( @@ -737,7 +741,7 @@ async def get_storage_at(args, command_args): parser = argparse.ArgumentParser( description=( "Outputs the storage value of a contract in a specific key with respect to " - "a specific block. In case no block ID is given, uses the latest block." + "a specific block. In case no block identifier is given, uses the pending block." ) ) parser.add_argument( @@ -813,15 +817,15 @@ def add_block_identifier_arguments( type=str, help=( f"The hash of the block to {block_role_description}. " - "In case this argument and block_number are not given, uses the latest block." + "In case this argument and block_number are not given, uses the pending block." ), ) parser.add_argument( f"--{identifier_prefix}number", help=( f"The number of the block to {block_role_description}; " - "Additional supported keywords: 'pending';" - "In case this argument and block_hash are not given, uses the latest block." + "Additional supported keywords: 'pending', 'latest';" + "In case this argument and block_hash are not given, uses the pending block." ), ) diff --git a/src/starkware/starknet/core/os/os_config/os_config_hash_test.py b/src/starkware/starknet/core/os/os_config/os_config_hash_test.py index 78e4c5ce..ba895f8b 100644 --- a/src/starkware/starknet/core/os/os_config/os_config_hash_test.py +++ b/src/starkware/starknet/core/os/os_config/os_config_hash_test.py @@ -14,7 +14,7 @@ HASH_PATH = get_source_dir_path("src/starkware/starknet/core/os/os_config/os_config_hash.json") FEE_TOKEN_ADDRESS = 0x49D36570D4E46F48E99674BD3FCC84644DDD6B96F7C741B1562B82F9E004DC7 -FIX_COMMAND = "fix_starknet_os_config_hash" +FIX_COMMAND = "starknet_os_config_hash_fix" @random_test() diff --git a/src/starkware/starknet/core/os/state.cairo b/src/starkware/starknet/core/os/state.cairo index ba60c283..5ad02892 100644 --- a/src/starkware/starknet/core/os/state.cairo +++ b/src/starkware/starknet/core/os/state.cairo @@ -69,8 +69,8 @@ func state_update{hash_ptr : HashBuiltin*, range_check_ptr, storage_updates_ptr let n_actual_state_changes = 0 # Creates PatriciaUpdateConstants struct for patricia update. let ( - local patricia_update_constants : PatriciaUpdateConstants*) = patricia_update_constants_new( - ) + local patricia_update_constants : PatriciaUpdateConstants* + ) = patricia_update_constants_new() with n_actual_state_changes: hash_state_changes( diff --git a/src/starkware/starknet/core/os/syscall_utils.py b/src/starkware/starknet/core/os/syscall_utils.py index cb9b7125..7519953d 100644 --- a/src/starkware/starknet/core/os/syscall_utils.py +++ b/src/starkware/starknet/core/os/syscall_utils.py @@ -70,11 +70,11 @@ class SysCallHandlerBase(ABC): base class for execution of system calls in the StarkNet OS. """ - def __init__(self, general_config: StarknetGeneralConfig): + def __init__(self, block_info: BlockInfo): os_program = get_os_program() - # StarkNet general configuration. - self.general_config = general_config + assert block_info.sequencer_address is not None + self.block_info = block_info self.structs = CairoStructFactory.from_program( program=os_program, @@ -270,7 +270,7 @@ def get_sequencer_address(self, segments: MemorySegmentManager, syscall_ptr: Rel ) response = self.structs.GetSequencerAddressResponse( - sequencer_address=self.general_config.sequencer_address + sequencer_address=self.block_info.sequencer_address ) self._write_syscall_response( syscall_name="GetSequencerAddress", @@ -523,15 +523,21 @@ def __init__( general_config: StarknetGeneralConfig, initial_syscall_ptr: RelocatableValue, ): - super().__init__(general_config=general_config) + super().__init__(block_info=state.block_info) + + # Configuration objects. + self.general_config = general_config + + # Storage-related members. + self.starknet_storage = starknet_storage + self.loop = starknet_storage.loop + # Execution-related objects. self.execute_entry_point_cls = execute_entry_point_cls self.tx_execution_context = tx_execution_context self.state = state self.caller_address = caller_address self.contract_address = contract_address - self.starknet_storage = starknet_storage - self.loop = starknet_storage.loop # Internal calls executed by the current contract call. self.internal_calls: List[CallInfo] = [] @@ -852,11 +858,10 @@ class OsSysCallHandler(SysCallHandlerBase): def __init__( self, tx_execution_infos: List[TransactionExecutionInfo], - general_config: StarknetGeneralConfig, starknet_storage_by_address: Mapping[int, StarknetStorageInterface], block_info: BlockInfo, ): - super().__init__(general_config=general_config) + super().__init__(block_info=block_info) self.tx_execution_info_iterator: Iterator[TransactionExecutionInfo] = iter( tx_execution_infos @@ -880,8 +885,6 @@ def __init__( # StarkNet storage members. self.starknet_storage_by_address = starknet_storage_by_address - self.block_info = block_info - # A pointer to the Cairo TxInfo struct. # This pointer needs to match the TxInfo pointer that is going to be used during the system # call validation by the StarkNet OS. diff --git a/src/starkware/starknet/definitions/fields.py b/src/starkware/starknet/definitions/fields.py index d6bdf1fd..ec525c46 100644 --- a/src/starkware/starknet/definitions/fields.py +++ b/src/starkware/starknet/definitions/fields.py @@ -16,6 +16,7 @@ ) from starkware.starkware_utils.marshmallow_dataclass_fields import ( BytesAsHex, + IntAsHex, IntAsStr, StrictRequiredInteger, ) @@ -31,10 +32,12 @@ # Common. felt_as_hex_list_metadata = dict( + marshmallow_field=mfields.List(IntAsHex(validate=everest_fields.FeltField.validate)) +) + +felt_as_hex_or_str_list_metadata = dict( marshmallow_field=mfields.List( - everest_fields.FeltField.get_marshmallow_field( - required=True, load_default=marshmallow.utils.missing - ) + IntAsHex(support_decimal_loading=True, validate=everest_fields.FeltField.validate) ) ) @@ -81,6 +84,16 @@ def address_metadata(name: str, error_code: StarknetErrorCode) -> Dict[str, Any] name="Sequencer address", error_code=StarknetErrorCode.OUT_OF_RANGE_SEQUENCER_ADDRESS ) +OptionalSequencerAddressField = OptionalField( + field=dataclasses.replace( + AddressField, + name="Sequencer address", + error_code=StarknetErrorCode.OUT_OF_RANGE_SEQUENCER_ADDRESS, + ), + none_probability=0, +) +optional_sequencer_address_metadata = OptionalSequencerAddressField.metadata() + caller_address_metadata = address_metadata( name="Caller address", error_code=StarknetErrorCode.OUT_OF_RANGE_CALLER_ADDRESS ) @@ -133,7 +146,7 @@ def address_metadata(name: str, error_code: StarknetErrorCode) -> Dict[str, Any] call_data_metadata = felt_list_metadata call_data_as_hex_metadata = felt_as_hex_list_metadata -signature_as_hex_metadata = felt_as_hex_list_metadata +signature_as_hex_metadata = felt_as_hex_or_str_list_metadata signature_metadata = felt_list_metadata retdata_as_hex_metadata = felt_as_hex_list_metadata diff --git a/src/starkware/starknet/definitions/general_config.py b/src/starkware/starknet/definitions/general_config.py index 32a34b6a..35ebf372 100644 --- a/src/starkware/starknet/definitions/general_config.py +++ b/src/starkware/starknet/definitions/general_config.py @@ -132,7 +132,7 @@ class StarknetGeneralConfig(EverestGeneralConfig): tx_version: int = field( metadata=dict( marshmallow_field=StrictRequiredInteger( - validate=validate_non_negative("Trasaction version."), + validate=validate_non_negative("Transaction version."), ), description=( "Current transaction version - " diff --git a/src/starkware/starknet/security/CMakeLists.txt b/src/starkware/starknet/security/CMakeLists.txt index aa6321e7..805ed20e 100644 --- a/src/starkware/starknet/security/CMakeLists.txt +++ b/src/starkware/starknet/security/CMakeLists.txt @@ -46,6 +46,7 @@ python_lib(starknet_hints_whitelist_lib whitelists/cairo_keccak.json whitelists/cairo_secp.json whitelists/cairo_sha256.json + whitelists/ec_bigint.json whitelists/ec_recover.json whitelists/latest.json diff --git a/src/starkware/starknet/security/starknet_common.cairo b/src/starkware/starknet/security/starknet_common.cairo index 3651ba85..994f9037 100644 --- a/src/starkware/starknet/security/starknet_common.cairo +++ b/src/starkware/starknet/security/starknet_common.cairo @@ -1,5 +1,27 @@ from starkware.cairo.common.alloc import alloc from starkware.cairo.common.bitwise import bitwise_and, bitwise_operations, bitwise_or, bitwise_xor +from starkware.cairo.common.cairo_keccak.keccak import ( + finalize_keccak, + keccak_add_uint256, + keccak_as_words, +) +from starkware.cairo.common.cairo_secp.bigint import bigint_to_uint256 +from starkware.cairo.common.cairo_secp.ec import ( + compute_doubling_slope, + compute_slope, + ec_add, + ec_double, + ec_mul, + ec_negate, +) +from starkware.cairo.common.cairo_secp.field import is_zero, reduce, verify_zero +from starkware.cairo.common.cairo_secp.signature import ( + div_mod_n, + get_point_from_x, + public_key_point_to_eth_address, + recover_public_key, + verify_eth_signature, +) from starkware.cairo.common.default_dict import default_dict_finalize, default_dict_new from starkware.cairo.common.dict import dict_read, dict_squash, dict_update, dict_write from starkware.cairo.common.find_element import find_element, search_sorted, search_sorted_lower diff --git a/src/starkware/starknet/security/whitelists/ec_bigint.json b/src/starkware/starknet/security/whitelists/ec_bigint.json new file mode 100644 index 00000000..bd479d53 --- /dev/null +++ b/src/starkware/starknet/security/whitelists/ec_bigint.json @@ -0,0 +1,33 @@ +{ + "allowed_reference_expressions_for_hint": [ + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import pack", + "from starkware.cairo.common.math_utils import as_int", + "from starkware.python.math_utils import div_mod, safe_div", + "", + "p = pack(ids.P, PRIME)", + "x = pack(ids.x, PRIME) + as_int(ids.x.d3, PRIME) * ids.BASE ** 3 + as_int(ids.x.d4, PRIME) * ids.BASE ** 4", + "y = pack(ids.y, PRIME)", + "", + "value = res = div_mod(x, y, p)" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import split", + "segments.write_arg(ids.res.address_, split(value))" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "k = safe_div(res * y - x, p)", + "value = k if k > 0 else 0 - k", + "ids.flag = 1 if k > 0 else 0" + ] + } + ] +} diff --git a/src/starkware/starknet/security/whitelists/latest.json b/src/starkware/starknet/security/whitelists/latest.json index a90f2f60..59dbb70c 100644 --- a/src/starkware/starknet/security/whitelists/latest.json +++ b/src/starkware/starknet/security/whitelists/latest.json @@ -1,5 +1,18 @@ { "allowed_reference_expressions_for_hint": [ + { + "allowed_expressions": [], + "hint_lines": [ + "# Add dummy pairs of input and output.", + "_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)", + "_block_size = int(ids.BLOCK_SIZE)", + "assert 0 <= _keccak_state_size_felts < 100", + "assert 0 <= _block_size < 10", + "inp = [0] * _keccak_state_size_felts", + "padding = (inp + keccak_func(inp)) * _block_size", + "segments.write_arg(ids.keccak_ptr_end, padding)" + ] + }, { "allowed_expressions": [], "hint_lines": [ @@ -265,6 +278,149 @@ "ids.low = int.from_bytes(hashed[16:32], 'big')" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_keccak.keccak_utils import keccak_func", + "_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)", + "assert 0 <= _keccak_state_size_felts < 100", + "", + "output_values = keccak_func(memory.get_range(", + " ids.keccak_ptr - _keccak_state_size_felts, _keccak_state_size_felts))", + "segments.write_arg(ids.keccak_ptr, output_values)" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import N, pack", + "from starkware.python.math_utils import div_mod, safe_div", + "", + "a = pack(ids.a, PRIME)", + "b = pack(ids.b, PRIME)", + "value = res = div_mod(a, b, N)" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P", + "from starkware.python.math_utils import div_mod", + "", + "value = x_inv = div_mod(1, x, SECP_P)" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack", + "", + "q, r = divmod(pack(ids.val, PRIME), SECP_P)", + "assert r == 0, f\"verify_zero: Invalid input {ids.val.d0, ids.val.d1, ids.val.d2}.\"", + "ids.q = q % PRIME" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack", + "", + "slope = pack(ids.slope, PRIME)", + "x = pack(ids.point.x, PRIME)", + "y = pack(ids.point.y, PRIME)", + "", + "value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack", + "", + "slope = pack(ids.slope, PRIME)", + "x0 = pack(ids.point0.x, PRIME)", + "x1 = pack(ids.point1.x, PRIME)", + "y0 = pack(ids.point0.y, PRIME)", + "", + "value = new_x = (pow(slope, 2, SECP_P) - x0 - x1) % SECP_P" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack", + "", + "value = pack(ids.x, PRIME) % SECP_P" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack", + "", + "x = pack(ids.x, PRIME) % SECP_P" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack", + "", + "x_cube_int = pack(ids.x_cube, PRIME) % SECP_P", + "y_square_int = (x_cube_int + ids.BETA) % SECP_P", + "y = pow(y_square_int, (SECP_P + 1) // 4, SECP_P)", + "", + "# We need to decide whether to take y or SECP_P - y.", + "if ids.v % 2 == y % 2:", + " value = y", + "else:", + " value = (-y) % SECP_P" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack", + "", + "y = pack(ids.point.y, PRIME) % SECP_P", + "# The modulo operation in python always returns a nonnegative number.", + "value = (-y) % SECP_P" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack", + "from starkware.python.math_utils import ec_double_slope", + "", + "# Compute the slope.", + "x = pack(ids.point.x, PRIME)", + "y = pack(ids.point.y, PRIME)", + "value = slope = ec_double_slope(point=(x, y), alpha=0, p=SECP_P)" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack", + "from starkware.python.math_utils import line_slope", + "", + "# Compute the slope.", + "x0 = pack(ids.point0.x, PRIME)", + "y0 = pack(ids.point0.y, PRIME)", + "x1 = pack(ids.point1.x, PRIME)", + "y1 = pack(ids.point1.y, PRIME)", + "value = slope = line_slope(point1=(x0, y0), point2=(x1, y1), p=SECP_P)" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "from starkware.cairo.common.cairo_secp.secp_utils import split", + "", + "segments.write_arg(ids.res.address_, split(value))" + ] + }, { "allowed_expressions": [], "hint_lines": [ @@ -421,6 +577,12 @@ "ids.loop_temps.should_continue = 1 if current_access_indices else 0" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "ids.low = (ids.x.d0 + ids.x.d1 * ids.BASE) & ((1 << 128) - 1)" + ] + }, { "allowed_expressions": [], "hint_lines": [ @@ -462,6 +624,12 @@ "positions = positions_dict[ids.value][::-1]" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "memory[ap] = (ids.scalar % PRIME) % 2" + ] + }, { "allowed_expressions": [], "hint_lines": [ @@ -492,6 +660,24 @@ "memory[ap] = segments.add()" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "memory[ap] = to_felt_or_relocatable(ids.n_bytes < ids.BYTES_IN_WORD)" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "memory[ap] = to_felt_or_relocatable(ids.n_bytes >= ids.KECCAK_FULL_RATE_IN_BYTES)" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "memory[ap] = to_felt_or_relocatable(x == 0)" + ] + }, { "allowed_expressions": [], "hint_lines": [ @@ -521,6 +707,13 @@ "current_access_index = new_access_index" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "segments.write_arg(ids.inputs, [ids.low % 2 ** 64, ids.low // 2 ** 64])", + "segments.write_arg(ids.inputs + 2, [ids.high % 2 ** 64, ids.high // 2 ** 64])" + ] + }, { "allowed_expressions": [], "hint_lines": [ @@ -614,6 +807,24 @@ "syscall_handler.storage_write(segments=segments, syscall_ptr=ids.syscall_ptr)" ] }, + { + "allowed_expressions": [], + "hint_lines": [ + "value = k = safe_div(res * b - a, N)" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "value = new_y = (slope * (x - new_x) - y) % SECP_P" + ] + }, + { + "allowed_expressions": [], + "hint_lines": [ + "value = new_y = (slope * (x0 - new_x) - y0) % SECP_P" + ] + }, { "allowed_expressions": [], "hint_lines": [ diff --git a/src/starkware/starknet/services/api/feeder_gateway/block_hash.py b/src/starkware/starknet/services/api/feeder_gateway/block_hash.py index 6d96a33c..3eba391f 100644 --- a/src/starkware/starknet/services/api/feeder_gateway/block_hash.py +++ b/src/starkware/starknet/services/api/feeder_gateway/block_hash.py @@ -17,6 +17,7 @@ async def calculate_block_hash( parent_hash: int, block_number: int, global_state_root: bytes, + sequencer_address: int, block_timestamp: int, tx_hashes: Sequence[int], tx_signatures: Sequence[List[int]], @@ -70,7 +71,7 @@ def bytes_hash_function(x: bytes, y: bytes) -> bytes: data=[ block_number, from_bytes(global_state_root), - general_config.sequencer_address, + sequencer_address, block_timestamp, len(tx_hashes), # Number of transactions. tx_commitment, # Transaction commitment. diff --git a/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py b/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py index 83c64522..48a9acde 100644 --- a/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py +++ b/src/starkware/starknet/services/api/feeder_gateway/feeder_gateway_client.py @@ -1,11 +1,10 @@ import json from typing import Any, Dict, List, Optional, Union -from typing_extensions import Literal - from services.everest.api.feeder_gateway.feeder_gateway_client import EverestFeederGatewayClient from starkware.starknet.definitions import fields from starkware.starknet.services.api.feeder_gateway.response_objects import ( + BlockIdentifier, StarknetBlock, TransactionInfo, TransactionReceipt, @@ -16,7 +15,6 @@ CastableToHash = Union[int, str] JsonObject = Dict[str, Any] -BlockIdentifier = Union[int, Literal["pending"]] class FeederGatewayClient(EverestFeederGatewayClient): diff --git a/src/starkware/starknet/services/api/feeder_gateway/response_objects.py b/src/starkware/starknet/services/api/feeder_gateway/response_objects.py index 911115a1..b1670bba 100644 --- a/src/starkware/starknet/services/api/feeder_gateway/response_objects.py +++ b/src/starkware/starknet/services/api/feeder_gateway/response_objects.py @@ -36,7 +36,10 @@ from starkware.starkware_utils.validated_dataclass import ValidatedDataclass from starkware.starkware_utils.validated_fields import sequential_id_metadata -BlockIdentifier = Union[int, Literal["pending"]] +BlockNumber = int +LatestBlock = Literal["latest"] +PendingBlock = Literal["pending"] +BlockIdentifier = Union[BlockNumber, LatestBlock, PendingBlock] OptionalBlockIdentifier = Optional[BlockIdentifier] TBlockInfo = TypeVar("TBlockInfo", bound="StarknetBlock") @@ -585,6 +588,7 @@ class StarknetBlock(BaseResponseObject): block_number: Optional[int] = field(metadata=fields.default_optional_block_number_metadata) state_root: Optional[bytes] = field(metadata=fields.optional_state_root_metadata) status: Optional[BlockStatus] + gas_price: int = field(metadata=fields.gas_price_metadata) transactions: Tuple[TransactionSpecificInfo, ...] = field( metadata=dict( marshmallow_field=VariadicLengthTupleField( @@ -593,6 +597,7 @@ class StarknetBlock(BaseResponseObject): ) ) timestamp: int = field(metadata=fields.timestamp_metadata) + sequencer_address: Optional[int] = field(metadata=fields.optional_sequencer_address_metadata) transaction_receipts: Optional[Tuple[TransactionExecution, ...]] = field( metadata=dict( marshmallow_field=VariadicLengthTupleField( @@ -610,8 +615,10 @@ def create( state_root: Optional[bytes], transactions: Iterable[InternalTransaction], timestamp: int, + sequencer_address: Optional[int], transaction_receipts: Optional[Tuple[TransactionExecution, ...]], status: Optional[BlockStatus], + gas_price: int, ) -> TBlockInfo: return cls( block_hash=block_hash, @@ -622,8 +629,10 @@ def create( TransactionSpecificInfo.from_internal(internal_tx=tx) for tx in transactions ), timestamp=timestamp, + sequencer_address=sequencer_address, transaction_receipts=transaction_receipts, status=status, + gas_price=gas_price, ) def __post_init__(self): diff --git a/src/starkware/starknet/solidity/StarknetMessaging.sol b/src/starkware/starknet/solidity/StarknetMessaging.sol index d0b03fb7..61572d77 100644 --- a/src/starkware/starknet/solidity/StarknetMessaging.sol +++ b/src/starkware/starknet/solidity/StarknetMessaging.sol @@ -10,7 +10,7 @@ import "contracts/starkware/solidity/libraries/NamedStorage.sol"; to the latter pipe while interacting with L2. */ contract StarknetMessaging is IStarknetMessaging { - /** + /* Random slot storage elements and accessors. */ string constant L1L2_MESSAGE_MAP_TAG = "STARKNET_1.0_MSGING_L1TOL2_MAPPPING_V2"; @@ -80,7 +80,7 @@ contract StarknetMessaging is IStarknetMessaging { uint256 selector, uint256[] calldata payload, uint256 nonce - ) internal returns (bytes32) { + ) internal view returns (bytes32) { return keccak256( abi.encodePacked( diff --git a/src/starkware/starknet/testing/state.py b/src/starkware/starknet/testing/state.py index 5e8b04ac..f94b1039 100644 --- a/src/starkware/starknet/testing/state.py +++ b/src/starkware/starknet/testing/state.py @@ -2,7 +2,11 @@ from typing import Dict, List, Optional, Tuple, Union from starkware.cairo.lang.vm.crypto import pedersen_hash_func -from starkware.starknet.business_logic.execution.objects import Event, TransactionExecutionInfo +from starkware.starknet.business_logic.execution.objects import ( + CallInfo, + Event, + TransactionExecutionInfo, +) from starkware.starknet.business_logic.internal_transaction import ( InternalDeploy, InternalInvokeFunction, @@ -103,6 +107,41 @@ async def deploy( return tx.contract_address, tx_execution_info + async def call_raw( + self, + contract_address: CastableToAddress, + selector: Union[int, str], + calldata: List[int], + caller_address: int, + max_fee: int, + signature: Optional[List[int]] = None, + entry_point_type: EntryPointType = EntryPointType.EXTERNAL, + nonce: Optional[int] = None, + version: int = constants.QUERY_VERSION, + ) -> CallInfo: + """ + Calls a function on a contract and returns its CallInfo without modifying the state. + """ + tx = create_invoke_function( + contract_address=contract_address, + selector=selector, + calldata=calldata, + caller_address=caller_address, + max_fee=max_fee, + version=version, + signature=signature, + entry_point_type=entry_point_type, + nonce=nonce, + chain_id=self.general_config.chain_id.value, + only_query=True, + ) + + return await tx.execute( + state=copy.deepcopy(self.state), + general_config=self.general_config, + only_query=True, + ) + async def invoke_raw( self, contract_address: CastableToAddress, @@ -124,28 +163,17 @@ async def invoke_raw( signature - a list of integers to pass as signature to the invoked function. """ - if isinstance(contract_address, str): - contract_address = int(contract_address, 16) - assert isinstance(contract_address, int) - - if isinstance(selector, str): - selector = get_selector_from_name(selector) - assert isinstance(selector, int) - - if signature is None: - signature = [] - - tx = InternalInvokeFunction.create( + tx = create_invoke_function( contract_address=contract_address, - entry_point_selector=selector, - entry_point_type=entry_point_type, + selector=selector, calldata=calldata, + caller_address=caller_address, max_fee=max_fee, + version=constants.TRANSACTION_VERSION, signature=signature, - caller_address=caller_address, + entry_point_type=entry_point_type, nonce=nonce, chain_id=self.general_config.chain_id.value, - version=constants.TRANSACTION_VERSION, ) with self.state.copy_and_apply() as state_copy: @@ -178,3 +206,42 @@ def consume_message_hash(self, message_hash: str): ), f"Message of hash {message_hash} is fully consumed." self._l2_to_l1_messages[message_hash] -= 1 + + +def create_invoke_function( + contract_address: CastableToAddress, + selector: Union[int, str], + calldata: List[int], + caller_address: int, + max_fee: int, + version: int, + signature: Optional[List[int]], + entry_point_type: EntryPointType, + nonce: Optional[int], + chain_id: int, + only_query: bool = False, +) -> InternalInvokeFunction: + + if isinstance(contract_address, str): + contract_address = int(contract_address, 16) + assert isinstance(contract_address, int) + + if isinstance(selector, str): + selector = get_selector_from_name(selector) + assert isinstance(selector, int) + + signature = [] if signature is None else signature + + return InternalInvokeFunction.create( + contract_address=contract_address, + entry_point_selector=selector, + entry_point_type=entry_point_type, + calldata=calldata, + max_fee=max_fee, + signature=signature, + caller_address=caller_address, + nonce=nonce, + chain_id=chain_id, + version=version, + only_query=only_query, + ) diff --git a/src/starkware/starknet/third_party/__init__.py b/src/starkware/starknet/third_party/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/starkware/starknet/third_party/open_zeppelin/starknet_contracts.py b/src/starkware/starknet/third_party/open_zeppelin/starknet_contracts.py index 08dfae5c..ccbabd61 100644 --- a/src/starkware/starknet/third_party/open_zeppelin/starknet_contracts.py +++ b/src/starkware/starknet/third_party/open_zeppelin/starknet_contracts.py @@ -1,4 +1,5 @@ import os.path + from starkware.starknet.services.api.contract_definition import ContractDefinition DIR = os.path.dirname(__file__) diff --git a/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree.py b/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree.py index c745d3b0..7f57179e 100644 --- a/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree.py +++ b/src/starkware/starkware_utils/commitment_tree/patricia_tree/patricia_tree.py @@ -1,4 +1,4 @@ -from typing import Collection, Dict, Optional, Tuple, Type +from typing import Collection, Dict, List, Optional, Tuple, Type import marshmallow_dataclass @@ -77,3 +77,24 @@ async def update( # In case root is an edge node, its fact must be explicitly written to DB. root_hash = await updated_virtual_root_node.commit(ffc=ffc, facts=facts) return PatriciaTree(root=root_hash, height=updated_virtual_root_node.height) + + async def get_diff_between_patricia_trees( + self, + other: "PatriciaTree", + ffc: FactFetchingContext, + storage_tree_height: int, + fact_cls: Type[TLeafFact], + ) -> List[Tuple[int, TLeafFact, TLeafFact]]: + """ + Returns a list of (key, old_fact, new_fact) that are different + between this tree and another. + + The height of the two trees must be equal. + + If the 'facts' argument is not None, this dictionary is filled with facts read from the DB. + """ + self_node, other_node = [ + VirtualPatriciaNode.from_hash(hash_value=hash_value, height=storage_tree_height) + for hash_value in (self.root, other.root) + ] + return await self_node.get_diff_between_trees(other=other_node, ffc=ffc, fact_cls=fact_cls) diff --git a/src/starkware/starkware_utils/marshmallow_dataclass_fields.py b/src/starkware/starkware_utils/marshmallow_dataclass_fields.py index da11be38..39ba18e7 100644 --- a/src/starkware/starkware_utils/marshmallow_dataclass_fields.py +++ b/src/starkware/starkware_utils/marshmallow_dataclass_fields.py @@ -88,6 +88,10 @@ class IntAsHex(mfields.Field): default_error_messages = {"invalid": 'Expected hex string, got: "{input}".'} + def __init__(self, support_decimal_loading: bool = False, **kwargs): + super().__init__(**kwargs) + self.support_decimal_loading = support_decimal_loading + def _serialize(self, value, attr, obj, **kwargs): if value is None: return None @@ -96,10 +100,13 @@ def _serialize(self, value, attr, obj, **kwargs): return hex(value) def _deserialize(self, value, attr, data, **kwargs): - if re.match("^0x[0-9a-f]+$", value) is None: - self.fail("invalid", input=value) + if re.match("^0x[0-9a-f]+$", value) is not None: + return int(value, 16) + + if self.support_decimal_loading and re.match("^[0-9]+$", value) is not None: + return int(value) - return int(value, 16) + self.fail("invalid", input=value) class BytesAsHex(mfields.Field): diff --git a/src/starkware/storage/gated_storage.py b/src/starkware/storage/gated_storage.py index 4e9ed06e..0cfc3ee1 100644 --- a/src/starkware/storage/gated_storage.py +++ b/src/starkware/storage/gated_storage.py @@ -29,14 +29,15 @@ async def create_from_config( async def _compress_value(self, key: bytes, value: bytes) -> Tuple[bytes, bytes]: """ - In case that the length of the value is greater than the limit, stores the value in the - second storage, with a unique key and returns the new value that will be stored to the first - storage which indicates that the original value is stored in storage1. + In case that the length of the key + the length of the value is greater than the limit, + stores the value in the second storage, with a unique key and returns the new value that + will be stored to the first storage which indicates that the original value is stored in + storage1. """ if value[: len(MAGIC_HEADER)] != MAGIC_HEADER: # If the value starts with MAGIC_HEADER, treat the value as a large value; Hence, it # will be stored in the second storage. - if len(value) <= self.limit: + if len(key) + len(value) <= self.limit: return key, value ukey = generate_unique_key(