From e45a6d5bed44f4cbc9e72fb382b4a008e3475b0d Mon Sep 17 00:00:00 2001 From: Matthew Chen Date: Wed, 1 Aug 2018 10:23:59 -0400 Subject: [PATCH] Code generate Swift wrappers for protocol buffers. --- Scripts/ProtoWrappers.py | 153 ++++++++---------- .../src/Protos/Generated/SSKProto.swift | 65 ++------ 2 files changed, 85 insertions(+), 133 deletions(-) diff --git a/Scripts/ProtoWrappers.py b/Scripts/ProtoWrappers.py index 4cfd253e5..32cde58ed 100755 --- a/Scripts/ProtoWrappers.py +++ b/Scripts/ProtoWrappers.py @@ -74,6 +74,9 @@ class LineWriter: def add(self, line): self.lines.append(('\t' * self.indent()) + line) + def add_raw(self, line): + self.lines.append(line) + def extend(self, text): for line in text.split('\n'): self.add(line) @@ -88,13 +91,6 @@ class LineWriter: lines = lines[:-1] self.lines = lines - # def prefixed_name(self, proto_name): - # names = self.all_context_proto_names() + [proto_name,] - # return self.args.wrapper_prefix + '_'.join(names) - - def is_top_level_entity(self): - return self.indent() == 0 - def newline(self): self.add('') @@ -278,6 +274,14 @@ class FileContext(BaseContext): import Foundation ''') + + writer.invalid_protobuf_error_name = '%sError' % self.args.wrapper_prefix + writer.extend((''' +public enum %s: Error { + case invalidProtobuf(description: String) +} +''' % writer.invalid_protobuf_error_name).strip()) + writer.newline() for child in self.children(): child.generate(writer) @@ -325,8 +329,6 @@ class MessageContext(BaseContext): child.prepare() def generate(self, writer): - is_top_level_entity = writer.is_top_level_entity() - writer.add('// MARK: - %s' % self.swift_name) writer.newline() @@ -335,15 +337,6 @@ class MessageContext(BaseContext): writer.push_context(self.proto_name, self.swift_name) - if is_top_level_entity: - writer.invalid_protobuf_error_name = '%sError' % self.swift_name - writer.extend((''' -public enum %s: Error { - case invalidProtobuf(description: String) -} -''' % writer.invalid_protobuf_error_name).strip()) - writer.newline() - for child in self.children(): child.generate(writer) @@ -436,6 +429,13 @@ public func serializedData() throws -> Data { writer.add('// MARK: - Begin Validation Logic for %s -' % self.swift_name) writer.newline() + + # Preserve existing validation logic. + validation_block = args.validation_map[self.swift_name] + if validation_block: + writer.add_raw(validation_block) + writer.newline() + writer.add('// MARK: - End Validation Logic for %s -' % self.swift_name) writer.newline() @@ -449,63 +449,6 @@ public func serializedData() throws -> Data { writer.add('}') writer.newline() - # @objc - # public init(serializedData: Data) throws { - # - # - # guard proto.hasSource else { - # throw EnvelopeError.invalidProtobuf(description: "missing required field: source") - # } - # self.source = proto.source - # - # guard proto.hasType else { - # throw EnvelopeError.invalidProtobuf(description: "missing required field: type") - # } - # self.type = { - # switch proto.type { - # case .unknown: - # return .unknown - # case .ciphertext: - # return .ciphertext - # case .keyExchange: - # return .keyExchange - # case .prekeyBundle: - # return .prekeyBundle - # case .receipt: - # return .receipt - # } - # }() - # - # guard proto.hasTimestamp else { - # throw EnvelopeError.invalidProtobuf(description: "missing required field: timestamp") - # } - # self.timestamp = proto.timestamp - # - # guard proto.hasSourceDevice else { - # throw EnvelopeError.invalidProtobuf(description: "missing required field: sourceDevice") - # } - # self.sourceDevice = proto.sourceDevice - # - # if proto.hasContent { - # self.content = proto.content - # } else { - # self.content = nil - # } - # - # if proto.hasLegacyMessage { - # self.legacyMessage = proto.legacyMessage - # } else { - # self.legacyMessage = nil - # } - # - # if proto.relay.count > 0 { - # self.relay = proto.relay - # } else { - # relay = nil - # } - # } - # - # asProtobuf() func writer.add('fileprivate var asProtobuf: %s {' % wrapped_swift_name) writer.push_indent() @@ -677,8 +620,8 @@ def line_parser(text): def parse_enum(args, proto_file_path, parser, parent_context, enum_name): - if args.verbose: - print '# enum:', enum_name + # if args.verbose: + # print '# enum:', enum_name context = EnumContext(args, parent_context, enum_name) @@ -700,8 +643,8 @@ def parse_enum(args, proto_file_path, parser, parent_context, enum_name): item_name = item_match.group(1).strip() item_index = item_match.group(2).strip() - if args.verbose: - print '\t enum item[%s]: %s' % (item_index, item_name) + # if args.verbose: + # print '\t enum item[%s]: %s' % (item_index, item_name) if item_name in context.item_names(): raise Exception('Duplicate enum name[%s]: %s' % (proto_file_path, item_name)) @@ -725,8 +668,8 @@ def optional_match_group(match, index): def parse_message(args, proto_file_path, parser, parent_context, message_name): - if args.verbose: - print '# message:', message_name + # if args.verbose: + # print '# message:', message_name context = MessageContext(args, parent_context, message_name) @@ -786,8 +729,8 @@ def parse_message(args, proto_file_path, parser, parent_context, message_name): } # print 'message_field:', message_field - if args.verbose: - print '\t message field[%s]: %s' % (item_index, str(message_field)) + # if args.verbose: + # print '\t message field[%s]: %s' % (item_index, str(message_field)) if item_name in context.field_names(): raise Exception('Duplicate message field name[%s]: %s' % (proto_file_path, item_name)) @@ -803,7 +746,45 @@ def parse_message(args, proto_file_path, parser, parent_context, message_name): raise Exception('Invalid message syntax[%s]: %s' % (proto_file_path, line)) - + +def preserve_validation_logic(args, proto_file_path, dst_file_path): + args.validation_map = {} + if os.path.exists(dst_file_path): + with open(dst_file_path, 'rt') as f: + old_text = f.read() + validation_start_regex = re.compile(r'// MARK: - Begin Validation Logic for ([^ ]+) -') + for match in validation_start_regex.finditer(old_text): + # print 'match' + name = match.group(1) + # print '\t name:', name + start = match.end(0) + # print '\t start:', start + end_marker = '// MARK: - End Validation Logic for %s -' % name + end = old_text.find(end_marker) + # print '\t end:', end + if end < start: + raise Exception('Malformed validation: %s' % proto_file_path) + validation_block = old_text[start:end] + # print '\t validation_block:', validation_block + + # Strip trailing whitespace. + validation_lines = validation_block.split('\n') + validation_lines = [line.rstrip() for line in validation_lines] + # Strip leading empty lines. + while len(validation_lines) > 0 and validation_lines[0] == '': + validation_lines = validation_lines[1:] + # Strip trailing empty lines. + while len(validation_lines) > 0 and validation_lines[-1] == '': + validation_lines = validation_lines[:-1] + validation_block = '\n'.join(validation_lines) + + if len(validation_block) > 0: + if args.verbose: + print 'Preserving validation logic for:', name + + args.validation_map[name] = validation_block + + def process_proto_file(args, proto_file_path, dst_file_path): with open(proto_file_path, 'rt') as f: text = f.read() @@ -856,6 +837,8 @@ def process_proto_file(args, proto_file_path, dst_file_path): raise Exception('Invalid syntax[%s]: %s' % (proto_file_path, line)) + preserve_validation_logic(args, proto_file_path, dst_file_path) + writer = LineWriter(args) context.prepare() context.generate(writer) @@ -896,5 +879,5 @@ if __name__ == "__main__": args.package = None process_proto_file(args, proto_file_path, dst_file_path) - print 'complete.' + # print 'complete.' \ No newline at end of file diff --git a/SignalServiceKit/src/Protos/Generated/SSKProto.swift b/SignalServiceKit/src/Protos/Generated/SSKProto.swift index c9dcb67f9..b7e64fed7 100644 --- a/SignalServiceKit/src/Protos/Generated/SSKProto.swift +++ b/SignalServiceKit/src/Protos/Generated/SSKProto.swift @@ -4,14 +4,14 @@ import Foundation +public enum SSKProtoError: Error { + case invalidProtobuf(description: String) +} + // MARK: - SSKProtoEnvelope @objc public class SSKProtoEnvelope: NSObject { - public enum SSKProtoEnvelopeError: Error { - case invalidProtobuf(description: String) - } - // MARK: - SSKProtoEnvelope_Type @objc public enum SSKProtoEnvelope_Type: Int32 { @@ -108,6 +108,19 @@ import Foundation // MARK: - Begin Validation Logic for SSKProtoEnvelope - + guard proto.hasSource else { + throw SSKProtoError.invalidProtobuf(description: "missing required field: source") + } + guard proto.hasType else { + throw SSKProtoError.invalidProtobuf(description: "missing required field: type") + } + guard proto.hasTimestamp else { + throw SSKProtoError.invalidProtobuf(description: "missing required field: timestamp") + } + guard proto.hasSourceDevice else { + throw SSKProtoError.invalidProtobuf(description: "missing required field: sourceDevice") + } + // MARK: - End Validation Logic for SSKProtoEnvelope - let result = SSKProtoEnvelope(type: type, relay: relay, source: source, timestamp: timestamp, sourceDevice: sourceDevice, legacyMessage: legacyMessage, content: content) @@ -147,10 +160,6 @@ import Foundation @objc public class SSKProtoContent: NSObject { - public enum SSKProtoContentError: Error { - case invalidProtobuf(description: String) - } - @objc public let dataMessage: SSKProtoDataMessage? @objc public let callMessage: SSKProtoCallMessage? @objc public let syncMessage: SSKProtoSyncMessage? @@ -240,10 +249,6 @@ import Foundation @objc public class SSKProtoCallMessage: NSObject { - public enum SSKProtoCallMessageError: Error { - case invalidProtobuf(description: String) - } - // MARK: - SSKProtoCallMessage_Offer @objc public class SSKProtoCallMessage_Offer: NSObject { @@ -615,10 +620,6 @@ import Foundation @objc public class SSKProtoDataMessage: NSObject { - public enum SSKProtoDataMessageError: Error { - case invalidProtobuf(description: String) - } - // MARK: - SSKProtoDataMessage_Flags @objc public enum SSKProtoDataMessage_Flags: Int32 { @@ -1551,10 +1552,6 @@ import Foundation @objc public class SSKProtoNullMessage: NSObject { - public enum SSKProtoNullMessageError: Error { - case invalidProtobuf(description: String) - } - @objc public let padding: Data? @objc public init(padding: Data?) { @@ -1600,10 +1597,6 @@ import Foundation @objc public class SSKProtoReceiptMessage: NSObject { - public enum SSKProtoReceiptMessageError: Error { - case invalidProtobuf(description: String) - } - // MARK: - SSKProtoReceiptMessage_Type @objc public enum SSKProtoReceiptMessage_Type: Int32 { @@ -1682,10 +1675,6 @@ import Foundation @objc public class SSKProtoVerified: NSObject { - public enum SSKProtoVerifiedError: Error { - case invalidProtobuf(description: String) - } - // MARK: - SSKProtoVerified_State @objc public enum SSKProtoVerified_State: Int32 { @@ -1786,10 +1775,6 @@ import Foundation @objc public class SSKProtoSyncMessage: NSObject { - public enum SSKProtoSyncMessageError: Error { - case invalidProtobuf(description: String) - } - // MARK: - SSKProtoSyncMessage_Sent @objc public class SSKProtoSyncMessage_Sent: NSObject { @@ -2317,10 +2302,6 @@ import Foundation @objc public class SSKProtoAttachmentPointer: NSObject { - public enum SSKProtoAttachmentPointerError: Error { - case invalidProtobuf(description: String) - } - // MARK: - SSKProtoAttachmentPointer_Flags @objc public enum SSKProtoAttachmentPointer_Flags: Int32 { @@ -2473,10 +2454,6 @@ import Foundation @objc public class SSKProtoGroupContext: NSObject { - public enum SSKProtoGroupContextError: Error { - case invalidProtobuf(description: String) - } - // MARK: - SSKProtoGroupContext_Type @objc public enum SSKProtoGroupContext_Type: Int32 { @@ -2597,10 +2574,6 @@ import Foundation @objc public class SSKProtoContactDetails: NSObject { - public enum SSKProtoContactDetailsError: Error { - case invalidProtobuf(description: String) - } - // MARK: - SSKProtoContactDetails_Avatar @objc public class SSKProtoContactDetails_Avatar: NSObject { @@ -2773,10 +2746,6 @@ import Foundation @objc public class SSKProtoGroupDetails: NSObject { - public enum SSKProtoGroupDetailsError: Error { - case invalidProtobuf(description: String) - } - // MARK: - SSKProtoGroupDetails_Avatar @objc public class SSKProtoGroupDetails_Avatar: NSObject {