Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/mssqlclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
parser.add_argument('-command', action='extend', nargs='*', help='Commands to execute in the SQL shell. Multiple commands can be passed.')
parser.add_argument('-file', type=argparse.FileType('r'), help='input file with commands to execute in the SQL shell')

parser.add_argument('--host-name', action='store', default='', help='HostName property to use when connecting to the MSSQLServer')
parser.add_argument('--app-name', action='store', default='', help='AppName property to use when connecting to the MSSQLServer')

group = parser.add_argument_group('authentication')

group.add_argument('-hashes', action="store", metavar = "LMHASH:NTHASH", help='NTLM hashes, format is LMHASH:NTHASH')
Expand Down Expand Up @@ -87,7 +90,7 @@
if options.aesKey is not None:
options.k = True

ms_sql = tds.MSSQL(options.target_ip, int(options.port), remoteName)
ms_sql = tds.MSSQL(options.target_ip, int(options.port), remoteName, workstation_id=options.host_name, application_name=options.app_name)
ms_sql.connect()
try:
if options.k is True:
Expand Down
28 changes: 21 additions & 7 deletions impacket/tds.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import math
import datetime
import string
from uuid import uuid4

from impacket import ntlm, uuid, LOG
from impacket.structure import Structure
Expand Down Expand Up @@ -469,7 +470,7 @@ class TDS_COLMETADATA(Structure):
)

class MSSQL:
def __init__(self, address, port=1433, remoteName = '', rowsPrinter=DummyPrint()):
def __init__(self, address, port=1433, remoteName = '', workstation_id: str = "", application_name: str = "", rowsPrinter=DummyPrint()):
#self.packetSize = 32764
self.packetSize = 32763
self.server = address
Expand All @@ -487,6 +488,9 @@ def __init__(self, address, port=1433, remoteName = '', rowsPrinter=DummyPrint()
self.__rowsPrinter = rowsPrinter
self.mssql_version = ""

self._workstation_id = workstation_id or f"DESKTOP-{uuid4().hex[:8].upper()}"
self._application_name = application_name or "Microsoft SQL Server Management Studio - Query"

# With Kerberos we need to know to which MSSQL instance we are going to connect (to compute the SPN)
# As such we need to be able to list these instances which is what this code does
def getInstances(self, timeout = 5):
Expand Down Expand Up @@ -550,9 +554,10 @@ def preLogin(self):
def encryptPassword(self, password ):
return bytes(bytearray([((x & 0x0f) << 4) + ((x & 0xf0) >> 4) ^ 0xa5 for x in bytearray(password)]))

def connect(self):
def connect(self, timeout=30):
af, socktype, proto, canonname, sa = socket.getaddrinfo(self.server, self.port, 0, socket.SOCK_STREAM)[0]
sock = socket.socket(af, socktype, proto)
sock.settimeout(timeout)

try:
sock.connect(sa)
Expand Down Expand Up @@ -808,8 +813,8 @@ def kerberosLogin(self, database, username, password='', domain='', hashes=None,
self.version["ProductMajorVersion"], self.version["ProductMinorVersion"], self.version["ProductBuild"] = 10, 0, 20348

login = TDS_LOGIN()
login['HostName'] = (''.join([random.choice(string.ascii_letters) for _ in range(8)])).encode('utf-16le')
login['AppName'] = (''.join([random.choice(string.ascii_letters) for _ in range(8)])).encode('utf-16le')
login['HostName'] = self.workstation_id.encode('utf-16le')
login['AppName'] = self.application_name.encode('utf-16le')
login['ServerName'] = self.remoteName.encode('utf-16le')
login['CltIntName'] = login['AppName']
login['ClientPID'] = random.randint(0,1024)
Expand Down Expand Up @@ -1013,8 +1018,8 @@ def login(self, database, username, password='', domain='', hashes = None, useWi
self.version["ProductMajorVersion"], self.version["ProductMinorVersion"], self.version["ProductBuild"] = 10, 0, 20348

login = TDS_LOGIN()
login['HostName'] = (''.join([random.choice(string.ascii_letters) for i in range(8)])).encode('utf-16le')
login['AppName'] = (''.join([random.choice(string.ascii_letters) for i in range(8)])).encode('utf-16le')
login['HostName'] = self.workstation_id.encode('utf-16le')
login['AppName'] = self.application_name.encode('utf-16le')
login['ServerName'] = self.remoteName.encode('utf-16le')
login['CltIntName'] = login['AppName']
login['ClientPID'] = random.randint(0,1024)
Expand Down Expand Up @@ -1713,4 +1718,13 @@ def RunSQLStatement(self,db,sql_query,wait=True,**kwArgs):
self.RunSQLQuery(db,sql_query,wait=wait)
if self.lastError:
raise self.lastError
return True
return True

# Properties
@property
def workstation_id(self):
return self._workstation_id

@property
def application_name(self):
return self._application_name