require "socket"
require "./query"
require "./response"
require "./rr"
require "./settings"
# TODO: Cancel timer
# TODO: TCP
module AsyncDNS
class Resolver
@v6_sock : UDPSocket?
@v4_sock : UDPSocket?
getter settings : Settings
enum Error
NO_NAME_SERVER
INVALID_REPLY
TRUNCATED
SERVER_INVALID_FORMAT
SERVER_FAILURE
SERVER_NAME_ERROR
SERVER_NOT_IMPLEMENTED
SERVER_REFUSED
UNKNOWN
end
private class Context
getter query : Query
getter id : UInt16
getter settings : Settings
getter block : Response | Error ->
property ns_index : Int32
property attempt : Int32
property used_ns : Socket::IPAddress | Nil
getter raw_data : Bytes
def initialize(@query : Query, @id : UInt16, @settings : Settings,
@block : Response | Error ->)
@ns_index = 0
@attempt = 0
@used_ns = nil
# Header
io = IO::Memory.new(512)
io.write_bytes(@id, IO::ByteFormat::BigEndian)
# RD
io.write_bytes(0x100_u16, IO::ByteFormat::BigEndian)
# QDCOUNT
io.write_bytes(1_u16, IO::ByteFormat::BigEndian)
# ANCOUNT, NSCOUNT and ARCOUNT
3.times { io.write_bytes(0_u16) }
# Question
# QNAME
@query.domain.split('.').each do |component|
if component.bytesize > 63 || io.size + component.bytesize > 512
raise ArgumentError.new("Domain component too long")
end
io.write_bytes(component.bytesize.to_u8)
io << component
end
# QTYPE
io.write_bytes(@query.rr_type.to_u16, IO::ByteFormat::BigEndian)
# QCLASS
io.write_bytes(@query.dns_class.to_u16, IO::ByteFormat::BigEndian)
@raw_data = io.to_slice
end
end
def initialize
@settings = Settings.new
@queries = Hash(UInt16, Context).new
@tcp_queries = Hash(Socket, Context).new
end
def resolve(query : Query, &block : Response | Error ->) : Nil
while true
id = Random::Secure.rand(UInt16::MIN..UInt16::MAX).to_u16
break unless @queries.has_key?(id)
end
if query.domain.bytesize > 253
raise ArgumentError.new("Queried domain is too long")
end
if @settings.nameservers.empty?
block.call(Error::NO_NAME_SERVER)
return
end
send(Context.new(query.dup, id, settings.dup, block))
end
private def send(context : Context) : Nil
@queries[context.id] = context
ns = context.settings.nameservers[context.ns_index]
context.used_ns = used_ns = Socket::IPAddress.new(ns, 53)
case used_ns.family
when Socket::Family::INET6
if @v6_sock.nil?
@v6_sock = s = UDPSocket.new
s.bind "::", 0
start_recv_loop(s)
end
sock = @v6_sock.not_nil!
when Socket::Family::INET
if @v4_sock.nil?
@v4_sock = s = UDPSocket.new
s.bind "0.0.0.0", 0
start_recv_loop(s)
end
sock = @v4_sock.not_nil!
else
raise ArgumentError.new("Nameserver must be INET or INET6")
end
sock.send(context.raw_data, used_ns)
end
def start_recv_loop(sock : UDPSocket) : Nil
spawn do
buf = Bytes.new(0x10000)
until sock.closed?
begin
len, addr = sock.receive(buf)
rescue ex : IO::Error
break if sock.closed? && ex.os_error.nil?
raise ex
end
handle_packet(buf[0, len], addr)
end
end
end
def handle_packet(packet : Bytes, sender : Socket::IPAddress) : Nil
begin
id = packet[0].to_u16 << 8 | packet[1]
begin
context = @queries[id]
rescue KeyError
return
end
return if sender != context.used_ns
@queries.delete(id)
if packet[2] & 0x80 == 0 || # QR
packet[2] & 0x78 != context.raw_data[2] & 0x78 # Opcode
context.block.call(Error::INVALID_REPLY)
return
end
if packet[2] & 0x02 != 0 # TC
# TODO: Switch to TCP
context.block.call(Error::TRUNCATED)
return
end
case packet[3] & 0x0F # RCODE
when 0
error = nil
try_next_ns = false
when 1
error = Error::SERVER_INVALID_FORMAT
try_next_ns = false
when 2
error = Error::SERVER_FAILURE
try_next_ns = true
when 3
error = Error::SERVER_NAME_ERROR
try_next_ns = false
when 4
error = Error::SERVER_NOT_IMPLEMENTED
try_next_ns = true
when 5
error = Error::SERVER_REFUSED
try_next_ns = true
else
error = Error::UNKNOWN
try_next_ns = true
end
if try_next_ns
if context.ns_index + 1 < context.settings.nameservers.size
context.ns_index += 1
send(context)
return
end
end
if error
context.block.call(error)
return
end
io = IO::Memory.new(packet[4, packet.size - 4], writeable: false)
qdcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian)
adcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian)
nscount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian)
arcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian)
rescue IndexError
rescue IO::EOFError
end
end
def stop : Nil
@v6_sock.try { |s| s.close }
@v4_sock.try { |s| s.close }
end
end
end