Commit 3d890f23 authored by Joel Sommers's avatar Joel Sommers
Browse files

Bugfix issue 4 + additional tests to cover the situation

parent 2c6c2d1e
......@@ -35,6 +35,11 @@ class ICMP(PacketHeaderBase):
self._code = self._valid_codes_map[self._type].EchoRequest
self._icmpdata = ICMPEchoRequest()
self._checksum = 0
# because of dependencies between type/code, must ensure that
# if icmptype is a keyword arg it gets set *first*
if 'icmptype' in kwargs:
self.icmptype = kwargs['icmptype']
del kwargs['icmptype']
super().__init__(**kwargs)
def size(self):
......@@ -78,9 +83,12 @@ class ICMP(PacketHeaderBase):
return self._code
@icmptype.setter
def icmptype(self,value):
def icmptype(self, value):
if not isinstance(value, self._valid_types):
raise ValueError("ICMP type must be an {} enumeration".format(type(self._valid_types)))
value = self._valid_types(value)
# JS: revised following line as above; too restrictive
# raise ValueError("ICMP type must be an {} enumeration".format(type(self._valid_types)))
cls = self._classtype_from_icmptype(value)
if not issubclass(self.icmpdata.__class__, cls):
self.icmpdata = cls()
......@@ -95,12 +103,16 @@ class ICMP(PacketHeaderBase):
def icmpcode(self,value):
if issubclass(value.__class__, IntEnum):
validcodes = self._valid_codes_map[self._type]
if value not in validcodes:
raise ValueError("Invalid code {} for type {}".format(value, self._type))
self._check_typecode_consistency(value)
self._code = value
elif isinstance(value, int):
self._code = self._valid_codes_map[self.icmptype](value)
def _check_typecode_consistency(self, xcode):
validcodes = self._valid_codes_map[self._type]
if xcode not in validcodes:
raise ValueError("Invalid code {} for type {}".format(xcode, self._type.name, self._type))
def __str__(self):
typecode = self.icmptype.name
if self.icmptype.name != self.icmpcode.name:
......
......@@ -21,6 +21,11 @@ class ICMPv6(ICMP):
self._code = self._valid_codes_map[self._type].EchoRequest
self._icmpdata = ICMPv6ClassFromType(self._type)()
self._checksum = 0
# if kwargs are given, must ensure that type gets set
# before code due to dependencies on validity.
if 'icmptype' in kwargs:
self.icmptype = kwargs['icmptype']
del kwargs['icmptype']
super().__init__(**kwargs)
def checksum(self):
......
......@@ -4,10 +4,16 @@ from switchyard.lib.packet.common import ICMPType
import unittest
class ICMPPacketTests(unittest.TestCase):
def testBadCode(self):
def testBadTypeCode(self):
i = ICMP()
with self.assertRaises(ValueError):
i.icmptype = 0
i.icmptype = 2
with self.assertRaises(ValueError):
i.icmptype = 19
with self.assertRaises(ValueError):
i.icmptype = 49
with self.assertRaises(ValueError):
i.icmpcode = ICMPType.EchoRequest
......@@ -15,6 +21,10 @@ class ICMPPacketTests(unittest.TestCase):
with self.assertRaises(ValueError):
i.icmpcode = 1
i.icmptype = 0 # echo reply; any code other than 0 is invalid
with self.assertRaises(ValueError):
i.icmpcode = 1
def testChangeICMPIdentity(self):
i = ICMP() # echorequest, by default
i.icmptype = ICMPType.EchoReply
......@@ -122,6 +132,42 @@ class ICMPPacketTests(unittest.TestCase):
self.assertEqual(i.icmptype, ICMPType.SourceQuench)
self.assertIsInstance(i.icmpdata, ICMPSourceQuench)
# not enough bytes
with self.assertRaises(Exception):
i.from_bytes(b'\x04\x00\xfb\xff\x00\x00\x00')
def testUnreachableMtu(self):
i = ICMP()
i.icmptype = ICMPType.DestinationUnreachable
i.icmpdata.nexthopmtu = 5
i.icmpdata.origdgramlen = 42
self.assertEqual(i.to_bytes()[-1], 5)
self.assertEqual(i.icmpdata.origdgramlen, 42)
def testICMPKwArgsValid(self):
icmptype = ICMPType.DestinationUnreachable
# valid combination
i = ICMP(icmptype=icmptype, icmpcode=ICMPTypeCodeMap[icmptype].NetworkUnreachable)
self.assertIsInstance(i.icmpdata, ICMPDestinationUnreachable)
i2 = ICMP()
i2.from_bytes(i.to_bytes())
self.assertIsInstance(i2.icmpdata, ICMPDestinationUnreachable)
self.assertEqual(i2.icmptype, icmptype)
self.assertEqual(i2.icmpcode, ICMPTypeCodeMap[icmptype].NetworkUnreachable)
def testICMPKwArgsInvalid1(self):
with self.assertRaises(ValueError):
i = ICMP(icmptype=0, icmpcode=45)
def testICMPKwArgsInvalid2(self):
with self.assertRaises(ValueError):
i = ICMP(icmptype=ICMPType.EchoRequest, icmpcode=ICMPTypeCodeMap[ICMPType.DestinationUnreachable].CommunicationAdministrativelyProhibited)
def testStringify(self):
i = ICMP(icmptype=3, icmpcode=8)
s = str(i)
self.assertTrue(s.startswith('ICMP DestinationUnreachable:SourceHostIsolated'))
if __name__ == '__main__':
unittest.main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment