Skip to content

Commit

Permalink
Get rid of Deref and DerefMut impls for TypeCheckingContext
Browse files Browse the repository at this point in the history
  • Loading branch information
VonTum committed Sep 17, 2024
1 parent 211e0d9 commit 9f2403d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 51 deletions.
1 change: 0 additions & 1 deletion src/flattening/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ mod walk;

use crate::alloc::UUIDAllocator;
use crate::prelude::*;
use crate::typing::type_inference::TypeVariableIDMarker;

use std::ops::Deref;

Expand Down
84 changes: 34 additions & 50 deletions src/flattening/typechecking.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use crate::prelude::*;

use std::ops::{Deref, DerefMut};

use walk::for_each_generative_input_in_template_args;

use crate::debug::SpanDebugger;
Expand Down Expand Up @@ -67,19 +65,6 @@ struct TypeCheckingContext<'l, 'errs> {
runtime_condition_stack: Vec<ConditionStackElem>,
}

impl<'l, 'errs> Deref for TypeCheckingContext<'l, 'errs> {
type Target = WorkingOnResolver<'l, 'errs, ModuleUUIDMarker, Module>;

fn deref(&self) -> &Self::Target {
&self.modules
}
}
impl<'l, 'errs> DerefMut for TypeCheckingContext<'l, 'errs> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.modules
}
}

impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
fn get_link_info<ID: Into<NameElem>>(&self, id: ID) -> Option<&LinkInfo> {
let ne: NameElem = id.into();
Expand All @@ -94,7 +79,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
port: PortID,
submodule_instr: FlatID,
) -> (&'s Declaration, FileUUID) {
let submodule_id = self.working_on.instructions[submodule_instr]
let submodule_id = self.modules.working_on.instructions[submodule_instr]
.unwrap_submodule()
.module_ref
.id;
Expand All @@ -105,7 +90,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {

fn get_type_of_port(&self, port: PortID, submodule_instr: FlatID) -> FullType {
let (decl, _file) = self.get_decl_of_module_port(port, submodule_instr);
let submodule_inst = self.working_on.instructions[submodule_instr].unwrap_submodule();
let submodule_inst = self.modules.working_on.instructions[submodule_instr].unwrap_submodule();
let submodule_module = &self.modules[submodule_inst.module_ref.id];
let port_interface = submodule_module.ports[port].domain;
let port_local_domain = submodule_inst.local_interface_domains[port_interface];
Expand All @@ -123,7 +108,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
) -> Option<SpanFile> {
match wire_ref_root {
WireReferenceRoot::LocalDecl(id, _) => {
let decl_root = self.working_on.instructions[*id].unwrap_wire_declaration();
let decl_root = self.modules.working_on.instructions[*id].unwrap_wire_declaration();
Some((decl_root.decl_span, self.errors.file))
}
WireReferenceRoot::NamedConstant(cst, _) => {
Expand All @@ -140,7 +125,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
fn get_type_of_wire_reference(&self, wire_ref: &WireReference) -> FullType {
let mut write_to_type = match &wire_ref.root {
WireReferenceRoot::LocalDecl(id, _) => {
let decl_root = self.working_on.instructions[*id].unwrap_wire_declaration();
let decl_root = self.modules.working_on.instructions[*id].unwrap_wire_declaration();
decl_root.typ.clone()
}
WireReferenceRoot::NamedConstant(cst, _) => {
Expand All @@ -155,7 +140,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
for p in &wire_ref.path {
match p {
&WireReferencePathElement::ArrayAccess { idx, bracket_span } => {
let idx_wire = self.working_on.instructions[idx].unwrap_wire();
let idx_wire = self.modules.working_on.instructions[idx].unwrap_wire();

write_to_type = self.type_checker.typecheck_array_access(
&write_to_type,
Expand All @@ -177,7 +162,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
}
self.runtime_condition_stack.pop().unwrap();
}
match &self.working_on.instructions[inst_id] {
match &self.modules.working_on.instructions[inst_id] {
Instruction::SubModule(_) => {}
Instruction::FuncCall(_) => {}
Instruction::Declaration(decl) => {
Expand All @@ -195,7 +180,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
Instruction::Write(conn) => {
let (decl, file) = match conn.to.root {
WireReferenceRoot::LocalDecl(decl_id, _) => {
let decl = self.working_on.instructions[decl_id].unwrap_wire_declaration();
let decl = self.modules.working_on.instructions[decl_id].unwrap_wire_declaration();
if decl.read_only {
self.errors
.error(conn.to_span, format!("'{}' is read-only", decl.name))
Expand Down Expand Up @@ -228,7 +213,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
if decl.identifier_type.is_generative() {
// Check that this generative declaration isn't used in a non-compiletime if
if let Some(root_flat) = conn.to.root.get_root_flat() {
let to_decl = self.working_on.instructions[root_flat]
let to_decl = self.modules.working_on.instructions[root_flat]
.unwrap_wire_declaration();

if self.runtime_condition_stack.len()
Expand Down Expand Up @@ -258,7 +243,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
}
}
Instruction::IfStatement(if_stmt) => {
let condition_wire = self.working_on.instructions[if_stmt.condition].unwrap_wire();
let condition_wire = self.modules.working_on.instructions[if_stmt.condition].unwrap_wire();
if let DomainType::Physical(domain) = condition_wire.typ.domain {
self.runtime_condition_stack.push(ConditionStackElem {
ends_at: if_stmt.else_end,
Expand Down Expand Up @@ -297,7 +282,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
[template_input.declaration_instruction]
.unwrap_wire_declaration();
let declared_here = Some((template_input_decl.name_span, self.errors.file));
let val_wire = self.working_on.instructions[*val].unwrap_wire();
let val_wire = self.modules.working_on.instructions[*val].unwrap_wire();
let target_abstract_type = template_input_decl
.typ_expr
.to_type_with_substitute(&global_ref.template_args);
Expand Down Expand Up @@ -325,7 +310,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {

self.typecheck_written_type(content_typ);

let idx_wire = self.working_on.instructions[*arr_idx].unwrap_wire();
let idx_wire = self.modules.working_on.instructions[*arr_idx].unwrap_wire();
self.type_checker.typecheck_and_generative::<true>(
&idx_wire.typ,
idx_wire.span,
Expand Down Expand Up @@ -355,24 +340,24 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
}
}

fn typecheck_visit_instruction(&mut self, instr_id: FlatID) {
match &self.working_on.instructions[instr_id] {
fn typecheck_visit_instruction(&self, instr_id: FlatID) {
match &self.modules.working_on.instructions[instr_id] {
Instruction::SubModule(sm) => {
self.typecheck_template_global(&sm.module_ref);
let md = &self.modules[sm.module_ref.id];
let local_interface_domains = md
.domain_names
.map(|_| self.type_checker.new_unknown_domain_id());

let Instruction::SubModule(sm) = &mut self.working_on.instructions[instr_id] else {
let Instruction::SubModule(sm) = &mut self.modules.working_on.instructions[instr_id] else {
unreachable!()
};
sm.local_interface_domains = local_interface_domains;
}
Instruction::Declaration(decl) => {
if let Some(latency_spec) = decl.latency_specifier {
let latency_spec_wire =
self.working_on.instructions[latency_spec].unwrap_wire();
self.modules.working_on.instructions[latency_spec].unwrap_wire();
self.type_checker.typecheck_and_generative::<true>(
&latency_spec_wire.typ,
latency_spec_wire.span,
Expand All @@ -398,7 +383,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
}
}
Instruction::IfStatement(stm) => {
let wire = &self.working_on.instructions[stm.condition].unwrap_wire();
let wire = &self.modules.working_on.instructions[stm.condition].unwrap_wire();
self.type_checker.typecheck_and_generative::<false>(
&wire.typ,
wire.span,
Expand All @@ -408,10 +393,9 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
);
}
Instruction::ForStatement(stm) => {
let loop_var =
&self.working_on.instructions[stm.loop_var_decl].unwrap_wire_declaration();
let start = &self.working_on.instructions[stm.start].unwrap_wire();
let end = &self.working_on.instructions[stm.end].unwrap_wire();
let loop_var = self.modules.working_on.instructions[stm.loop_var_decl].unwrap_wire_declaration();
let start = self.modules.working_on.instructions[stm.start].unwrap_wire();
let end = self.modules.working_on.instructions[stm.end].unwrap_wire();

self.type_checker.typecheck_and_generative::<true>(
&start.typ,
Expand All @@ -432,16 +416,16 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
let result_typ = match &w.source {
WireSource::WireRef(from_wire) => self.get_type_of_wire_reference(from_wire),
&WireSource::UnaryOp { op, right } => {
let right_wire = self.working_on.instructions[right].unwrap_wire();
let right_wire = self.modules.working_on.instructions[right].unwrap_wire();
self.type_checker.typecheck_unary_operator(
op,
&right_wire.typ,
right_wire.span,
)
}
&WireSource::BinaryOp { op, left, right } => {
let left_wire = self.working_on.instructions[left].unwrap_wire();
let right_wire = self.working_on.instructions[right].unwrap_wire();
let left_wire = self.modules.working_on.instructions[left].unwrap_wire();
let right_wire = self.modules.working_on.instructions[right].unwrap_wire();
self.type_checker.typecheck_binary_operator(
op,
&left_wire.typ,
Expand All @@ -452,7 +436,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
}
WireSource::Constant(value) => value.get_type_of_constant(),
};
let Instruction::Wire(w) = &mut self.working_on.instructions[instr_id] else {
let Instruction::Wire(w) = &mut self.modules.working_on.instructions[instr_id] else {
unreachable!()
};
w.typ = result_typ;
Expand All @@ -467,7 +451,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
let declared_here = (decl.decl_span, file);

// Typecheck the value with target type
let from_wire = self.working_on.instructions[*arg].unwrap_wire();
let from_wire = self.modules.working_on.instructions[*arg].unwrap_wire();

self.join_with_condition(&write_to_type.domain, from_wire.span.debug());
self.type_checker.typecheck_write_to(
Expand Down Expand Up @@ -499,7 +483,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
};

// Typecheck the value with target type
let from_wire = self.working_on.instructions[conn.from].unwrap_wire();
let from_wire = self.modules.working_on.instructions[conn.from].unwrap_wire();

from_wire.span.debug();
self.type_checker.typecheck_write_to(
Expand Down Expand Up @@ -528,7 +512,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
decl.typ.domain = DomainType::Physical(port.domain);
}

for elem_id in self.working_on.instructions.id_range() {
for elem_id in self.modules.working_on.instructions.id_range() {
self.control_flow_visit_instruction(elem_id);
self.typecheck_visit_instruction(elem_id);
}
Expand Down Expand Up @@ -563,12 +547,12 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
name: match *best_name {
BestName::NamedDomain => self.modules.working_on.domain_names[id].clone(),
BestName::SubModule(sm_instr, sm_domain) => {
let sm = self.working_on.instructions[sm_instr].unwrap_submodule();
let sm = self.modules.working_on.instructions[sm_instr].unwrap_submodule();
sm.module_ref.span.debug();
let sm_md = &self.modules[sm.module_ref.id];
format!("{}_{}", sm.get_name(&sm_md), sm_md.domain_names[sm_domain])
}
BestName::NamedWire(decl_id) => self.working_on.instructions[decl_id]
BestName::NamedWire(decl_id) => self.modules.working_on.instructions[decl_id]
.unwrap_wire_declaration()
.name
.clone(),
Expand All @@ -584,11 +568,11 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
let instruction_fanins = self.make_fanins();

let mut is_instance_used_map: FlatAlloc<bool, FlatIDMarker> =
self.working_on.instructions.map(|_| false);
self.modules.working_on.instructions.map(|_| false);

let mut wire_to_explore_queue: Vec<FlatID> = Vec::new();

for (_id, port) in &self.working_on.ports {
for (_id, port) in &self.modules.working_on.ports {
if !port.is_input {
is_instance_used_map[port.declaration_instruction] = true;
wire_to_explore_queue.push(port.declaration_instruction);
Expand All @@ -605,7 +589,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
}

// Now produce warnings from the unused list
for (id, inst) in self.working_on.instructions.iter() {
for (id, inst) in self.modules.working_on.instructions.iter() {
if !is_instance_used_map[id] {
if let Instruction::Declaration(decl) = inst {
self.errors.warn(decl.name_span, "Unused Variable: This variable does not affect the output ports of this module");
Expand All @@ -617,9 +601,9 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
fn make_fanins(&self) -> FlatAlloc<Vec<FlatID>, FlatIDMarker> {
// Setup Wire Fanouts List for faster processing
let mut instruction_fanins: FlatAlloc<Vec<FlatID>, FlatIDMarker> =
self.working_on.instructions.map(|_| Vec::new());
self.modules.working_on.instructions.map(|_| Vec::new());

for (inst_id, inst) in self.working_on.instructions.iter() {
for (inst_id, inst) in self.modules.working_on.instructions.iter() {
let mut collector_func = |id| instruction_fanins[inst_id].push(id);
match inst {
Instruction::Write(conn) => {
Expand Down Expand Up @@ -649,7 +633,7 @@ impl<'l, 'errs> TypeCheckingContext<'l, 'errs> {
}
Instruction::IfStatement(stm) => {
for id in FlatIDRange::new(stm.then_start, stm.else_end) {
if let Instruction::Write(conn) = &self.working_on.instructions[id] {
if let Instruction::Write(conn) = &self.modules.working_on.instructions[id] {
if let Some(flat_root) = conn.to.root.get_root_flat() {
instruction_fanins[flat_root].push(stm.condition);
}
Expand Down
3 changes: 3 additions & 0 deletions src/typing/abstract_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ impl<'linker, 'errs> TypeUnifier<'linker, 'errs> {
}
}

// Compares two types, if they aren't the same then this produces an error.
pub fn typecheck_abstr(
&self,
found: &AbstractType,
Expand Down Expand Up @@ -386,6 +387,8 @@ impl<'linker, 'errs> TypeUnifier<'linker, 'errs> {
context: &str,
declared_here: Option<SpanFile>,
) {


self.typecheck_abstr(&found.typ, span, &expected, context, declared_here);

if MUST_BE_GENERATIVE && found.domain != DomainType::Generative {
Expand Down

0 comments on commit 9f2403d

Please sign in to comment.