Skip to content

Commit

Permalink
feat: add kind to local scope. Add kind to local scope and use to dis…
Browse files Browse the repository at this point in the history
…tinguish different scope types. It is also used to distinguish lvalues and rvalues in schema expr to provide more accurate completion in lsp.

Signed-off-by: he1pa <18012015693@163.com>
  • Loading branch information
He1pa committed Jan 3, 2024
1 parent b82e786 commit 9204199
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 21 deletions.
41 changes: 37 additions & 4 deletions kclvm/sema/src/advanced_resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use crate::{
core::{
global_state::GlobalState,
package::ModuleInfo,
scope::{LocalSymbolScope, RootSymbolScope, ScopeKind, ScopeRef},
scope::{LocalSymbolScope, LocalSymbolScopeKind, RootSymbolScope, ScopeKind, ScopeRef},
symbol::SymbolRef,
},
ty::TypeRef,
Expand Down Expand Up @@ -146,9 +146,15 @@ impl<'ctx> AdvancedResolver<'ctx> {
self.ctx.scopes.push(scope_ref);
}

fn enter_local_scope(&mut self, filepath: &str, start: Position, end: Position) {
fn enter_local_scope(
&mut self,
filepath: &str,
start: Position,
end: Position,
kind: LocalSymbolScopeKind,
) {
let parent = *self.ctx.scopes.last().unwrap();
let local_scope = LocalSymbolScope::new(parent, start, end);
let local_scope = LocalSymbolScope::new(parent, start, end, kind);
let scope_ref = self.gs.get_scopes_mut().alloc_local_scope(local_scope);

match parent.get_kind() {
Expand Down Expand Up @@ -1306,7 +1312,16 @@ mod tests {
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
30,
41,
6,
6,
),
// __main__.Main schema config entry value scope
(
"src/advanced_resolver/test_data/schema_symbols.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
30,
20,
10,
),
// pkg.Person schema expr scope
Expand All @@ -1316,6 +1331,15 @@ mod tests {
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
33,
21,
1,
),
// pkg.Person schema config entry value scope
(
"src/advanced_resolver/test_data/schema_symbols.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
34,
17,
6,
),
// __main__ package scope
Expand Down Expand Up @@ -1343,6 +1367,15 @@ mod tests {
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
12,
5,
2,
),
// import_test.a.Name config entry value scope
(
"src/advanced_resolver/test_data/import_test/a.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string()),
12,
21,
8,
),
];
Expand Down
53 changes: 47 additions & 6 deletions kclvm/sema/src/advanced_resolver/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use kclvm_ast::walker::MutSelfTypedResultWalker;
use kclvm_error::{diagnostic::Range, Position};

use crate::{
core::symbol::{KCLSymbolSemanticInfo, SymbolRef, UnresolvedSymbol, ValueSymbol},
core::{
scope::LocalSymbolScopeKind,
symbol::{KCLSymbolSemanticInfo, SymbolRef, UnresolvedSymbol, ValueSymbol},
},
ty::{Type, SCHEMA_MEMBER_FUNCTIONS},
};

Expand Down Expand Up @@ -198,7 +201,12 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {
};
}

self.enter_local_scope(&self.ctx.current_filename.clone().unwrap(), start, end);
self.enter_local_scope(
&self.ctx.current_filename.clone().unwrap(),
start,
end,
LocalSymbolScopeKind::SchemaDef,
);
let cur_scope = *self.ctx.scopes.last().unwrap();
self.gs
.get_scopes_mut()
Expand Down Expand Up @@ -337,6 +345,7 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {
&self.ctx.current_filename.as_ref().unwrap().clone(),
start,
end,
LocalSymbolScopeKind::Quant,
);
let cur_scope = *self.ctx.scopes.last().unwrap();
for target in quant_expr.variables.iter() {
Expand Down Expand Up @@ -503,7 +512,12 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {
Some(last) => last.get_end_pos(),
None => list_comp.elt.get_end_pos(),
};
self.enter_local_scope(&self.ctx.current_filename.clone().unwrap(), start, end);
self.enter_local_scope(
&self.ctx.current_filename.clone().unwrap(),
start,
end,
LocalSymbolScopeKind::List,
);
for comp_clause in &list_comp.generators {
self.walk_comp_clause(&comp_clause.node);
}
Expand All @@ -519,7 +533,12 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {
Some(last) => last.get_end_pos(),
None => dict_comp.entry.value.get_end_pos(),
};
self.enter_local_scope(&self.ctx.current_filename.clone().unwrap(), start, end);
self.enter_local_scope(
&self.ctx.current_filename.clone().unwrap(),
start,
end,
LocalSymbolScopeKind::Dict,
);
for comp_clause in &dict_comp.generators {
self.walk_comp_clause(&comp_clause.node);
}
Expand Down Expand Up @@ -603,7 +622,12 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {

fn walk_lambda_expr(&mut self, lambda_expr: &'ctx ast::LambdaExpr) -> Self::Result {
let (start, end) = (self.ctx.start_pos.clone(), self.ctx.end_pos.clone());
self.enter_local_scope(&self.ctx.current_filename.clone().unwrap(), start, end);
self.enter_local_scope(
&self.ctx.current_filename.clone().unwrap(),
start,
end,
LocalSymbolScopeKind::Lambda,
);
if let Some(args) = &lambda_expr.args {
self.walk_arguments(&args.node);
}
Expand Down Expand Up @@ -984,12 +1008,19 @@ impl<'ctx> AdvancedResolver<'ctx> {
pub(crate) fn walk_config_entries(&mut self, entries: &'ctx [ast::NodeRef<ast::ConfigEntry>]) {
let (start, end) = (self.ctx.start_pos.clone(), self.ctx.end_pos.clone());

let schema_symbol = self.ctx.current_schema_symbol.take();
let kind = match &schema_symbol {
Some(_) => LocalSymbolScopeKind::SchemaConfig,
None => LocalSymbolScopeKind::Common,
};

self.enter_local_scope(
&self.ctx.current_filename.as_ref().unwrap().clone(),
start,
end,
kind,
);
let schema_symbol = self.ctx.current_schema_symbol.take();

if let Some(owner) = schema_symbol {
let cur_scope = self.ctx.scopes.last().unwrap();
self.gs
Expand All @@ -1005,7 +1036,17 @@ impl<'ctx> AdvancedResolver<'ctx> {
}
self.ctx.maybe_def = false;
}

let (start, end) = entry.node.value.get_span_pos();
self.enter_local_scope(
&self.ctx.current_filename.as_ref().unwrap().clone(),
start,
end,
LocalSymbolScopeKind::Common,
);

self.expr(&entry.node.value);
self.leave_scope();
}
self.leave_scope()
}
Expand Down
1 change: 1 addition & 0 deletions kclvm/sema/src/core/global_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ impl GlobalState {
scopes,
&self.symbols,
self.packages.get_module_info(scope.get_filename()),
false,
)
.values()
.into_iter()
Expand Down
59 changes: 49 additions & 10 deletions kclvm/sema/src/core/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub trait Scope {
fn get_children(&self) -> Vec<ScopeRef>;

fn contains_pos(&self, pos: &Position) -> bool;
fn get_range(&self) -> Option<(Position, Position)>;

fn get_owner(&self) -> Option<SymbolRef>;
fn look_up_def(
Expand All @@ -29,6 +30,7 @@ pub trait Scope {
scope_data: &ScopeData,
symbol_data: &Self::SymbolData,
module_info: Option<&ModuleInfo>,
recursive: bool,
) -> HashMap<String, SymbolRef>;

fn dump(&self, scope_data: &ScopeData, symbol_data: &Self::SymbolData) -> Option<String>;
Expand Down Expand Up @@ -202,6 +204,7 @@ impl Scope for RootSymbolScope {
_scope_data: &ScopeData,
symbol_data: &Self::SymbolData,
module_info: Option<&ModuleInfo>,
_recursive: bool,
) -> HashMap<String, SymbolRef> {
let mut all_defs_map = HashMap::new();
if let Some(owner) = symbol_data.get_symbol(self.owner) {
Expand Down Expand Up @@ -254,6 +257,10 @@ impl Scope for RootSymbolScope {
let val: serde_json::Value = serde_json::from_str(&output).unwrap();
Some(serde_json::to_string_pretty(&val).ok()?)
}

fn get_range(&self) -> Option<(Position, Position)> {
None
}
}

impl RootSymbolScope {
Expand Down Expand Up @@ -293,6 +300,19 @@ pub struct LocalSymbolScope {

pub(crate) start: Position,
pub(crate) end: Position,
pub(crate) kind: LocalSymbolScopeKind,
}

#[allow(unused)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LocalSymbolScopeKind {
List,
Dict,
Quant,
Lambda,
SchemaDef,
SchemaConfig,
Common,
}

impl Scope for LocalSymbolScope {
Expand Down Expand Up @@ -349,13 +369,9 @@ impl Scope for LocalSymbolScope {
scope_data: &ScopeData,
symbol_data: &Self::SymbolData,
module_info: Option<&ModuleInfo>,
recursive: bool,
) -> HashMap<String, SymbolRef> {
let mut all_defs_map = HashMap::new();
for def_ref in self.defs.values() {
if let Some(def) = symbol_data.get_symbol(*def_ref) {
all_defs_map.insert(def.get_name(), *def_ref);
}
}
if let Some(owner) = self.owner {
if let Some(owner) = symbol_data.get_symbol(owner) {
for def_ref in owner.get_all_attributes(symbol_data, module_info) {
Expand All @@ -368,10 +384,23 @@ impl Scope for LocalSymbolScope {
}
}
}
if let Some(parent) = scope_data.get_scope(self.parent) {
for (name, def_ref) in parent.get_all_defs(scope_data, symbol_data, module_info) {
if !all_defs_map.contains_key(&name) {
all_defs_map.insert(name, def_ref);

if self.kind == LocalSymbolScopeKind::SchemaConfig && !recursive {
return all_defs_map;
} else {
for def_ref in self.defs.values() {
if let Some(def) = symbol_data.get_symbol(*def_ref) {
all_defs_map.insert(def.get_name(), *def_ref);
}
}

if let Some(parent) = scope_data.get_scope(self.parent) {
for (name, def_ref) in
parent.get_all_defs(scope_data, symbol_data, module_info, true)
{
if !all_defs_map.contains_key(&name) {
all_defs_map.insert(name, def_ref);
}
}
}
}
Expand Down Expand Up @@ -430,10 +459,19 @@ impl Scope for LocalSymbolScope {
output.push_str("\n]\n}");
Some(output)
}

fn get_range(&self) -> Option<(Position, Position)> {
Some((self.start.clone(), self.end.clone()))
}
}

impl LocalSymbolScope {
pub fn new(parent: ScopeRef, start: Position, end: Position) -> Self {
pub fn new(
parent: ScopeRef,
start: Position,
end: Position,
kind: LocalSymbolScopeKind,
) -> Self {
Self {
parent,
owner: None,
Expand All @@ -442,6 +480,7 @@ impl LocalSymbolScope {
refs: vec![],
start,
end,
kind,
}
}

Expand Down
2 changes: 1 addition & 1 deletion kclvm/tools/src/LSP/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ mod tests {
CompletionResponse::List(_) => panic!("test failed"),
};

expected_labels.extend(["name", "age"]);
expected_labels = vec!["", "age", "math", "name", "subpkg"];
got_labels.sort();
expected_labels.sort();
assert_eq!(got_labels, expected_labels);
Expand Down

0 comments on commit 9204199

Please sign in to comment.