Skip to content

Commit

Permalink
specialization
Browse files Browse the repository at this point in the history
...
  • Loading branch information
tjjfvi committed Jan 13, 2025
1 parent af9ff0f commit 9a20e22
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 40 deletions.
10 changes: 10 additions & 0 deletions tests/programs/specializations.vi
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

use std::tuple::Pair;

pub fn main(&io: &IO) {
io.println([1, 2, 3, 4].to_string[; N32::to_string]());
io.println(["abc", "def", "ghi"].to_string[; String::to_string]());
io.println(['x', 'y', 'z'].to_string[; Char::to_string]());
io.println([true, false].to_string[; Bool::to_string]());
io.println([(1, 'a'), (2, 'b')].to_string[; Pair::to_string[; N32::to_string, Char::to_string]]());
}
2 changes: 1 addition & 1 deletion vine/examples/sub_min.vi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pub fn main(&io: &IO) {
let list = [4, 3, 7, 9];
sub_min(&list);
io.println(list.to_string[;N32::to_string]());
io.println(list.to_string[; N32::to_string]());
}

pub fn sub_min(&list: &List[N32]) {
Expand Down
4 changes: 3 additions & 1 deletion vine/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use class::Classes;
use ivy::ast::Net;
use vine_util::{interner::Interned, new_idx};

use crate::{diag::ErrorGuaranteed, resolver::DefId};
use crate::{diag::ErrorGuaranteed, resolver::DefId, specializer::RelId};

Check warning on line 12 in vine/src/ast.rs

View workflow job for this annotation

GitHub Actions / cspell

Unknown word (specializer)

new_idx!(pub Local; n => ["l{n}"]);
new_idx!(pub DynFnId; n => ["f{n}"]);
Expand Down Expand Up @@ -248,6 +248,8 @@ pub enum ExprKind<'core> {
Paren(B<Expr<'core>>),
#[class(value)]
Path(GenericPath<'core>),
#[class(value, synthetic)]
Rel(RelId),
#[class(place, resolved)]
Local(Local),
#[class(value, resolved)]
Expand Down
1 change: 0 additions & 1 deletion vine/src/checker/typeof_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ impl<'core> Checker<'core, '_> {
if path.generics.as_ref().is_some_and(|g| !g.impls.is_empty()) {

Check warning on line 141 in vine/src/checker/typeof_def.rs

View workflow job for this annotation

GitHub Actions / cspell

Unknown word (impls)
self.core.report(Diag::UnexpectedImplArgs { span });
}
// dbg!(&path);
let def_id = path.path.resolved.unwrap();
let def = &self.resolver.defs[def_id];
let Some(type_def) = &def.type_def else {
Expand Down
2 changes: 2 additions & 0 deletions vine/src/distiller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ impl<'core, 'r> Distiller<'core, 'r> {
ExprKind::F32(f) => Port::F32(*f),

ExprKind::Path(path) => Port::Const(path.path.resolved.unwrap()),
ExprKind::Rel(rel_id) => Port::Rel(*rel_id),

ExprKind::DynFn(dyn_fn) => {
let dyn_fn = self.dyn_fns[*dyn_fn].as_ref().unwrap();
stage.steps.push(Step::Transfer(Transfer::unconditional(dyn_fn.interface)));
Expand Down
75 changes: 46 additions & 29 deletions vine/src/emitter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::BTreeMap, mem::take};
use std::{collections::BTreeMap, fmt::Write, mem::take};

use ivy::ast::{Net, Nets, Tree};
use vine_util::idx::{Counter, IdxVec};
Expand All @@ -7,6 +7,7 @@ use crate::{
analyzer::usage::Usage,
ast::Local,
resolver::{AdtDef, Def, Resolver, ValueDefKind, VariantDef},
specializer::{Spec, SpecId},

Check warning on line 10 in vine/src/emitter.rs

View workflow job for this annotation

GitHub Actions / cspell

Unknown word (specializer)
vir::{
Interface, InterfaceId, InterfaceKind, Invocation, Port, Stage, StageId, Step, Transfer, VIR,
},
Expand All @@ -23,40 +24,38 @@ impl<'core, 'a> Emitter<'core, 'a> {
Emitter { nets: Nets::default(), resolver, dup_labels: Counter::default() }
}

pub fn emit_vir(&mut self, path: String, vir: &VIR) {
pub fn emit_vir(&mut self, path: String, vir: &VIR, specs: &IdxVec<SpecId, Spec>) {
let mut emitter = VirEmitter {
resolver: self.resolver,
path,
stages: &vir.stages,
interfaces: &vir.interfaces,
specs,
spec_id: SpecId::default(),
locals: BTreeMap::new(),
pairs: Vec::new(),
wire_offset: 0,
wires: Counter::default(),
dup_labels: self.dup_labels,
};

for stage in vir.stages.values() {
let interface = &vir.interfaces[stage.interface];
if interface.incoming != 0 && !interface.inline() {
emitter.wire_offset = 0;
emitter.wires.0 = stage.wires.0 .0;
let root = emitter.emit_interface(interface, true);
let root =
Tree::n_ary("enum", stage.header.iter().map(|p| emitter.emit_port(p)).chain([root]));
emitter._emit_stage(stage);
for (_, local) in take(&mut emitter.locals) {
emitter.finish_local(local);
for spec_id in specs.keys() {
emitter.spec_id = spec_id;
for stage in vir.stages.values() {
let interface = &vir.interfaces[stage.interface];
if interface.incoming != 0 && !interface.inline() {
emitter.wire_offset = 0;
emitter.wires.0 = stage.wires.0 .0;
let root = emitter.emit_interface(interface, true);
let root =
Tree::n_ary("enum", stage.header.iter().map(|p| emitter.emit_port(p)).chain([root]));
emitter._emit_stage(stage);
for (_, local) in take(&mut emitter.locals) {
emitter.finish_local(local);
}
let net = Net { root, pairs: take(&mut emitter.pairs) };
self.nets.insert(emitter.stage_name(stage.id), net);
}
let net = Net { root, pairs: take(&mut emitter.pairs) };
self.nets.insert(
if stage.id.0 == 0 {
emitter.path.clone()
} else {
format!("{}::{}", emitter.path, stage.id.0)
},
net,
);
}
}

Expand All @@ -66,7 +65,7 @@ impl<'core, 'a> Emitter<'core, 'a> {
pub fn emit_ivy(&mut self, def: &Def<'core>) {
if let Some(value_def) = &def.value_def {
match &value_def.kind {
ValueDefKind::Expr(_) => {}
ValueDefKind::Expr(_) | ValueDefKind::TraitSubitem(..) => {}
ValueDefKind::Ivy(net) => {
self.nets.insert(def.canonical.to_string(), net.clone());
}
Expand All @@ -85,7 +84,6 @@ impl<'core, 'a> Emitter<'core, 'a> {
);
self.nets.insert(def.canonical.to_string(), Net { root, pairs: Vec::new() });
}
ValueDefKind::TraitSubitem(..) => todo!(),
}
}
}
Expand All @@ -96,6 +94,8 @@ struct VirEmitter<'core, 'a> {
path: String,
stages: &'a IdxVec<StageId, Stage>,
interfaces: &'a IdxVec<InterfaceId, Interface>,
specs: &'a IdxVec<SpecId, Spec>,
spec_id: SpecId,
locals: BTreeMap<Local, LocalState>,
pairs: Vec<(Tree, Tree)>,
wire_offset: usize,
Expand Down Expand Up @@ -234,8 +234,8 @@ impl<'core, 'a> VirEmitter<'core, 'a> {
)
}

fn emit_stage_node(&self, stage: StageId) -> Tree {
Tree::Global(format!("{}::{}", self.path, stage.0))
fn emit_stage_node(&self, stage_id: StageId) -> Tree {
Tree::Global(self.stage_name(stage_id))
}

fn local(&mut self, local: Local) -> &mut LocalState {
Expand All @@ -244,7 +244,8 @@ impl<'core, 'a> VirEmitter<'core, 'a> {

fn emit_step(&mut self, step: &Step) {
let wire_offset = self.wire_offset;
let emit_port = |p| Self::_emit_port(wire_offset, self.resolver, p);
let spec_id = self.spec_id;
let emit_port = |p| Self::_emit_port(wire_offset, self.resolver, &self.specs[spec_id], p);
match step {
Step::Invoke(local, invocation) => match invocation {
Invocation::Erase => self.local(*local).erase(),
Expand Down Expand Up @@ -297,16 +298,21 @@ impl<'core, 'a> VirEmitter<'core, 'a> {
}

fn emit_port(&self, port: &Port) -> Tree {
Self::_emit_port(self.wire_offset, self.resolver, port)
Self::_emit_port(self.wire_offset, self.resolver, &self.specs[self.spec_id], port)
}

fn _emit_port(wire_offset: usize, resolver: &Resolver, port: &Port) -> Tree {
fn _emit_port(wire_offset: usize, resolver: &Resolver, spec: &Spec, port: &Port) -> Tree {
match port {
Port::Erase => Tree::Erase,
Port::N32(n) => Tree::N32(*n),
Port::F32(f) => Tree::F32(*f),
Port::Wire(w) => Tree::Var(format!("w{}", wire_offset + w.0)),
Port::Const(def) => Tree::Global(resolver.defs[*def].canonical.to_string()),
Port::Rel(rel) => {
let (def, spec, singular) = spec.rels[*rel];

Check warning on line 312 in vine/src/emitter.rs

View workflow job for this annotation

GitHub Actions / cspell

Unknown word (rels)
let path = &resolver.defs[def].canonical;
Tree::Global(if singular { path.to_string() } else { format!("{}::{}", path, spec.0) })
}
}
}

Expand All @@ -324,6 +330,17 @@ impl<'core, 'a> VirEmitter<'core, 'a> {
let label = format!("w{}", self.wires.next());
(Tree::Var(label.clone()), Tree::Var(label))
}

fn stage_name(&self, stage_id: StageId) -> String {
let mut name = self.path.clone();
if !self.specs[self.spec_id].singular {
write!(name, "::{}", self.spec_id.0).unwrap();
}
if stage_id.0 != 0 {
write!(name, "::{}", stage_id.0).unwrap();
}
name
}
}

#[derive(Default)]
Expand Down
7 changes: 5 additions & 2 deletions vine/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ pub mod normalizer;
pub mod parser;
pub mod repl;
pub mod resolver;
pub mod specializer;
pub mod vir;
pub mod visit;

use core::{Core, CoreArenas};
use std::path::PathBuf;

use ivy::ast::Nets;
use specializer::specialize;

use crate::{
analyzer::analyze, ast::Path, checker::Checker, distiller::Distiller, emitter::Emitter,
Expand Down Expand Up @@ -57,14 +59,15 @@ pub fn compile(config: Config) -> Result<Nets, String> {

core.bail()?;

let specializations = specialize(&mut resolver);
let mut distiller = Distiller::new(&resolver);
let mut emitter = Emitter::new(&resolver);
for (_, def) in &resolver.defs {
for (def_id, def) in &resolver.defs {
if matches_filter(&def.canonical, &config.items) {
if let Some(vir) = distiller.distill(def) {
let mut vir = normalize(&vir);
analyze(&mut vir);
emitter.emit_vir(def.canonical.to_string(), &vir);
emitter.emit_vir(def.canonical.to_string(), &vir, &specializations[def_id]);
} else {
emitter.emit_ivy(def);
}
Expand Down
12 changes: 8 additions & 4 deletions vine/src/repl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::{
normalizer::normalize,
parser::VineParser,
resolver::{DefId, Resolver},
specializer::specialize,
vir::{InterfaceId, StageId},
};

Expand Down Expand Up @@ -78,13 +79,14 @@ impl<'core, 'ctx, 'ivm> Repl<'core, 'ctx, 'ivm> {

core.bail()?;

let specializations = specialize(&mut resolver);
let mut distiller = Distiller::new(&resolver);
let mut emitter = Emitter::new(&resolver);
for (_, def) in &resolver.defs {
for (def_id, def) in &resolver.defs {
if let Some(vir) = distiller.distill(def) {
let mut vir = normalize(&vir);
analyze(&mut vir);
emitter.emit_vir(def.canonical.to_string(), &vir);
emitter.emit_vir(def.canonical.to_string(), &vir, &specializations[def_id]);
} else {
emitter.emit_ivy(def);
}
Expand Down Expand Up @@ -183,7 +185,8 @@ impl<'core, 'ctx, 'ivm> Repl<'core, 'ctx, 'ivm> {
if let Some(vir) = distiller.distill(def) {
let mut vir = normalize(&vir);
analyze(&mut vir);
emitter.emit_vir(def.canonical.to_string(), &vir);
// emitter.emit_vir(def.canonical.to_string(), &vir);
todo!()
} else {
emitter.emit_ivy(def);
}
Expand All @@ -202,7 +205,8 @@ impl<'core, 'ctx, 'ivm> Repl<'core, 'ctx, 'ivm> {
for var in self.vars.values() {
vir.interfaces[InterfaceId(0)].wires.insert(var.local, (Usage::Mut, Usage::Mut));
}
emitter.emit_vir(name.clone(), &vir);
// emitter.emit_vir(name.clone(), &vir);
todo!();

self.host.insert_nets(&emitter.nets);

Expand Down
3 changes: 2 additions & 1 deletion vine/src/resolver/resolve_defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ impl<'core> VisitMut<'core, '_> for ResolveVisitor<'core, '_> {
self._visit_type(ty);
}

fn _visit_impl(&mut self, impl_: &mut Impl<'core>) {
fn visit_impl(&mut self, impl_: &mut Impl<'core>) {
if let ImplKind::Path(path) = &mut impl_.kind {
if path.generics.is_none() {
if let Some(ident) = path.path.as_ident() {
Expand All @@ -383,6 +383,7 @@ impl<'core> VisitMut<'core, '_> for ResolveVisitor<'core, '_> {
impl_.kind = ImplKind::Error(self.resolver.core.report(diag));
}
}
self._visit_impl(impl_);
}
}

Expand Down
Loading

0 comments on commit 9a20e22

Please sign in to comment.