Code generate Swift wrappers for protocol buffers.

pull/1/head
Matthew Chen 6 years ago
parent 0cf199bd7e
commit e45a6d5bed

@ -74,6 +74,9 @@ class LineWriter:
def add(self, line):
self.lines.append(('\t' * self.indent()) + line)
def add_raw(self, line):
self.lines.append(line)
def extend(self, text):
for line in text.split('\n'):
self.add(line)
@ -88,13 +91,6 @@ class LineWriter:
lines = lines[:-1]
self.lines = lines
# def prefixed_name(self, proto_name):
# names = self.all_context_proto_names() + [proto_name,]
# return self.args.wrapper_prefix + '_'.join(names)
def is_top_level_entity(self):
return self.indent() == 0
def newline(self):
self.add('')
@ -278,6 +274,14 @@ class FileContext(BaseContext):
import Foundation
''')
writer.invalid_protobuf_error_name = '%sError' % self.args.wrapper_prefix
writer.extend(('''
public enum %s: Error {
case invalidProtobuf(description: String)
}
''' % writer.invalid_protobuf_error_name).strip())
writer.newline()
for child in self.children():
child.generate(writer)
@ -325,8 +329,6 @@ class MessageContext(BaseContext):
child.prepare()
def generate(self, writer):
is_top_level_entity = writer.is_top_level_entity()
writer.add('// MARK: - %s' % self.swift_name)
writer.newline()
@ -335,15 +337,6 @@ class MessageContext(BaseContext):
writer.push_context(self.proto_name, self.swift_name)
if is_top_level_entity:
writer.invalid_protobuf_error_name = '%sError' % self.swift_name
writer.extend(('''
public enum %s: Error {
case invalidProtobuf(description: String)
}
''' % writer.invalid_protobuf_error_name).strip())
writer.newline()
for child in self.children():
child.generate(writer)
@ -436,6 +429,13 @@ public func serializedData() throws -> Data {
writer.add('// MARK: - Begin Validation Logic for %s -' % self.swift_name)
writer.newline()
# Preserve existing validation logic.
validation_block = args.validation_map[self.swift_name]
if validation_block:
writer.add_raw(validation_block)
writer.newline()
writer.add('// MARK: - End Validation Logic for %s -' % self.swift_name)
writer.newline()
@ -449,63 +449,6 @@ public func serializedData() throws -> Data {
writer.add('}')
writer.newline()
# @objc
# public init(serializedData: Data) throws {
#
#
# guard proto.hasSource else {
# throw EnvelopeError.invalidProtobuf(description: "missing required field: source")
# }
# self.source = proto.source
#
# guard proto.hasType else {
# throw EnvelopeError.invalidProtobuf(description: "missing required field: type")
# }
# self.type = {
# switch proto.type {
# case .unknown:
# return .unknown
# case .ciphertext:
# return .ciphertext
# case .keyExchange:
# return .keyExchange
# case .prekeyBundle:
# return .prekeyBundle
# case .receipt:
# return .receipt
# }
# }()
#
# guard proto.hasTimestamp else {
# throw EnvelopeError.invalidProtobuf(description: "missing required field: timestamp")
# }
# self.timestamp = proto.timestamp
#
# guard proto.hasSourceDevice else {
# throw EnvelopeError.invalidProtobuf(description: "missing required field: sourceDevice")
# }
# self.sourceDevice = proto.sourceDevice
#
# if proto.hasContent {
# self.content = proto.content
# } else {
# self.content = nil
# }
#
# if proto.hasLegacyMessage {
# self.legacyMessage = proto.legacyMessage
# } else {
# self.legacyMessage = nil
# }
#
# if proto.relay.count > 0 {
# self.relay = proto.relay
# } else {
# relay = nil
# }
# }
#
# asProtobuf() func
writer.add('fileprivate var asProtobuf: %s {' % wrapped_swift_name)
writer.push_indent()
@ -677,8 +620,8 @@ def line_parser(text):
def parse_enum(args, proto_file_path, parser, parent_context, enum_name):
if args.verbose:
print '# enum:', enum_name
# if args.verbose:
# print '# enum:', enum_name
context = EnumContext(args, parent_context, enum_name)
@ -700,8 +643,8 @@ def parse_enum(args, proto_file_path, parser, parent_context, enum_name):
item_name = item_match.group(1).strip()
item_index = item_match.group(2).strip()
if args.verbose:
print '\t enum item[%s]: %s' % (item_index, item_name)
# if args.verbose:
# print '\t enum item[%s]: %s' % (item_index, item_name)
if item_name in context.item_names():
raise Exception('Duplicate enum name[%s]: %s' % (proto_file_path, item_name))
@ -725,8 +668,8 @@ def optional_match_group(match, index):
def parse_message(args, proto_file_path, parser, parent_context, message_name):
if args.verbose:
print '# message:', message_name
# if args.verbose:
# print '# message:', message_name
context = MessageContext(args, parent_context, message_name)
@ -786,8 +729,8 @@ def parse_message(args, proto_file_path, parser, parent_context, message_name):
}
# print 'message_field:', message_field
if args.verbose:
print '\t message field[%s]: %s' % (item_index, str(message_field))
# if args.verbose:
# print '\t message field[%s]: %s' % (item_index, str(message_field))
if item_name in context.field_names():
raise Exception('Duplicate message field name[%s]: %s' % (proto_file_path, item_name))
@ -803,7 +746,45 @@ def parse_message(args, proto_file_path, parser, parent_context, message_name):
raise Exception('Invalid message syntax[%s]: %s' % (proto_file_path, line))
def preserve_validation_logic(args, proto_file_path, dst_file_path):
args.validation_map = {}
if os.path.exists(dst_file_path):
with open(dst_file_path, 'rt') as f:
old_text = f.read()
validation_start_regex = re.compile(r'// MARK: - Begin Validation Logic for ([^ ]+) -')
for match in validation_start_regex.finditer(old_text):
# print 'match'
name = match.group(1)
# print '\t name:', name
start = match.end(0)
# print '\t start:', start
end_marker = '// MARK: - End Validation Logic for %s -' % name
end = old_text.find(end_marker)
# print '\t end:', end
if end < start:
raise Exception('Malformed validation: %s' % proto_file_path)
validation_block = old_text[start:end]
# print '\t validation_block:', validation_block
# Strip trailing whitespace.
validation_lines = validation_block.split('\n')
validation_lines = [line.rstrip() for line in validation_lines]
# Strip leading empty lines.
while len(validation_lines) > 0 and validation_lines[0] == '':
validation_lines = validation_lines[1:]
# Strip trailing empty lines.
while len(validation_lines) > 0 and validation_lines[-1] == '':
validation_lines = validation_lines[:-1]
validation_block = '\n'.join(validation_lines)
if len(validation_block) > 0:
if args.verbose:
print 'Preserving validation logic for:', name
args.validation_map[name] = validation_block
def process_proto_file(args, proto_file_path, dst_file_path):
with open(proto_file_path, 'rt') as f:
text = f.read()
@ -856,6 +837,8 @@ def process_proto_file(args, proto_file_path, dst_file_path):
raise Exception('Invalid syntax[%s]: %s' % (proto_file_path, line))
preserve_validation_logic(args, proto_file_path, dst_file_path)
writer = LineWriter(args)
context.prepare()
context.generate(writer)
@ -896,5 +879,5 @@ if __name__ == "__main__":
args.package = None
process_proto_file(args, proto_file_path, dst_file_path)
print 'complete.'
# print 'complete.'

@ -4,14 +4,14 @@
import Foundation
public enum SSKProtoError: Error {
case invalidProtobuf(description: String)
}
// MARK: - SSKProtoEnvelope
@objc public class SSKProtoEnvelope: NSObject {
public enum SSKProtoEnvelopeError: Error {
case invalidProtobuf(description: String)
}
// MARK: - SSKProtoEnvelope_Type
@objc public enum SSKProtoEnvelope_Type: Int32 {
@ -108,6 +108,19 @@ import Foundation
// MARK: - Begin Validation Logic for SSKProtoEnvelope -
guard proto.hasSource else {
throw SSKProtoError.invalidProtobuf(description: "missing required field: source")
}
guard proto.hasType else {
throw SSKProtoError.invalidProtobuf(description: "missing required field: type")
}
guard proto.hasTimestamp else {
throw SSKProtoError.invalidProtobuf(description: "missing required field: timestamp")
}
guard proto.hasSourceDevice else {
throw SSKProtoError.invalidProtobuf(description: "missing required field: sourceDevice")
}
// MARK: - End Validation Logic for SSKProtoEnvelope -
let result = SSKProtoEnvelope(type: type, relay: relay, source: source, timestamp: timestamp, sourceDevice: sourceDevice, legacyMessage: legacyMessage, content: content)
@ -147,10 +160,6 @@ import Foundation
@objc public class SSKProtoContent: NSObject {
public enum SSKProtoContentError: Error {
case invalidProtobuf(description: String)
}
@objc public let dataMessage: SSKProtoDataMessage?
@objc public let callMessage: SSKProtoCallMessage?
@objc public let syncMessage: SSKProtoSyncMessage?
@ -240,10 +249,6 @@ import Foundation
@objc public class SSKProtoCallMessage: NSObject {
public enum SSKProtoCallMessageError: Error {
case invalidProtobuf(description: String)
}
// MARK: - SSKProtoCallMessage_Offer
@objc public class SSKProtoCallMessage_Offer: NSObject {
@ -615,10 +620,6 @@ import Foundation
@objc public class SSKProtoDataMessage: NSObject {
public enum SSKProtoDataMessageError: Error {
case invalidProtobuf(description: String)
}
// MARK: - SSKProtoDataMessage_Flags
@objc public enum SSKProtoDataMessage_Flags: Int32 {
@ -1551,10 +1552,6 @@ import Foundation
@objc public class SSKProtoNullMessage: NSObject {
public enum SSKProtoNullMessageError: Error {
case invalidProtobuf(description: String)
}
@objc public let padding: Data?
@objc public init(padding: Data?) {
@ -1600,10 +1597,6 @@ import Foundation
@objc public class SSKProtoReceiptMessage: NSObject {
public enum SSKProtoReceiptMessageError: Error {
case invalidProtobuf(description: String)
}
// MARK: - SSKProtoReceiptMessage_Type
@objc public enum SSKProtoReceiptMessage_Type: Int32 {
@ -1682,10 +1675,6 @@ import Foundation
@objc public class SSKProtoVerified: NSObject {
public enum SSKProtoVerifiedError: Error {
case invalidProtobuf(description: String)
}
// MARK: - SSKProtoVerified_State
@objc public enum SSKProtoVerified_State: Int32 {
@ -1786,10 +1775,6 @@ import Foundation
@objc public class SSKProtoSyncMessage: NSObject {
public enum SSKProtoSyncMessageError: Error {
case invalidProtobuf(description: String)
}
// MARK: - SSKProtoSyncMessage_Sent
@objc public class SSKProtoSyncMessage_Sent: NSObject {
@ -2317,10 +2302,6 @@ import Foundation
@objc public class SSKProtoAttachmentPointer: NSObject {
public enum SSKProtoAttachmentPointerError: Error {
case invalidProtobuf(description: String)
}
// MARK: - SSKProtoAttachmentPointer_Flags
@objc public enum SSKProtoAttachmentPointer_Flags: Int32 {
@ -2473,10 +2454,6 @@ import Foundation
@objc public class SSKProtoGroupContext: NSObject {
public enum SSKProtoGroupContextError: Error {
case invalidProtobuf(description: String)
}
// MARK: - SSKProtoGroupContext_Type
@objc public enum SSKProtoGroupContext_Type: Int32 {
@ -2597,10 +2574,6 @@ import Foundation
@objc public class SSKProtoContactDetails: NSObject {
public enum SSKProtoContactDetailsError: Error {
case invalidProtobuf(description: String)
}
// MARK: - SSKProtoContactDetails_Avatar
@objc public class SSKProtoContactDetails_Avatar: NSObject {
@ -2773,10 +2746,6 @@ import Foundation
@objc public class SSKProtoGroupDetails: NSObject {
public enum SSKProtoGroupDetailsError: Error {
case invalidProtobuf(description: String)
}
// MARK: - SSKProtoGroupDetails_Avatar
@objc public class SSKProtoGroupDetails_Avatar: NSObject {

Loading…
Cancel
Save