简体   繁体   English

在 Python3 中创建一个带有网站阻止的 DNS 服务器?

[英]Create a DNS Server in Python3 with website blocking?

Can I create a DNS server using python3 which blocks/redirects certain website URLs?我可以使用阻止/重定向某些网站 URL 的 python3 创建一个 DNS 服务器吗? I have tried using dnslib but I can't seem to be able to block URLS.我曾尝试使用 dnslib,但似乎无法阻止 URL。

I am attempting to create a program which can block websites at DNS level, and I am aware of pi-hole but I would like to create my own.我正在尝试创建一个可以在 DNS 级别阻止网站的程序,我知道 pi-hole 但我想创建自己的。

My code so far (modified example code):到目前为止我的代码(修改后的示例代码):

# -*- coding: utf-8 -*-

"""
    InterceptResolver - proxy requests to upstream server
                        (optionally intercepting)
"""
from __future__ import print_function

import binascii,copy,socket,struct,sys

from dnslib import DNSRecord,RR,QTYPE,RCODE,parse_time
from dnslib.server import DNSServer,DNSHandler,BaseResolver,DNSLogger
from dnslib.label import DNSLabel

class InterceptResolver(BaseResolver):

    """
        Intercepting resolver
        Proxy requests to upstream server optionally intercepting requests
        matching local records
    """

    def __init__(self,address,port,ttl,intercept,skip,nxdomain,forward,all_qtypes,timeout=0):
        """
            address/port    - upstream server
            ttl             - default ttl for intercept records
            intercept       - list of wildcard RRs to respond to (zone format)
            skip            - list of wildcard labels to skip
            nxdomain        - list of wildcard labels to return NXDOMAIN
            forward         - list of wildcard labels to forward
            all_qtypes      - intercept all qtypes if qname matches.
            timeout         - timeout for upstream server(s)
        """
        self.address = address
        self.port = port
        self.ttl = parse_time(ttl)
        self.skip = skip
        self.nxdomain = nxdomain
        self.forward = []
        for i in forward:
            qname, _, upstream = i.partition(':')
            upstream_ip, _, upstream_port = upstream.partition(':')
            self.forward.append((qname, upstream_ip, int(upstream_port or '53')))
        self.all_qtypes = all_qtypes
        self.timeout = timeout
        self.zone = []
        for i in intercept:
            if i == '-':
                i = sys.stdin.read()
            for rr in RR.fromZone(i,ttl=self.ttl):
                self.zone.append((rr.rname,QTYPE[rr.rtype],rr))

    def resolve(self,request,handler):
        matched = False
        reply = request.reply()
        qname = request.q.qname
        qtype = QTYPE[request.q.qtype]
        # Try to resolve locally unless on skip list
        if not any([qname.matchGlob(s) for s in self.skip]):
            for name,rtype,rr in self.zone:
                if qname.matchGlob(name):
                    if qtype in (rtype,'ANY','CNAME'):
                        a = copy.copy(rr)
                        a.rname = qname
                        reply.add_answer(a)
                    matched = True
        # Check for NXDOMAIN
        if any([qname.matchGlob(s) for s in self.nxdomain]):
            reply.header.rcode = getattr(RCODE,'NXDOMAIN')
            return reply
        if matched and self.all_qtypes:
            return reply
        # Otherwise proxy, first checking forwards, then to upstream.
        upstream, upstream_port = self.address,self.port
        if not any([qname.matchGlob(s) for s in self.skip]):
            for name, ip, port in self.forward:
                if qname.matchGlob(name):
                    upstream, upstream_port = ip, port
        if not reply.rr:
            try:
                if handler.protocol == 'udp':
                    proxy_r = request.send(upstream,upstream_port,
                                    timeout=self.timeout)
                else:
                    proxy_r = request.send(upstream,upstream_port,
                                    tcp=True,timeout=self.timeout)
                reply = DNSRecord.parse(proxy_r)
            except socket.timeout:
                reply.header.rcode = getattr(RCODE,'SERVFAIL')

        return reply

if __name__ == '__main__':

    import argparse,sys,time

    p = argparse.ArgumentParser(description="DNS Intercept Proxy")
    p.add_argument("--port","-p",type=int,default=53,
                    metavar="<port>",
                    help="Local proxy port (default:53)")
    p.add_argument("--address","-a",default="",
                    metavar="<address>",
                    help="Local proxy listen address (default:all)")
    p.add_argument("--upstream","-u",default="8.8.8.8:53",
            metavar="<dns server:port>",
                    help="Upstream DNS server:port (default:8.8.8.8:53)")
    p.add_argument("--intercept","-i",action="append",
                    metavar="<zone record>",
                    help="Intercept requests matching zone record (glob) ('-' for stdin)")
    p.add_argument("--skip","-s",action="append",
                    metavar="<label>",
                    help="Don't intercept matching label (glob)")
    p.add_argument("--nxdomain","-x",action="append",
                    metavar="<label>",
                    help="Return NXDOMAIN (glob)")
    p.add_argument("--forward","-f",action="append",
                   metavar="<label:dns server:port>",
                   help="forward requests matching label (glob) to dns server")
    p.add_argument("--ttl","-t",default="60s",
                    metavar="<ttl>",
                    help="Intercept TTL (default: 60s)")
    p.add_argument("--timeout","-o",type=float,default=5,
                    metavar="<timeout>",
                    help="Upstream timeout (default: 5s)")
    p.add_argument("--all-qtypes",action='store_true',default=False,
                   help="Return an empty response if qname matches, but qtype doesn't")
    p.add_argument("--log",default="request,reply,truncated,error",
                    help="Log hooks to enable (default: +request,+reply,+truncated,+error,-recv,-send,-data)")
    p.add_argument("--log-prefix",action='store_true',default=False,
                    help="Log prefix (timestamp/handler/resolver) (default: False)")
    args = p.parse_args()

    args.dns,_,args.dns_port = args.upstream.partition(':')
    args.dns_port = int(args.dns_port or 53)
    tcpEnabled = True
    
    resolver = InterceptResolver(args.dns,
                                 args.dns_port,
                                 args.ttl,
                                 args.intercept or [],
                                 args.skip or [],
                                 args.nxdomain or [],
                                 args.forward or [],
                                 args.all_qtypes,
                                 args.timeout)
    logger = DNSLogger(args.log,args.log_prefix)

    print("Starting Intercept Proxy (%s:%d -> %s:%d) [%s]" % (
                        args.address or "*",args.port,
                        args.dns,args.dns_port,
                        "UDP/TCP" if tcpEnabled else "UDP"))

    for rr in resolver.zone:
        print("    | ",rr[2].toZone(),sep="")
    if resolver.nxdomain:
        print("    NXDOMAIN:",", ".join(resolver.nxdomain))
    if resolver.skip:
        print("    Skipping:",", ".join(resolver.skip))
    if resolver.forward:
        print("    Forwarding:")
        for i in resolver.forward:
            print("    | ","%s:%s:%s" % i,sep="")

    DNSHandler.log = {
        #'log_request',       # DNS Request
        #'log_reply',        # DNS Response
        #'log_truncated',    # Truncated
        #'log_error',        # Decoding error
    }

    udp_server = DNSServer(resolver,
                           port=args.port,
                           address=args.address,
                           logger=logger)
    udp_server.start_thread()

    if tcpEnabled:
        tcp_server = DNSServer(resolver,
                               port=args.port,
                               address=args.address,
                               tcp=True,
                               logger=logger)
        tcp_server.start_thread()

    while udp_server.isAlive():
        time.sleep(1)
        #print(DNSHandler.log)
        #print("LOG START")
        print(logger.currentLog)
        #print("LOG END")

I have also modified the server.py file slightly in the dnslib library:我还在 dnslib 库中稍微修改了 server.py 文件:

# -*- coding: utf-8 -*-

"""
    DNS server framework - intended to simplify creation of custom resolvers.

    Comprises the following components:

        DNSServer   - socketserver wrapper (in most cases you should just
                      need to pass this an appropriate resolver instance
                      and start in either foreground/background)

        DNSHandler  - handler instantiated by DNSServer to handle requests
                      The 'handle' method deals with the sending/receiving
                      packets (handling TCP length prefix) and delegates
                      the protocol handling to 'get_reply'. This decodes
                      packet, hands off a DNSRecord to the Resolver instance,
                      and encodes the returned DNSRecord.

                      In most cases you dont need to change DNSHandler unless
                      you need to get hold of the raw protocol data in the
                      Resolver

        DNSLogger   - The class provides a default set of logging functions for
                      the various stages of the request handled by a DNSServer
                      instance which are enabled/disabled by flags in the 'log'
                      class variable.

        Resolver    - Instance implementing a 'resolve' method that receives
                      the decodes request packet and returns a response.

                      To implement a custom resolver in most cases all you need
                      is to implement this interface.

                      Note that there is only a single instance of the Resolver
                      so need to be careful about thread-safety and blocking

        The following examples use the server framework:

            fixedresolver.py    - Simple resolver which will respond to all
                                  requests with a fixed response
            zoneresolver.py     - Resolver which will take a standard zone
                                  file input
            shellresolver.py    - Example of a dynamic resolver
            proxy.py            - DNS proxy
            intercept.py        - Intercepting DNS proxy

        >>> resolver = BaseResolver()
        >>> logger = DNSLogger(prefix=False)
        >>> server = DNSServer(resolver,port=8053,address="localhost",logger=logger)
        >>> server.start_thread()
        >>> q = DNSRecord.question("abc.def")
        >>> a = q.send("localhost",8053)
        Request: [...] (udp) / 'abc.def.' (A)
        Reply: [...] (udp) / 'abc.def.' (A) / NXDOMAIN
        >>> print(DNSRecord.parse(a))
        ;; ->>HEADER<<- opcode: QUERY, status: NXDOMAIN, id: ...
        ;; flags: qr aa rd ra; QUERY: 1, ANSWER: 0, AUTHORITY: 0, ADDITIONAL: 0
        ;; QUESTION SECTION:
        ;abc.def.                       IN      A
        >>> server.stop()

        >>> class TestResolver:
        ...     def resolve(self,request,handler):
        ...         reply = request.reply()
        ...         reply.add_answer(*RR.fromZone("abc.def. 60 A 1.2.3.4"))
        ...         return reply
        >>> resolver = TestResolver()
        >>> server = DNSServer(resolver,port=8053,address="localhost",logger=logger,tcp=True)
        >>> server.start_thread()
        >>> a = q.send("localhost",8053,tcp=True)
        Request: [...] (tcp) / 'abc.def.' (A)
        Reply: [...] (tcp) / 'abc.def.' (A) / RRs: A
        >>> print(DNSRecord.parse(a))
        ;; ->>HEADER<<- opcode: QUERY, status: NOERROR, id: ...
        ;; flags: qr aa rd ra; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0
        ;; QUESTION SECTION:
        ;abc.def.                       IN      A
        ;; ANSWER SECTION:
        abc.def.                60      IN      A       1.2.3.4
        >>> server.stop()


"""
from __future__ import print_function

import binascii,socket,struct,threading,time

try:
    import socketserver
except ImportError:
    import SocketServer as socketserver

from dnslib import DNSRecord,DNSError,QTYPE,RCODE,RR

class BaseResolver(object):
    """
        Base resolver implementation. Provides 'resolve' method which is
        called by DNSHandler with the decode request (DNSRecord instance)
        and returns a DNSRecord instance as reply.

        In most cases you should be able to create a custom resolver by
        just replacing the resolve method with appropriate resolver code for
        application (see fixedresolver/zoneresolver/shellresolver for
        examples)

        Note that a single instance is used by all DNSHandler instances so
        need to consider blocking & thread safety.
    """
    def resolve(self,request,handler):
        """
            Example resolver - respond to all requests with NXDOMAIN
        """
        reply = request.reply()
        reply.header.rcode = getattr(RCODE,'NXDOMAIN')
        return reply

class DNSHandler(socketserver.BaseRequestHandler):
    """
        Handler for socketserver. Transparently handles both TCP/UDP requests
        (TCP requests have length prepended) and hands off lookup to resolver
        instance specified in <SocketServer>.resolver
    """

    udplen = 0                  # Max udp packet length (0 = ignore)

    def handle(self):
        if self.server.socket_type == socket.SOCK_STREAM:
            self.protocol = 'tcp'
            data = self.request.recv(8192)
            if len(data) < 2:
                self.server.logger.log_error(self,"Request Truncated")
                return
            length = struct.unpack("!H",bytes(data[:2]))[0]
            while len(data) - 2 < length:
                new_data = self.request.recv(8192)
                if not new_data:
                    break
                data += new_data
            data = data[2:]
        else:
            self.protocol = 'udp'
            data,connection = self.request

        self.server.logger.log_recv(self,data)

        try:
            rdata = self.get_reply(data)
            self.server.logger.log_send(self,rdata)

            if self.protocol == 'tcp':
                rdata = struct.pack("!H",len(rdata)) + rdata
                self.request.sendall(rdata)
            else:
                connection.sendto(rdata,self.client_address)

        except DNSError as e:
            self.server.logger.log_error(self,e)

    def get_reply(self,data):
        request = DNSRecord.parse(data)
        self.server.logger.log_request(self,request)

        resolver = self.server.resolver
        reply = resolver.resolve(request,self)
        self.server.logger.log_reply(self,reply)

        if self.protocol == 'udp':
            rdata = reply.pack()
            if self.udplen and len(rdata) > self.udplen:
                truncated_reply = reply.truncate()
                rdata = truncated_reply.pack()
                self.server.logger.log_truncated(self,truncated_reply)
        else:
            rdata = reply.pack()

        return rdata

class DNSLogger:

    """
        The class provides a default set of logging functions for the various
        stages of the request handled by a DNSServer instance which are
        enabled/disabled by flags in the 'log' class variable.

        To customise logging create an object which implements the DNSLogger
        interface and pass instance to DNSServer.

        The methods which the logger instance must implement are:

            log_recv          - Raw packet received
            log_send          - Raw packet sent
            log_request       - DNS Request
            log_reply         - DNS Response
            log_truncated     - Truncated
            log_error         - Decoding error
            log_data          - Dump full request/response
    """

    def __init__(self,log="",prefix=True,verbose=False):
        """
            Selectively enable log hooks depending on log argument
            (comma separated list of hooks to enable/disable)

            - If empty enable default log hooks
            - If entry starts with '+' (eg. +send,+recv) enable hook
            - If entry starts with '-' (eg. -data) disable hook
            - If entry doesn't start with +/- replace defaults

            Prefix argument enables/disables log prefix
        """
        self.verbose = verbose
        self.currentLog = []
        default = ["request","reply","truncated","error"]
        log = log.split(",") if log else []
        enabled = set([ s for s in log if s[0] not in '+-'] or default)
        [ enabled.add(l[1:]) for l in log if l.startswith('+') ]
        [ enabled.discard(l[1:]) for l in log if l.startswith('-') ]
        for l in ['log_recv','log_send','log_request','log_reply',
                  'log_truncated','log_error','log_data']:
            if l[4:] not in enabled:
                setattr(self,l,self.log_pass)
        self.prefix = prefix

    def log_pass(self,*args):
        pass

    def log_prefix(self,handler):
        if self.prefix:
            return "%s [%s:%s] " % (time.strftime("%Y-%m-%d %X"),
                               handler.__class__.__name__,
                               handler.server.resolver.__class__.__name__)
        else:
            return ""

    def log_recv(self,handler,data):
        log = "%sReceived: [%s:%d] (%s) <%d> : %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    len(data),
                    binascii.hexlify(data))

        self.currentLog.append(log)
        if self.verbose:
            print(log)

    def log_send(self,handler,data):
        log = "%sSent: [%s:%d] (%s) <%d> : %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    len(data),
                    binascii.hexlify(data))

        self.currentLog.append(log)
        if self.verbose:
            print(log)

    def log_request(self,handler,request):
        log = "%sRequest: [%s:%d] (%s) / '%s' (%s)" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    request.q.qname,
                    QTYPE[request.q.qtype])
        self.currentLog.append(log)
        if self.verbose:
            print(log)
        self.log_data(request)

    def log_reply(self,handler,reply):
        if reply.header.rcode == RCODE.NOERROR:
            log = "%sReply: [%s:%d] (%s) / '%s' (%s) / RRs: %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    reply.q.qname,
                    QTYPE[reply.q.qtype],
                    ",".join([QTYPE[a.rtype] for a in reply.rr]))
        else:
            log = "%sReply: [%s:%d] (%s) / '%s' (%s) / %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    reply.q.qname,
                    QTYPE[reply.q.qtype],
                    RCODE[reply.header.rcode])
        self.currentLog.append(log)
        if self.verbose:
            print(log)
        self.log_data(reply)

    def log_truncated(self,handler,reply):
        log = "%sTruncated Reply: [%s:%d] (%s) / '%s' (%s) / RRs: %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    reply.q.qname,
                    QTYPE[reply.q.qtype],
                    ",".join([QTYPE[a.rtype] for a in reply.rr]))

        self.currentLog.append(log)
        if self.verbose:
            print(log)        
        self.log_data(reply)

    def log_error(self,handler,e):
        log = "%sInvalid Request: [%s:%d] (%s) :: %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    e)

        self.currentLog.append(log)
        if self.verbose:
            print(log)

    def log_data(self,dnsobj):
        print("\n",dnsobj.toZone("    "),"\n",sep="")
        self.dataLog = str("\n" + str(dnsobj.toZone("    ")) + "\n" + str(sep=""))


class UDPServer(socketserver.ThreadingMixIn,socketserver.UDPServer):
    allow_reuse_address = True

class TCPServer(socketserver.ThreadingMixIn,socketserver.TCPServer):
    allow_reuse_address = True

class DNSServer(object):

    """
        Convenience wrapper for socketserver instance allowing
        either UDP/TCP server to be started in blocking more
        or as a background thread.

        Processing is delegated to custom resolver (instance) and
        optionally custom logger (instance), handler (class), and
        server (class)

        In most cases only a custom resolver instance is required
        (and possibly logger)
    """
    def __init__(self,resolver,
                      address="",
                      port=53,
                      tcp=False,
                      logger=None,
                      handler=DNSHandler,
                      server=None):
        """
            resolver:   resolver instance
            address:    listen address (default: "")
            port:       listen port (default: 53)
            tcp:        UDP (false) / TCP (true) (default: False)
            logger:     logger instance (default: DNSLogger)
            handler:    handler class (default: DNSHandler)
            server:     socketserver class (default: UDPServer/TCPServer)
        """
        if not server:
            if tcp:
                server = TCPServer
            else:
                server = UDPServer
        self.server = server((address,port),handler)
        self.server.resolver = resolver
        self.server.logger = logger or DNSLogger()

    def start(self):
        self.server.serve_forever()

    def start_thread(self):
        self.thread = threading.Thread(target=self.server.serve_forever)
        self.thread.daemon = True
        self.thread.start()

    def stop(self):
        self.server.shutdown()

    def isAlive(self):
        return self.thread.is_alive()

if __name__ == "__main__":
    import doctest,sys
    sys.exit(0 if doctest.testmod(optionflags=doctest.ELLIPSIS).failed == 0 else 1)

I did it by adding an if statement in the intercept class itself and redirecting the DNS to 0.0.0.0 if the if statement is true.我通过在拦截类本身中添加一个 if 语句并在 if 语句为真时将 DNS 重定向到 0.0.0.0 来做到这一点的。 I am aware that this code could be optimised and that certain parts of it could be removed.我知道可以优化此代码并且可以删除其中的某些部分。

# -*- coding: utf-8 -*-

"""
    InterceptResolver - proxy requests to upstream server
                        (optionally intercepting)
"""
from __future__ import print_function

import binascii,copy,socket,struct,sys

from dnslib import DNSRecord,RR,QTYPE,RCODE,parse_time
from dnslib.server import DNSServer,DNSHandler,BaseResolver,DNSLogger
from dnslib.label import DNSLabel

# Custom DNSLogger class with variable verbose setting
class variableVerboseDNSLogger(DNSLogger):

    """
        The class provides a default set of logging functions for the various
        stages of the request handled by a DNSServer instance which are
        enabled/disabled by flags in the 'log' class variable.

        To customise logging create an object which implements the DNSLogger
        interface and pass instance to DNSServer.

        The methods which the logger instance must implement are:

            log_recv          - Raw packet received
            log_send          - Raw packet sent
            log_request       - DNS Request
            log_reply         - DNS Response
            log_truncated     - Truncated
            log_error         - Decoding error
            log_data          - Dump full request/response
    """

    def __init__(self,log="",prefix=True,verbose=False,logToList=False):
        """
            Selectively enable log hooks depending on log argument
            (comma separated list of hooks to enable/disable)

            - If empty enable default log hooks
            - If entry starts with '+' (eg. +send,+recv) enable hook
            - If entry starts with '-' (eg. -data) disable hook
            - If entry doesn't start with +/- replace defaults

            Prefix argument enables/disables log prefix
        """
        self.verbose = verbose
        self.logToList = logToList
        self.currentLog = []
        default = ["request","reply","truncated","error"]
        log = log.split(",") if log else []
        enabled = set([ s for s in log if s[0] not in '+-'] or default)
        [ enabled.add(l[1:]) for l in log if l.startswith('+') ]
        [ enabled.discard(l[1:]) for l in log if l.startswith('-') ]
        for l in ['log_recv','log_send','log_request','log_reply',
                  'log_truncated','log_error','log_data']:
            if l[4:] not in enabled:
                setattr(self,l,self.log_pass)
        self.prefix = prefix

    def log_pass(self,*args):
        pass

    def log_prefix(self,handler):
        if self.prefix:
            return "%s [%s:%s] " % (time.strftime("%Y-%m-%d %X"),
                               handler.__class__.__name__,
                               handler.server.resolver.__class__.__name__)
        else:
            return ""

    def log_recv(self,handler,data):
        log = "%sReceived: [%s:%d] (%s) <%d> : %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    len(data),
                    binascii.hexlify(data))

        if self.verbose:
            print(log)
        if self.logToList:
            self.currentLog.append(log)
        

    def log_send(self,handler,data):
        log = "%sSent: [%s:%d] (%s) <%d> : %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    len(data),
                    binascii.hexlify(data))

        if self.verbose:
            print(log)
        if self.logToList:
            self.currentLog.append(log)
        

    def log_request(self,handler,request):
        log = "%sRequest: [%s:%d] (%s) / '%s' (%s)" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    request.q.qname,
                    QTYPE[request.q.qtype])
        if self.verbose:
            print(log)
        if self.logToList:
            self.currentLog.append(log)
        
        self.log_data(request)

    def log_reply(self,handler,reply):
        if reply.header.rcode == RCODE.NOERROR:
            log = "%sReply: [%s:%d] (%s) / '%s' (%s) / RRs: %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    reply.q.qname,
                    QTYPE[reply.q.qtype],
                    ",".join([QTYPE[a.rtype] for a in reply.rr]))
        else:
            log = "%sReply: [%s:%d] (%s) / '%s' (%s) / %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    reply.q.qname,
                    QTYPE[reply.q.qtype],
                    RCODE[reply.header.rcode])
        if self.verbose:
            print(log)
        if self.logToList:
            self.currentLog.append(log)
        
        self.log_data(reply)

    def log_truncated(self,handler,reply):
        log = "%sTruncated Reply: [%s:%d] (%s) / '%s' (%s) / RRs: %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    reply.q.qname,
                    QTYPE[reply.q.qtype],
                    ",".join([QTYPE[a.rtype] for a in reply.rr]))

        if self.verbose:
            print(log)
        if self.logToList:
            self.currentLog.append(log)
        
        self.log_data(reply)

    def log_error(self,handler,e):
        log = "%sInvalid Request: [%s:%d] (%s) :: %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    e)

        if self.verbose:
            print(log)
        if self.logToList:
            self.currentLog.append(log)

    def log_data(self,dnsobj):
        print("\n",dnsobj.toZone("    "),"\n",sep="")
        self.dataLog = str("\n" + str(dnsobj.toZone("    ")) + "\n" + str(sep=""))


class InterceptResolver(BaseResolver):

    """
        Intercepting resolver
        Proxy requests to upstream server optionally intercepting requests
        matching local records
    """

    def __init__(self,address,port,ttl,intercept,skip,nxdomain,forward,all_qtypes,timeout=0):
        """
            address/port    - upstream server
            ttl             - default ttl for intercept records
            intercept       - list of wildcard RRs to respond to (zone format)
            skip            - list of wildcard labels to skip
            nxdomain        - list of wildcard labels to return NXDOMAIN
            forward         - list of wildcard labels to forward
            all_qtypes      - intercept all qtypes if qname matches.
            timeout         - timeout for upstream server(s)
        """
        self.address = address
        self.port = port
        self.ttl = parse_time(ttl)
        self.skip = skip
        self.nxdomain = nxdomain
        self.forward = []
        for i in forward:
            qname, _, upstream = i.partition(':')
            upstream_ip, _, upstream_port = upstream.partition(':')
            self.forward.append((qname, upstream_ip, int(upstream_port or '53')))
        self.all_qtypes = all_qtypes
        self.timeout = timeout
        self.zone = []
        for i in intercept:
            if i == '-':
                i = sys.stdin.read()
            for rr in RR.fromZone(i,ttl=self.ttl):
                self.zone.append((rr.rname,QTYPE[rr.rtype],rr))

    def resolve(self,request,handler):
        matched = False
        reply = request.reply()
        qname = request.q.qname
        qtype = QTYPE[request.q.qtype]
        
        # Check for NXDOMAIN
        print("QNAME label= " + str(qname) + "\n")
        
        # Send to to upstream
        upstream, upstream_port = self.address,self.port
        if not reply.rr:
            try:
                if handler.protocol == 'udp':
                    proxy_r = request.send(upstream,upstream_port,
                                    timeout=self.timeout)
                else:
                    proxy_r = request.send(upstream,upstream_port,
                                    tcp=True,timeout=self.timeout)
                reply = DNSRecord.parse(proxy_r)
                # Detects if URL is the one below
                if(qname == "google.com."):
                    # Returns generic IP address
                    print("REPLY = " + str(reply))
                    reply = request.reply()
                    reply.add_answer(". A 0.0.0.0")
                    print("REPLY = " + str(reply))
            except socket.timeout:
                reply.header.rcode = getattr(RCODE,'SERVFAIL')

        return reply

if __name__ == '__main__':

    import argparse,sys,time

    # Most of these don't do anything so dont use them
    p = argparse.ArgumentParser(description="DNS Intercept Proxy, please ignore arguments and run it")
    p.add_argument("--intercept","-i",action="append",
                    metavar="<zone record>",
                    help="Intercept requests matching zone record (glob) ('-' for stdin)")
    p.add_argument("--skip","-s",action="append",
                    metavar="<label>",
                    help="Don't intercept matching label (glob)")
    p.add_argument("--nxdomain","-x",action="append",
                    metavar="<label>",
                    help="Return NXDOMAIN (glob)")
    p.add_argument("--forward","-f",action="append",
                   metavar="<label:dns server:port>",
                   help="forward requests matching label (glob) to dns server")
    p.add_argument("--ttl","-t",default="60s",
                    metavar="<ttl>",
                    help="Intercept TTL (default: 60s)")
    p.add_argument("--timeout","-o",type=float,default=5,
                    metavar="<timeout>",
                    help="Upstream timeout (default: 5s)")
    p.add_argument("--all-qtypes",action='store_true',default=False,
                   help="Return an empty response if qname matches, but qtype doesn't")
    p.add_argument("--log",default="request,reply,truncated,error",
                    help="Log hooks to enable (default: +request,+reply,+truncated,+error,-recv,-send,-data)")
    p.add_argument("--log-prefix",action='store_true',default=False,
                    help="Log prefix (timestamp/handler/resolver) (default: False)")
    args = p.parse_args()

    #'args.dns,_,args.dns_port = args.upstream.partition(':')
    #args.dns_port = int(args.dns_port or 53)
    tcpEnabled = True
    externalDNS = "***.***.***.***"
    port = 53
    externalDNSPort = 53

    resolver = InterceptResolver(address = externalDNS,
                                 port = externalDNSPort,
                                 ttl = "60s",
                                 intercept = args.intercept or [],
                                 skip = args.skip or [],
                                 nxdomain = args.nxdomain or [],
                                 forward = args.forward or [],
                                 all_qtypes = args.all_qtypes,
                                 timeout = args.timeout)
    
    logger = variableVerboseDNSLogger(log = args.log,
                                      prefix = args.log_prefix,
                                      verbose = False,
                                      logToList = True)

    print("Starting Intercept Proxy (%s:%d -> %s:%d) [%s]" % (
                        externalDNS or "*",port,
                        externalDNS,externalDNSPort,
                        "UDP/TCP" if tcpEnabled else "UDP"))

    for rr in resolver.zone:
        print("    | ",rr[2].toZone(),sep="")
    if resolver.nxdomain:
        print("    NXDOMAIN:",", ".join(resolver.nxdomain))
    if resolver.skip:
        print("    Skipping:",", ".join(resolver.skip))
    if resolver.forward:
        print("    Forwarding:")
        for i in resolver.forward:
            print("    | ","%s:%s:%s" % i,sep="")

    DNSHandler.log = {
        #'log_request',       # DNS Request
        #'log_reply',        # DNS Response
        #'log_truncated',    # Truncated
        #'log_error',        # Decoding error
    }

    udp_server = DNSServer(resolver,
                           port=port,
                           address="",
                           logger=logger)
    udp_server.start_thread()

    if tcpEnabled:
        tcp_server = DNSServer(resolver,
                               port=port,
                               address="",
                               tcp=True,
                               logger=logger)
        tcp_server.start_thread()

    while udp_server.isAlive():
        time.sleep(1)
        #print(DNSHandler.log)
        #print("LOG START")
        #print(logger.currentLog)
        logger.currentLog = []
        #print("LOG END")

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM