簡體   English   中英

在 Python3 中創建一個帶有網站阻止的 DNS 服務器?

[英]Create a DNS Server in Python3 with website blocking?

我可以使用阻止/重定向某些網站 URL 的 python3 創建一個 DNS 服務器嗎? 我曾嘗試使用 dnslib,但似乎無法阻止 URL。

我正在嘗試創建一個可以在 DNS 級別阻止網站的程序,我知道 pi-hole 但我想創建自己的。

到目前為止我的代碼(修改后的示例代碼):

# -*- coding: utf-8 -*-

"""
    InterceptResolver - proxy requests to upstream server
                        (optionally intercepting)
"""
from __future__ import print_function

import binascii,copy,socket,struct,sys

from dnslib import DNSRecord,RR,QTYPE,RCODE,parse_time
from dnslib.server import DNSServer,DNSHandler,BaseResolver,DNSLogger
from dnslib.label import DNSLabel

class InterceptResolver(BaseResolver):

    """
        Intercepting resolver
        Proxy requests to upstream server optionally intercepting requests
        matching local records
    """

    def __init__(self,address,port,ttl,intercept,skip,nxdomain,forward,all_qtypes,timeout=0):
        """
            address/port    - upstream server
            ttl             - default ttl for intercept records
            intercept       - list of wildcard RRs to respond to (zone format)
            skip            - list of wildcard labels to skip
            nxdomain        - list of wildcard labels to return NXDOMAIN
            forward         - list of wildcard labels to forward
            all_qtypes      - intercept all qtypes if qname matches.
            timeout         - timeout for upstream server(s)
        """
        self.address = address
        self.port = port
        self.ttl = parse_time(ttl)
        self.skip = skip
        self.nxdomain = nxdomain
        self.forward = []
        for i in forward:
            qname, _, upstream = i.partition(':')
            upstream_ip, _, upstream_port = upstream.partition(':')
            self.forward.append((qname, upstream_ip, int(upstream_port or '53')))
        self.all_qtypes = all_qtypes
        self.timeout = timeout
        self.zone = []
        for i in intercept:
            if i == '-':
                i = sys.stdin.read()
            for rr in RR.fromZone(i,ttl=self.ttl):
                self.zone.append((rr.rname,QTYPE[rr.rtype],rr))

    def resolve(self,request,handler):
        matched = False
        reply = request.reply()
        qname = request.q.qname
        qtype = QTYPE[request.q.qtype]
        # Try to resolve locally unless on skip list
        if not any([qname.matchGlob(s) for s in self.skip]):
            for name,rtype,rr in self.zone:
                if qname.matchGlob(name):
                    if qtype in (rtype,'ANY','CNAME'):
                        a = copy.copy(rr)
                        a.rname = qname
                        reply.add_answer(a)
                    matched = True
        # Check for NXDOMAIN
        if any([qname.matchGlob(s) for s in self.nxdomain]):
            reply.header.rcode = getattr(RCODE,'NXDOMAIN')
            return reply
        if matched and self.all_qtypes:
            return reply
        # Otherwise proxy, first checking forwards, then to upstream.
        upstream, upstream_port = self.address,self.port
        if not any([qname.matchGlob(s) for s in self.skip]):
            for name, ip, port in self.forward:
                if qname.matchGlob(name):
                    upstream, upstream_port = ip, port
        if not reply.rr:
            try:
                if handler.protocol == 'udp':
                    proxy_r = request.send(upstream,upstream_port,
                                    timeout=self.timeout)
                else:
                    proxy_r = request.send(upstream,upstream_port,
                                    tcp=True,timeout=self.timeout)
                reply = DNSRecord.parse(proxy_r)
            except socket.timeout:
                reply.header.rcode = getattr(RCODE,'SERVFAIL')

        return reply

if __name__ == '__main__':

    import argparse,sys,time

    p = argparse.ArgumentParser(description="DNS Intercept Proxy")
    p.add_argument("--port","-p",type=int,default=53,
                    metavar="<port>",
                    help="Local proxy port (default:53)")
    p.add_argument("--address","-a",default="",
                    metavar="<address>",
                    help="Local proxy listen address (default:all)")
    p.add_argument("--upstream","-u",default="8.8.8.8:53",
            metavar="<dns server:port>",
                    help="Upstream DNS server:port (default:8.8.8.8:53)")
    p.add_argument("--intercept","-i",action="append",
                    metavar="<zone record>",
                    help="Intercept requests matching zone record (glob) ('-' for stdin)")
    p.add_argument("--skip","-s",action="append",
                    metavar="<label>",
                    help="Don't intercept matching label (glob)")
    p.add_argument("--nxdomain","-x",action="append",
                    metavar="<label>",
                    help="Return NXDOMAIN (glob)")
    p.add_argument("--forward","-f",action="append",
                   metavar="<label:dns server:port>",
                   help="forward requests matching label (glob) to dns server")
    p.add_argument("--ttl","-t",default="60s",
                    metavar="<ttl>",
                    help="Intercept TTL (default: 60s)")
    p.add_argument("--timeout","-o",type=float,default=5,
                    metavar="<timeout>",
                    help="Upstream timeout (default: 5s)")
    p.add_argument("--all-qtypes",action='store_true',default=False,
                   help="Return an empty response if qname matches, but qtype doesn't")
    p.add_argument("--log",default="request,reply,truncated,error",
                    help="Log hooks to enable (default: +request,+reply,+truncated,+error,-recv,-send,-data)")
    p.add_argument("--log-prefix",action='store_true',default=False,
                    help="Log prefix (timestamp/handler/resolver) (default: False)")
    args = p.parse_args()

    args.dns,_,args.dns_port = args.upstream.partition(':')
    args.dns_port = int(args.dns_port or 53)
    tcpEnabled = True
    
    resolver = InterceptResolver(args.dns,
                                 args.dns_port,
                                 args.ttl,
                                 args.intercept or [],
                                 args.skip or [],
                                 args.nxdomain or [],
                                 args.forward or [],
                                 args.all_qtypes,
                                 args.timeout)
    logger = DNSLogger(args.log,args.log_prefix)

    print("Starting Intercept Proxy (%s:%d -> %s:%d) [%s]" % (
                        args.address or "*",args.port,
                        args.dns,args.dns_port,
                        "UDP/TCP" if tcpEnabled else "UDP"))

    for rr in resolver.zone:
        print("    | ",rr[2].toZone(),sep="")
    if resolver.nxdomain:
        print("    NXDOMAIN:",", ".join(resolver.nxdomain))
    if resolver.skip:
        print("    Skipping:",", ".join(resolver.skip))
    if resolver.forward:
        print("    Forwarding:")
        for i in resolver.forward:
            print("    | ","%s:%s:%s" % i,sep="")

    DNSHandler.log = {
        #'log_request',       # DNS Request
        #'log_reply',        # DNS Response
        #'log_truncated',    # Truncated
        #'log_error',        # Decoding error
    }

    udp_server = DNSServer(resolver,
                           port=args.port,
                           address=args.address,
                           logger=logger)
    udp_server.start_thread()

    if tcpEnabled:
        tcp_server = DNSServer(resolver,
                               port=args.port,
                               address=args.address,
                               tcp=True,
                               logger=logger)
        tcp_server.start_thread()

    while udp_server.isAlive():
        time.sleep(1)
        #print(DNSHandler.log)
        #print("LOG START")
        print(logger.currentLog)
        #print("LOG END")

我還在 dnslib 庫中稍微修改了 server.py 文件:

# -*- coding: utf-8 -*-

"""
    DNS server framework - intended to simplify creation of custom resolvers.

    Comprises the following components:

        DNSServer   - socketserver wrapper (in most cases you should just
                      need to pass this an appropriate resolver instance
                      and start in either foreground/background)

        DNSHandler  - handler instantiated by DNSServer to handle requests
                      The 'handle' method deals with the sending/receiving
                      packets (handling TCP length prefix) and delegates
                      the protocol handling to 'get_reply'. This decodes
                      packet, hands off a DNSRecord to the Resolver instance,
                      and encodes the returned DNSRecord.

                      In most cases you dont need to change DNSHandler unless
                      you need to get hold of the raw protocol data in the
                      Resolver

        DNSLogger   - The class provides a default set of logging functions for
                      the various stages of the request handled by a DNSServer
                      instance which are enabled/disabled by flags in the 'log'
                      class variable.

        Resolver    - Instance implementing a 'resolve' method that receives
                      the decodes request packet and returns a response.

                      To implement a custom resolver in most cases all you need
                      is to implement this interface.

                      Note that there is only a single instance of the Resolver
                      so need to be careful about thread-safety and blocking

        The following examples use the server framework:

            fixedresolver.py    - Simple resolver which will respond to all
                                  requests with a fixed response
            zoneresolver.py     - Resolver which will take a standard zone
                                  file input
            shellresolver.py    - Example of a dynamic resolver
            proxy.py            - DNS proxy
            intercept.py        - Intercepting DNS proxy

        >>> resolver = BaseResolver()
        >>> logger = DNSLogger(prefix=False)
        >>> server = DNSServer(resolver,port=8053,address="localhost",logger=logger)
        >>> server.start_thread()
        >>> q = DNSRecord.question("abc.def")
        >>> a = q.send("localhost",8053)
        Request: [...] (udp) / 'abc.def.' (A)
        Reply: [...] (udp) / 'abc.def.' (A) / NXDOMAIN
        >>> print(DNSRecord.parse(a))
        ;; ->>HEADER<<- opcode: QUERY, status: NXDOMAIN, id: ...
        ;; flags: qr aa rd ra; QUERY: 1, ANSWER: 0, AUTHORITY: 0, ADDITIONAL: 0
        ;; QUESTION SECTION:
        ;abc.def.                       IN      A
        >>> server.stop()

        >>> class TestResolver:
        ...     def resolve(self,request,handler):
        ...         reply = request.reply()
        ...         reply.add_answer(*RR.fromZone("abc.def. 60 A 1.2.3.4"))
        ...         return reply
        >>> resolver = TestResolver()
        >>> server = DNSServer(resolver,port=8053,address="localhost",logger=logger,tcp=True)
        >>> server.start_thread()
        >>> a = q.send("localhost",8053,tcp=True)
        Request: [...] (tcp) / 'abc.def.' (A)
        Reply: [...] (tcp) / 'abc.def.' (A) / RRs: A
        >>> print(DNSRecord.parse(a))
        ;; ->>HEADER<<- opcode: QUERY, status: NOERROR, id: ...
        ;; flags: qr aa rd ra; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0
        ;; QUESTION SECTION:
        ;abc.def.                       IN      A
        ;; ANSWER SECTION:
        abc.def.                60      IN      A       1.2.3.4
        >>> server.stop()


"""
from __future__ import print_function

import binascii,socket,struct,threading,time

try:
    import socketserver
except ImportError:
    import SocketServer as socketserver

from dnslib import DNSRecord,DNSError,QTYPE,RCODE,RR

class BaseResolver(object):
    """
        Base resolver implementation. Provides 'resolve' method which is
        called by DNSHandler with the decode request (DNSRecord instance)
        and returns a DNSRecord instance as reply.

        In most cases you should be able to create a custom resolver by
        just replacing the resolve method with appropriate resolver code for
        application (see fixedresolver/zoneresolver/shellresolver for
        examples)

        Note that a single instance is used by all DNSHandler instances so
        need to consider blocking & thread safety.
    """
    def resolve(self,request,handler):
        """
            Example resolver - respond to all requests with NXDOMAIN
        """
        reply = request.reply()
        reply.header.rcode = getattr(RCODE,'NXDOMAIN')
        return reply

class DNSHandler(socketserver.BaseRequestHandler):
    """
        Handler for socketserver. Transparently handles both TCP/UDP requests
        (TCP requests have length prepended) and hands off lookup to resolver
        instance specified in <SocketServer>.resolver
    """

    udplen = 0                  # Max udp packet length (0 = ignore)

    def handle(self):
        if self.server.socket_type == socket.SOCK_STREAM:
            self.protocol = 'tcp'
            data = self.request.recv(8192)
            if len(data) < 2:
                self.server.logger.log_error(self,"Request Truncated")
                return
            length = struct.unpack("!H",bytes(data[:2]))[0]
            while len(data) - 2 < length:
                new_data = self.request.recv(8192)
                if not new_data:
                    break
                data += new_data
            data = data[2:]
        else:
            self.protocol = 'udp'
            data,connection = self.request

        self.server.logger.log_recv(self,data)

        try:
            rdata = self.get_reply(data)
            self.server.logger.log_send(self,rdata)

            if self.protocol == 'tcp':
                rdata = struct.pack("!H",len(rdata)) + rdata
                self.request.sendall(rdata)
            else:
                connection.sendto(rdata,self.client_address)

        except DNSError as e:
            self.server.logger.log_error(self,e)

    def get_reply(self,data):
        request = DNSRecord.parse(data)
        self.server.logger.log_request(self,request)

        resolver = self.server.resolver
        reply = resolver.resolve(request,self)
        self.server.logger.log_reply(self,reply)

        if self.protocol == 'udp':
            rdata = reply.pack()
            if self.udplen and len(rdata) > self.udplen:
                truncated_reply = reply.truncate()
                rdata = truncated_reply.pack()
                self.server.logger.log_truncated(self,truncated_reply)
        else:
            rdata = reply.pack()

        return rdata

class DNSLogger:

    """
        The class provides a default set of logging functions for the various
        stages of the request handled by a DNSServer instance which are
        enabled/disabled by flags in the 'log' class variable.

        To customise logging create an object which implements the DNSLogger
        interface and pass instance to DNSServer.

        The methods which the logger instance must implement are:

            log_recv          - Raw packet received
            log_send          - Raw packet sent
            log_request       - DNS Request
            log_reply         - DNS Response
            log_truncated     - Truncated
            log_error         - Decoding error
            log_data          - Dump full request/response
    """

    def __init__(self,log="",prefix=True,verbose=False):
        """
            Selectively enable log hooks depending on log argument
            (comma separated list of hooks to enable/disable)

            - If empty enable default log hooks
            - If entry starts with '+' (eg. +send,+recv) enable hook
            - If entry starts with '-' (eg. -data) disable hook
            - If entry doesn't start with +/- replace defaults

            Prefix argument enables/disables log prefix
        """
        self.verbose = verbose
        self.currentLog = []
        default = ["request","reply","truncated","error"]
        log = log.split(",") if log else []
        enabled = set([ s for s in log if s[0] not in '+-'] or default)
        [ enabled.add(l[1:]) for l in log if l.startswith('+') ]
        [ enabled.discard(l[1:]) for l in log if l.startswith('-') ]
        for l in ['log_recv','log_send','log_request','log_reply',
                  'log_truncated','log_error','log_data']:
            if l[4:] not in enabled:
                setattr(self,l,self.log_pass)
        self.prefix = prefix

    def log_pass(self,*args):
        pass

    def log_prefix(self,handler):
        if self.prefix:
            return "%s [%s:%s] " % (time.strftime("%Y-%m-%d %X"),
                               handler.__class__.__name__,
                               handler.server.resolver.__class__.__name__)
        else:
            return ""

    def log_recv(self,handler,data):
        log = "%sReceived: [%s:%d] (%s) <%d> : %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    len(data),
                    binascii.hexlify(data))

        self.currentLog.append(log)
        if self.verbose:
            print(log)

    def log_send(self,handler,data):
        log = "%sSent: [%s:%d] (%s) <%d> : %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    len(data),
                    binascii.hexlify(data))

        self.currentLog.append(log)
        if self.verbose:
            print(log)

    def log_request(self,handler,request):
        log = "%sRequest: [%s:%d] (%s) / '%s' (%s)" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    request.q.qname,
                    QTYPE[request.q.qtype])
        self.currentLog.append(log)
        if self.verbose:
            print(log)
        self.log_data(request)

    def log_reply(self,handler,reply):
        if reply.header.rcode == RCODE.NOERROR:
            log = "%sReply: [%s:%d] (%s) / '%s' (%s) / RRs: %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    reply.q.qname,
                    QTYPE[reply.q.qtype],
                    ",".join([QTYPE[a.rtype] for a in reply.rr]))
        else:
            log = "%sReply: [%s:%d] (%s) / '%s' (%s) / %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    reply.q.qname,
                    QTYPE[reply.q.qtype],
                    RCODE[reply.header.rcode])
        self.currentLog.append(log)
        if self.verbose:
            print(log)
        self.log_data(reply)

    def log_truncated(self,handler,reply):
        log = "%sTruncated Reply: [%s:%d] (%s) / '%s' (%s) / RRs: %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    reply.q.qname,
                    QTYPE[reply.q.qtype],
                    ",".join([QTYPE[a.rtype] for a in reply.rr]))

        self.currentLog.append(log)
        if self.verbose:
            print(log)        
        self.log_data(reply)

    def log_error(self,handler,e):
        log = "%sInvalid Request: [%s:%d] (%s) :: %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    e)

        self.currentLog.append(log)
        if self.verbose:
            print(log)

    def log_data(self,dnsobj):
        print("\n",dnsobj.toZone("    "),"\n",sep="")
        self.dataLog = str("\n" + str(dnsobj.toZone("    ")) + "\n" + str(sep=""))


class UDPServer(socketserver.ThreadingMixIn,socketserver.UDPServer):
    allow_reuse_address = True

class TCPServer(socketserver.ThreadingMixIn,socketserver.TCPServer):
    allow_reuse_address = True

class DNSServer(object):

    """
        Convenience wrapper for socketserver instance allowing
        either UDP/TCP server to be started in blocking more
        or as a background thread.

        Processing is delegated to custom resolver (instance) and
        optionally custom logger (instance), handler (class), and
        server (class)

        In most cases only a custom resolver instance is required
        (and possibly logger)
    """
    def __init__(self,resolver,
                      address="",
                      port=53,
                      tcp=False,
                      logger=None,
                      handler=DNSHandler,
                      server=None):
        """
            resolver:   resolver instance
            address:    listen address (default: "")
            port:       listen port (default: 53)
            tcp:        UDP (false) / TCP (true) (default: False)
            logger:     logger instance (default: DNSLogger)
            handler:    handler class (default: DNSHandler)
            server:     socketserver class (default: UDPServer/TCPServer)
        """
        if not server:
            if tcp:
                server = TCPServer
            else:
                server = UDPServer
        self.server = server((address,port),handler)
        self.server.resolver = resolver
        self.server.logger = logger or DNSLogger()

    def start(self):
        self.server.serve_forever()

    def start_thread(self):
        self.thread = threading.Thread(target=self.server.serve_forever)
        self.thread.daemon = True
        self.thread.start()

    def stop(self):
        self.server.shutdown()

    def isAlive(self):
        return self.thread.is_alive()

if __name__ == "__main__":
    import doctest,sys
    sys.exit(0 if doctest.testmod(optionflags=doctest.ELLIPSIS).failed == 0 else 1)

我通過在攔截類本身中添加一個 if 語句並在 if 語句為真時將 DNS 重定向到 0.0.0.0 來做到這一點的。 我知道可以優化此代碼並且可以刪除其中的某些部分。

# -*- coding: utf-8 -*-

"""
    InterceptResolver - proxy requests to upstream server
                        (optionally intercepting)
"""
from __future__ import print_function

import binascii,copy,socket,struct,sys

from dnslib import DNSRecord,RR,QTYPE,RCODE,parse_time
from dnslib.server import DNSServer,DNSHandler,BaseResolver,DNSLogger
from dnslib.label import DNSLabel

# Custom DNSLogger class with variable verbose setting
class variableVerboseDNSLogger(DNSLogger):

    """
        The class provides a default set of logging functions for the various
        stages of the request handled by a DNSServer instance which are
        enabled/disabled by flags in the 'log' class variable.

        To customise logging create an object which implements the DNSLogger
        interface and pass instance to DNSServer.

        The methods which the logger instance must implement are:

            log_recv          - Raw packet received
            log_send          - Raw packet sent
            log_request       - DNS Request
            log_reply         - DNS Response
            log_truncated     - Truncated
            log_error         - Decoding error
            log_data          - Dump full request/response
    """

    def __init__(self,log="",prefix=True,verbose=False,logToList=False):
        """
            Selectively enable log hooks depending on log argument
            (comma separated list of hooks to enable/disable)

            - If empty enable default log hooks
            - If entry starts with '+' (eg. +send,+recv) enable hook
            - If entry starts with '-' (eg. -data) disable hook
            - If entry doesn't start with +/- replace defaults

            Prefix argument enables/disables log prefix
        """
        self.verbose = verbose
        self.logToList = logToList
        self.currentLog = []
        default = ["request","reply","truncated","error"]
        log = log.split(",") if log else []
        enabled = set([ s for s in log if s[0] not in '+-'] or default)
        [ enabled.add(l[1:]) for l in log if l.startswith('+') ]
        [ enabled.discard(l[1:]) for l in log if l.startswith('-') ]
        for l in ['log_recv','log_send','log_request','log_reply',
                  'log_truncated','log_error','log_data']:
            if l[4:] not in enabled:
                setattr(self,l,self.log_pass)
        self.prefix = prefix

    def log_pass(self,*args):
        pass

    def log_prefix(self,handler):
        if self.prefix:
            return "%s [%s:%s] " % (time.strftime("%Y-%m-%d %X"),
                               handler.__class__.__name__,
                               handler.server.resolver.__class__.__name__)
        else:
            return ""

    def log_recv(self,handler,data):
        log = "%sReceived: [%s:%d] (%s) <%d> : %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    len(data),
                    binascii.hexlify(data))

        if self.verbose:
            print(log)
        if self.logToList:
            self.currentLog.append(log)
        

    def log_send(self,handler,data):
        log = "%sSent: [%s:%d] (%s) <%d> : %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    len(data),
                    binascii.hexlify(data))

        if self.verbose:
            print(log)
        if self.logToList:
            self.currentLog.append(log)
        

    def log_request(self,handler,request):
        log = "%sRequest: [%s:%d] (%s) / '%s' (%s)" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    request.q.qname,
                    QTYPE[request.q.qtype])
        if self.verbose:
            print(log)
        if self.logToList:
            self.currentLog.append(log)
        
        self.log_data(request)

    def log_reply(self,handler,reply):
        if reply.header.rcode == RCODE.NOERROR:
            log = "%sReply: [%s:%d] (%s) / '%s' (%s) / RRs: %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    reply.q.qname,
                    QTYPE[reply.q.qtype],
                    ",".join([QTYPE[a.rtype] for a in reply.rr]))
        else:
            log = "%sReply: [%s:%d] (%s) / '%s' (%s) / %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    reply.q.qname,
                    QTYPE[reply.q.qtype],
                    RCODE[reply.header.rcode])
        if self.verbose:
            print(log)
        if self.logToList:
            self.currentLog.append(log)
        
        self.log_data(reply)

    def log_truncated(self,handler,reply):
        log = "%sTruncated Reply: [%s:%d] (%s) / '%s' (%s) / RRs: %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    reply.q.qname,
                    QTYPE[reply.q.qtype],
                    ",".join([QTYPE[a.rtype] for a in reply.rr]))

        if self.verbose:
            print(log)
        if self.logToList:
            self.currentLog.append(log)
        
        self.log_data(reply)

    def log_error(self,handler,e):
        log = "%sInvalid Request: [%s:%d] (%s) :: %s" % (
                    self.log_prefix(handler),
                    handler.client_address[0],
                    handler.client_address[1],
                    handler.protocol,
                    e)

        if self.verbose:
            print(log)
        if self.logToList:
            self.currentLog.append(log)

    def log_data(self,dnsobj):
        print("\n",dnsobj.toZone("    "),"\n",sep="")
        self.dataLog = str("\n" + str(dnsobj.toZone("    ")) + "\n" + str(sep=""))


class InterceptResolver(BaseResolver):

    """
        Intercepting resolver
        Proxy requests to upstream server optionally intercepting requests
        matching local records
    """

    def __init__(self,address,port,ttl,intercept,skip,nxdomain,forward,all_qtypes,timeout=0):
        """
            address/port    - upstream server
            ttl             - default ttl for intercept records
            intercept       - list of wildcard RRs to respond to (zone format)
            skip            - list of wildcard labels to skip
            nxdomain        - list of wildcard labels to return NXDOMAIN
            forward         - list of wildcard labels to forward
            all_qtypes      - intercept all qtypes if qname matches.
            timeout         - timeout for upstream server(s)
        """
        self.address = address
        self.port = port
        self.ttl = parse_time(ttl)
        self.skip = skip
        self.nxdomain = nxdomain
        self.forward = []
        for i in forward:
            qname, _, upstream = i.partition(':')
            upstream_ip, _, upstream_port = upstream.partition(':')
            self.forward.append((qname, upstream_ip, int(upstream_port or '53')))
        self.all_qtypes = all_qtypes
        self.timeout = timeout
        self.zone = []
        for i in intercept:
            if i == '-':
                i = sys.stdin.read()
            for rr in RR.fromZone(i,ttl=self.ttl):
                self.zone.append((rr.rname,QTYPE[rr.rtype],rr))

    def resolve(self,request,handler):
        matched = False
        reply = request.reply()
        qname = request.q.qname
        qtype = QTYPE[request.q.qtype]
        
        # Check for NXDOMAIN
        print("QNAME label= " + str(qname) + "\n")
        
        # Send to to upstream
        upstream, upstream_port = self.address,self.port
        if not reply.rr:
            try:
                if handler.protocol == 'udp':
                    proxy_r = request.send(upstream,upstream_port,
                                    timeout=self.timeout)
                else:
                    proxy_r = request.send(upstream,upstream_port,
                                    tcp=True,timeout=self.timeout)
                reply = DNSRecord.parse(proxy_r)
                # Detects if URL is the one below
                if(qname == "google.com."):
                    # Returns generic IP address
                    print("REPLY = " + str(reply))
                    reply = request.reply()
                    reply.add_answer(". A 0.0.0.0")
                    print("REPLY = " + str(reply))
            except socket.timeout:
                reply.header.rcode = getattr(RCODE,'SERVFAIL')

        return reply

if __name__ == '__main__':

    import argparse,sys,time

    # Most of these don't do anything so dont use them
    p = argparse.ArgumentParser(description="DNS Intercept Proxy, please ignore arguments and run it")
    p.add_argument("--intercept","-i",action="append",
                    metavar="<zone record>",
                    help="Intercept requests matching zone record (glob) ('-' for stdin)")
    p.add_argument("--skip","-s",action="append",
                    metavar="<label>",
                    help="Don't intercept matching label (glob)")
    p.add_argument("--nxdomain","-x",action="append",
                    metavar="<label>",
                    help="Return NXDOMAIN (glob)")
    p.add_argument("--forward","-f",action="append",
                   metavar="<label:dns server:port>",
                   help="forward requests matching label (glob) to dns server")
    p.add_argument("--ttl","-t",default="60s",
                    metavar="<ttl>",
                    help="Intercept TTL (default: 60s)")
    p.add_argument("--timeout","-o",type=float,default=5,
                    metavar="<timeout>",
                    help="Upstream timeout (default: 5s)")
    p.add_argument("--all-qtypes",action='store_true',default=False,
                   help="Return an empty response if qname matches, but qtype doesn't")
    p.add_argument("--log",default="request,reply,truncated,error",
                    help="Log hooks to enable (default: +request,+reply,+truncated,+error,-recv,-send,-data)")
    p.add_argument("--log-prefix",action='store_true',default=False,
                    help="Log prefix (timestamp/handler/resolver) (default: False)")
    args = p.parse_args()

    #'args.dns,_,args.dns_port = args.upstream.partition(':')
    #args.dns_port = int(args.dns_port or 53)
    tcpEnabled = True
    externalDNS = "***.***.***.***"
    port = 53
    externalDNSPort = 53

    resolver = InterceptResolver(address = externalDNS,
                                 port = externalDNSPort,
                                 ttl = "60s",
                                 intercept = args.intercept or [],
                                 skip = args.skip or [],
                                 nxdomain = args.nxdomain or [],
                                 forward = args.forward or [],
                                 all_qtypes = args.all_qtypes,
                                 timeout = args.timeout)
    
    logger = variableVerboseDNSLogger(log = args.log,
                                      prefix = args.log_prefix,
                                      verbose = False,
                                      logToList = True)

    print("Starting Intercept Proxy (%s:%d -> %s:%d) [%s]" % (
                        externalDNS or "*",port,
                        externalDNS,externalDNSPort,
                        "UDP/TCP" if tcpEnabled else "UDP"))

    for rr in resolver.zone:
        print("    | ",rr[2].toZone(),sep="")
    if resolver.nxdomain:
        print("    NXDOMAIN:",", ".join(resolver.nxdomain))
    if resolver.skip:
        print("    Skipping:",", ".join(resolver.skip))
    if resolver.forward:
        print("    Forwarding:")
        for i in resolver.forward:
            print("    | ","%s:%s:%s" % i,sep="")

    DNSHandler.log = {
        #'log_request',       # DNS Request
        #'log_reply',        # DNS Response
        #'log_truncated',    # Truncated
        #'log_error',        # Decoding error
    }

    udp_server = DNSServer(resolver,
                           port=port,
                           address="",
                           logger=logger)
    udp_server.start_thread()

    if tcpEnabled:
        tcp_server = DNSServer(resolver,
                               port=port,
                               address="",
                               tcp=True,
                               logger=logger)
        tcp_server.start_thread()

    while udp_server.isAlive():
        time.sleep(1)
        #print(DNSHandler.log)
        #print("LOG START")
        #print(logger.currentLog)
        logger.currentLog = []
        #print("LOG END")

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM