Index: src/query.cr ================================================================== --- src/query.cr +++ src/query.cr @@ -5,9 +5,11 @@ class Query getter domain : String getter dns_class : DNSClass getter rr_type : RRType - def initialize(@domain : String, @dns_class : DNSClass, @rr_type : RRType) + def initialize(domain : String, @dns_class : DNSClass, @rr_type : RRType) + domain += '.' unless domain.ends_with?('.') + @domain = domain end end end Index: src/resolver.cr ================================================================== --- src/resolver.cr +++ src/resolver.cr @@ -2,10 +2,13 @@ require "./query" require "./response" require "./rr" require "./settings" + +# TODO: Cancel timer +# TODO: TCP module AsyncDNS class Resolver @v6_sock : UDPSocket? @v4_sock : UDPSocket? @@ -12,10 +15,18 @@ getter settings : Settings enum Error NO_NAME_SERVER + INVALID_REPLY + TRUNCATED + SERVER_INVALID_FORMAT + SERVER_FAILURE + SERVER_NAME_ERROR + SERVER_NOT_IMPLEMENTED + SERVER_REFUSED + UNKNOWN end private class Context getter query : Query getter id : UInt16 @@ -49,10 +60,11 @@ @query.domain.split('.').each do |component| if component.bytesize > 63 || io.size + component.bytesize > 512 raise ArgumentError.new("Domain component too long") end + io.write_bytes(component.bytesize.to_u8) io << component end # QTYPE io.write_bytes(@query.rr_type.to_u16, IO::ByteFormat::BigEndian) @@ -68,13 +80,12 @@ @queries = Hash(UInt16, Context).new @tcp_queries = Hash(Socket, Context).new end def resolve(query : Query, &block : Response | Error ->) : Nil - id : UInt16 while true - id = Random::Secure.rand(UInt16::MIN..UInt16::MAX) + id = Random::Secure.rand(UInt16::MIN..UInt16::MAX).to_u16 break unless @queries.has_key?(id) end if query.domain.bytesize > 253 raise ArgumentError.new("Queried domain is too long") @@ -92,11 +103,10 @@ @queries[context.id] = context ns = context.settings.nameservers[context.ns_index] context.used_ns = used_ns = Socket::IPAddress.new(ns, 53) - sock : UDPSocket case used_ns.family when Socket::Family::INET6 if @v6_sock.nil? @v6_sock = s = UDPSocket.new s.bind "::", 0 @@ -122,24 +132,93 @@ until sock.closed? begin len, addr = sock.receive(buf) rescue ex : IO::Error - p ex break if sock.closed? && ex.os_error.nil? raise ex end handle_packet(buf[0, len], addr) end end end - def handle_packet(packet : Bytes, sender : Socket::IPAddress) - p packet + def handle_packet(packet : Bytes, sender : Socket::IPAddress) : Nil + io = IO::Memory.new(packet, writeable: false) + + begin + id = packet[0].to_u16 << 8 | packet[1] + + begin + context = @queries[id] + rescue KeyError + return + end + + return if sender != context.used_ns + + @queries.delete(id) + + if packet[2] & 0x80 == 0 || # QR + packet[2] & 0x78 != context.raw_data[2] & 0x78 # Opcode + context.block.call(Error::INVALID_REPLY) + return + end + + if packet[2] & 0x02 != 0 # TC + # TODO: Switch to TCP + context.block.call(Error::TRUNCATED) + return + end + + case packet[3] & 0x0F # RCODE + when 0 + error = nil + try_next_ns = false + when 1 + error = Error::SERVER_INVALID_FORMAT + try_next_ns = false + when 2 + error = Error::SERVER_FAILURE + try_next_ns = true + when 3 + error = Error::SERVER_NAME_ERROR + try_next_ns = false + when 4 + error = Error::SERVER_NOT_IMPLEMENTED + try_next_ns = true + when 5 + error = Error::SERVER_REFUSED + try_next_ns = true + else + error = Error::UNKNOWN + try_next_ns = true + end + + if try_next_ns + if context.ns_index + 1 < context.settings.nameservers.size + context.ns_index += 1 + send(context) + return + end + end + + if error + context.block.call(error) + return + end + + io = IO::Memory.new(packet[4, packet.size - 4], writeable: false) + qdcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) + adcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) + nscount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) + arcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian) + rescue IndexError + end end def stop : Nil @v6_sock.try { |s| s.close } @v4_sock.try { |s| s.close } end end end