Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement user-defined type guard method #1501

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion lib/steep/ast/types/logic.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module Types
module Logic
class Base
extend SharedInstance

def subst(s)
self
end
Expand Down Expand Up @@ -53,6 +53,28 @@ class ArgEqualsReceiver < Base
class ArgIsAncestor < Base
end

class Guard < Base
PATTERN = /\Aguard:\s*(self)\s+(is)\s+(\S+)\s*\Z/

attr_reader :subject
attr_reader :operator
attr_reader :type

def initialize(subject:, operator:, type:)
@subject = subject
@operator = operator
@type = type
end

def ==(other)
super && subject == other.subject && operator == other.operator && type == other.type
end

def hash
self.class.hash ^ subject.hash ^ operator.hash ^ type.hash
end
end

class Env < Base
attr_reader :truthy, :falsy, :type

Expand Down
36 changes: 36 additions & 0 deletions lib/steep/interface/builder.rb
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def singleton_shape(type_name)
overloads = method.defs.map do |type_def|
method_name = method_name_for(type_def, name)
method_type = factory.method_type(type_def.type)
method_type = replace_guard_method(definition, type_def, method_type)
method_type = replace_primitive_method(method_name, type_def, method_type)
method_type = replace_kernel_class(method_name, type_def, method_type) { AST::Builtin::Class.instance_type }
method_type = add_implicitly_returns_nil(type_def.annotations, method_type)
Expand Down Expand Up @@ -313,6 +314,7 @@ def object_shape(type_name)
overloads = method.defs.map do |type_def|
method_name = method_name_for(type_def, name)
method_type = factory.method_type(type_def.type)
method_type = replace_guard_method(definition, type_def, method_type)
method_type = replace_primitive_method(method_name, type_def, method_type)
if type_name.class?
method_type = replace_kernel_class(method_name, type_def, method_type) { AST::Types::Name::Singleton.new(name: type_name) }
Expand Down Expand Up @@ -725,6 +727,40 @@ def proc_shape(proc, proc_shape)
shape
end

def replace_guard_method(definition, method_def, method_type)
match = method_def.annotations.filter_map { AST::Types::Logic::Guard::PATTERN.match(_1.string) }.first
if match
subject = match[1] or raise
operator = match[2] or raise
type_name = match[3] or raise

type = RBS::Parser.parse_type(type_name)
raise "Unknown type: #{type_name}" unless type

context = context_from(definition.type_name)
type = type.map_type_name { factory.absolute_type_name(_1, context: context) or raise "Unknown type: #{_1}" }
guard = AST::Types::Logic::Guard.new(subject: subject, operator: operator, type: type)
definition.type_name
method_type.with(
type: method_type.type.with(return_type: guard)
)
else
method_type
end
rescue => exn
Steep.logger.error { exn.message }
method_type
end

def context_from(type_name)
if type_name.namespace == RBS::Namespace.root
[nil, type_name]
else
parent = context_from(type_name.namespace.to_type_name)
[parent, type_name]
end
end
Comment on lines +755 to +762
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not find such a helper method in Steep and RBS.
I'm not sure where the best place is.


def replace_primitive_method(method_name, method_def, method_type)
defined_in = method_def.defined_in
member = method_def.member
Expand Down
52 changes: 47 additions & 5 deletions lib/steep/type_inference/logic_type_interpreter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,39 @@ def evaluate_method_call(env:, type:, receiver:, arguments:)
truthy_result.update_type { FALSE }
]
end

when AST::Types::Logic::Guard
if receiver
receiver_type = factory.deep_expand_alias(typing.type_of(node: receiver)) || raise

# TODO: Expand the type params to the actual types
# TODO: Support argument types (ex. `self is arg1`)
# TODO: Support class' type param types (ex. `self is T`)
# TODO: Support method's type param types (ex. `self is T`)
# TODO: Support is_a operator
# TODO: Ensure the type exists

if no_subtyping?(sub_type: factory.type(type.type), super_type: receiver_type)
Steep.logger.error { "Type guard failed: #{receiver_type} is not a subtype of #{type.type}" }
return nil
end

truthy_type, falsy_type = type_guard_type_case_select(receiver_type, type.type)
truthy_env, falsy_env = refine_node_type(
env: env,
node: receiver,
truthy_type: truthy_type || receiver_type,
falsy_type: falsy_type || UNTYPED
)

truthy_result = Result.new(type: TRUE, env: truthy_env, unreachable: false)
truthy_result.unreachable! unless truthy_type

falsy_result = Result.new(type: FALSE, env: falsy_env, unreachable: false)
falsy_result.unreachable! unless falsy_type

[truthy_result, falsy_result]
end
end
end

Expand Down Expand Up @@ -494,25 +527,34 @@ def literal_var_type_case_select(value_node, arg_type)
end
end

def type_case_select(type, klass)
truth_types, false_types = type_case_select0(type, klass)
def type_guard_type_case_select(type, guard_type)
instance_type = factory.type(guard_type)
truth_types, false_types = type_case_select0(type, instance_type)

[
truth_types.empty? ? nil : AST::Types::Union.build(types: truth_types),
false_types.empty? ? nil : AST::Types::Union.build(types: false_types)
]
end

def type_case_select0(type, klass)
def type_case_select(type, klass)
instance_type = factory.instance_type(klass)
truth_types, false_types = type_case_select0(type, instance_type)

[
truth_types.empty? ? nil : AST::Types::Union.build(types: truth_types),
false_types.empty? ? nil : AST::Types::Union.build(types: false_types)
]
end

def type_case_select0(type, instance_type)
case type
when AST::Types::Union
truthy_types = [] # :Array[AST::Types::t]
falsy_types = [] #: Array[AST::Types::t]

type.types.each do |ty|
truths, falses = type_case_select0(ty, klass)
truths, falses = type_case_select0(ty, instance_type)

if truths.empty?
falsy_types.push(ty)
Expand All @@ -529,7 +571,7 @@ def type_case_select0(type, klass)
if ty == type
[[type], [type]]
else
type_case_select0(ty, klass)
type_case_select0(ty, instance_type)
end

when AST::Types::Any, AST::Types::Top, AST::Types::Var
Expand Down
2 changes: 1 addition & 1 deletion sig/steep/ast/types.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module Steep
| Intersection | Record | Tuple | Union
| Name::Alias | Name::Instance | Name::Interface | Name::Singleton
| Proc | Var
| Logic::Not | Logic::ReceiverIsNil | Logic::ReceiverIsNotNil | Logic::ReceiverIsArg | Logic::ArgIsReceiver | Logic::ArgEqualsReceiver | Logic::ArgIsAncestor | Logic::Env
| Logic::Not | Logic::ReceiverIsNil | Logic::ReceiverIsNotNil | Logic::ReceiverIsArg | Logic::ArgIsReceiver | Logic::ArgEqualsReceiver | Logic::ArgIsAncestor | Logic::Guard | Logic::Env

# Variables and special types that is subject for substitution
#
Expand Down
13 changes: 13 additions & 0 deletions sig/steep/ast/types/logic.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ module Steep
class ArgIsAncestor < Base
end

# A type for type guard.
class Guard < Base
PATTERN: Regexp

attr_reader subject: String
attr_reader operator: String
attr_reader type: RBS::Types::t

def self.new: (subject: String, operator: String, type: RBS::Types::t) -> Guard

def initialize: (subject: String, operator: String, type: RBS::Types::t) -> void
end

# A type with truthy/falsy type environment.
class Env < Base
attr_reader truthy: TypeInference::TypeEnv
Expand Down
4 changes: 4 additions & 0 deletions sig/steep/interface/builder.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ module Steep

def method_name_for: (RBS::Definition::Method::TypeDef, Symbol name) -> method_name

def replace_guard_method: (RBS::Definition, RBS::Definition::Method::TypeDef, MethodType) -> MethodType

def context_from: (RBS::TypeName) -> RBS::Resolver::context

def replace_primitive_method: (method_name, RBS::Definition::Method::TypeDef, MethodType) -> MethodType

def replace_kernel_class: (method_name, RBS::Definition::Method::TypeDef, MethodType) { () -> AST::Types::t } -> MethodType
Expand Down
4 changes: 3 additions & 1 deletion sig/steep/type_inference/logic_type_interpreter.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ module Steep
#
def literal_var_type_case_select: (Parser::AST::Node value_node, AST::Types::t arg_type) -> [Array[AST::Types::t], Array[AST::Types::t]]?

def type_guard_type_case_select: (AST::Types::t `type`, RBS::Types::t guard_type) -> [AST::Types::t?, AST::Types::t?]

def type_case_select: (AST::Types::t `type`, RBS::TypeName klass) -> [AST::Types::t?, AST::Types::t?]

def type_case_select0: (AST::Types::t `type`, RBS::TypeName klass) -> [Array[AST::Types::t], Array[AST::Types::t]]
def type_case_select0: (AST::Types::t `type`, AST::Types::t instance_type) -> [Array[AST::Types::t], Array[AST::Types::t]]

def try_convert: (AST::Types::t, Symbol) -> AST::Types::t?

Expand Down
8 changes: 8 additions & 0 deletions sig/test/type_check_test.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ class TypeCheckTest < Minitest::Test

def test_type_narrowing__local_variable_safe_navigation_operator: () -> untyped

def test_type_guard__self_is_TYPE: () -> untyped

def test_type_guard__self_is_TYPE_no_subtyping: () -> untyped

def test_type_guard__self_is_TYPE_unknown: () -> untyped

def test_type_guard__self_is_TYPE_singleton: () -> untyped

def test_argument_error__unexpected_unexpected_positional_argument: () -> untyped

def test_type_assertion__type_error: () -> untyped
Expand Down
Loading