Skip to content

Commit

Permalink
Merge PR #123
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengchun committed Feb 3, 2025
2 parents a5f9024 + 121dfc1 commit 15733e6
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 78 deletions.
60 changes: 35 additions & 25 deletions cached_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,48 @@ import (
)

type cachedReader struct {
buffer *bufio.Reader
cache []byte
cacheCap int
cacheLen int
buffer *bufio.Reader
cache []byte
caching bool
}

func newCachedReader(r *bufio.Reader) *cachedReader {
return &cachedReader{
buffer: r,
cache: make([]byte, 4096),
cacheCap: 4096,
cacheLen: 0,
caching: false,
buffer: r,
cache: make([]byte, 0, 4096),
caching: false,
}
}

func (c *cachedReader) StartCaching() {
c.cacheLen = 0
c.cache = c.cache[:0]
c.caching = true
}

func (c *cachedReader) ReadByte() (byte, error) {
if !c.caching {
return c.buffer.ReadByte()
}
b, err := c.buffer.ReadByte()
func (c *cachedReader) ReadByte() (b byte, err error) {
b, err = c.buffer.ReadByte()
if err != nil {
return b, err
return
}
if c.cacheLen < c.cacheCap {
c.cache[c.cacheLen] = b
c.cacheLen++
if c.caching {
c.cacheByte(b)
}
return b, err
return
}

func (c *cachedReader) Cache() []byte {
return c.cache[:c.cacheLen]
return c.cache
}

func (c *cachedReader) CacheWithLimit(n int) []byte {
if n < 1 {
return nil
}
l := len(c.cache)
if n > l {
n = l
}
return c.cache[:n]
}

func (c *cachedReader) StopCaching() {
Expand All @@ -55,15 +58,22 @@ func (c *cachedReader) Read(p []byte) (int, error) {
if err != nil {
return n, err
}
if c.caching && c.cacheLen < c.cacheCap {
if c.caching {
for i := 0; i < n; i++ {
c.cache[c.cacheLen] = p[i]
c.cacheLen++
if c.cacheLen >= c.cacheCap {
if !c.cacheByte(p[i]) {
break
}
}
}
return n, err
}

func (c *cachedReader) cacheByte(b byte) bool {
n := len(c.cache)
if n == cap(c.cache) {
return false
}
c.cache = c.cache[:n+1]
c.cache[n] = b
return true
}
15 changes: 15 additions & 0 deletions cached_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,19 @@ func TestCaching(t *testing.T) {
if !bytes.Equal(cached, []byte("BCDEF")) {
t.Fatalf("Incorrect cached buffer value")
}

cached = cachedReader.CacheWithLimit(-1)
if cached != nil {
t.Fatalf("Incorrect cached buffer value")
}

cached = cachedReader.CacheWithLimit(3)
if !bytes.Equal(cached, []byte("BCD")) {
t.Fatalf("Incorrect cached buffer value")
}

cached = cachedReader.CacheWithLimit(1000)
if !bytes.Equal(cached, []byte("BCDEF")) {
t.Fatalf("Incorrect cached buffer value")
}
}
42 changes: 36 additions & 6 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ func WithPreserveSpace() OutputOption {
}
}

// WithoutPreserveSpace will not preserve spaces in output
func WithoutPreserveSpace() OutputOption {
return func(oc *outputConfiguration) {
oc.preserveSpaces = false
}
}

// WithIndentation sets the indentation string used for formatting the output.
func WithIndentation(indentation string) OutputOption {
return func(oc *outputConfiguration) {
Expand Down Expand Up @@ -328,7 +335,9 @@ func (n *Node) Write(writer io.Writer, self bool) error {

// WriteWithOptions writes xml with given options to given writer.
func (n *Node) WriteWithOptions(writer io.Writer, opts ...OutputOption) (err error) {
config := &outputConfiguration{}
config := &outputConfiguration{
preserveSpaces: true,
}
// Set the options
for _, opt := range opts {
opt(config)
Expand Down Expand Up @@ -400,11 +409,7 @@ func AddChild(parent, n *Node) {
parent.LastChild = n
}

// AddSibling adds a new node 'n' as a sibling of a given node 'sibling'.
// Note it is not necessarily true that the new node 'n' would be added
// immediately after 'sibling'. If 'sibling' isn't the last child of its
// parent, then the new node 'n' will be added at the end of the sibling
// chain of their parent.
// AddSibling adds a new node 'n' as a last node of sibling chain for a given node 'sibling'.
func AddSibling(sibling, n *Node) {
for t := sibling.NextSibling; t != nil; t = t.NextSibling {
sibling = t
Expand All @@ -418,6 +423,19 @@ func AddSibling(sibling, n *Node) {
}
}

// AddImmediateSibling adds a new node 'n' as immediate sibling a given node 'sibling'.
func AddImmediateSibling(sibling, n *Node) {
n.Parent = sibling.Parent
n.NextSibling = sibling.NextSibling
sibling.NextSibling = n
n.PrevSibling = sibling
if n.NextSibling != nil {
n.NextSibling.PrevSibling = n
} else if n.Parent != nil {
sibling.Parent.LastChild = n
}
}

// RemoveFromTree removes a node and its subtree from the document
// tree it is in. If the node is the root of the tree, then it's no-op.
func RemoveFromTree(n *Node) {
Expand Down Expand Up @@ -445,3 +463,15 @@ func RemoveFromTree(n *Node) {
n.PrevSibling = nil
n.NextSibling = nil
}

// GetRoot returns a root of the tree where 'n' is a node.
func GetRoot(n *Node) *Node {
if n == nil {
return nil
}
root := n
for root.Parent != nil {
root = root.Parent
}
return root
}
48 changes: 36 additions & 12 deletions node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func TestRemoveFromTree(t *testing.T) {
testTrue(t, n != nil)
RemoveFromTree(n)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<?procinst?><!--comment--><aaa><bbb></bbb><ddd></ddd><ggg></ggg></aaa>`)
})

Expand All @@ -260,7 +260,7 @@ func TestRemoveFromTree(t *testing.T) {
testTrue(t, n != nil)
RemoveFromTree(n)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<?procinst?><!--comment--><aaa><ddd><eee><fff></fff></eee></ddd><ggg></ggg></aaa>`)
})

Expand All @@ -270,7 +270,7 @@ func TestRemoveFromTree(t *testing.T) {
testTrue(t, n != nil)
RemoveFromTree(n)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<?procinst?><!--comment--><aaa><bbb></bbb><ggg></ggg></aaa>`)
})

Expand All @@ -280,7 +280,7 @@ func TestRemoveFromTree(t *testing.T) {
testTrue(t, n != nil)
RemoveFromTree(n)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<?procinst?><!--comment--><aaa><bbb></bbb><ddd><eee><fff></fff></eee></ddd></aaa>`)
})

Expand All @@ -290,7 +290,7 @@ func TestRemoveFromTree(t *testing.T) {
testValue(t, procInst.Type, DeclarationNode)
RemoveFromTree(procInst)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<!--comment--><aaa><bbb></bbb><ddd><eee><fff></fff></eee></ddd><ggg></ggg></aaa>`)
})

Expand All @@ -300,19 +300,44 @@ func TestRemoveFromTree(t *testing.T) {
testValue(t, commentNode.Type, CommentNode)
RemoveFromTree(commentNode)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<?procinst?><aaa><bbb></bbb><ddd><eee><fff></fff></eee></ddd><ggg></ggg></aaa>`)
})

t.Run("remove call on root does nothing", func(t *testing.T) {
doc := parseXML()
RemoveFromTree(doc)
verifyNodePointers(t, doc)
testValue(t, doc.OutputXML(false),
testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
`<?procinst?><!--comment--><aaa><bbb></bbb><ddd><eee><fff></fff></eee></ddd><ggg></ggg></aaa>`)
})
}

func TestAddImmediateSibling(t *testing.T) {
s := `<?xml version="1.0" encoding="UTF-8"?>
<AAA>
<BBB id="1"/>
<CCC id="2">
<DDD/>
</CCC>
<CCC id="3">
<DDD/>
</CCC>
</AAA>`
root, err := Parse(strings.NewReader(s))
if err != nil {
t.Error(err)
}

aaa := findNode(root, "AAA")
n := aaa.SelectElement("BBB")
if n == nil {
t.Fatalf("n is nil")
}
AddImmediateSibling(n, &Node{Type: ElementNode, Data: "r"})
testValue(t, root.OutputXMLWithOptions(WithoutPreserveSpace()), `<?xml version="1.0" encoding="UTF-8"?><AAA><BBB id="1"></BBB><r></r><CCC id="2"><DDD></DDD></CCC><CCC id="3"><DDD></DDD></CCC></AAA>`)
}

func TestSelectElement(t *testing.T) {
s := `<?xml version="1.0" encoding="UTF-8"?>
<AAA>
Expand Down Expand Up @@ -497,7 +522,6 @@ func TestWriteWithNamespacePrefix(t *testing.T) {
}
}


func TestQueryWithPrefix(t *testing.T) {
s := `<?xml version="1.0" encoding="UTF-8"?><S:Envelope xmlns:S="http://schemas.xmlsoap.org/soap/envelope/"><S:Body test="1"><ns2:Fault xmlns:ns2="http://schemas.xmlsoap.org/soap/envelope/" xmlns:ns3="http://www.w3.org/2003/05/soap-envelope"><faultcode>ns2:Client</faultcode><faultstring>This is a client fault</faultstring></ns2:Fault></S:Body></S:Envelope>`
doc, _ := Parse(strings.NewReader(s))
Expand Down Expand Up @@ -582,7 +606,7 @@ func TestOutputXMLWithSpaceDirect(t *testing.T) {
t.Errorf(`expected "%s", obtained "%s"`, expected, g)
}

output := html.UnescapeString(doc.OutputXML(true))
output := html.UnescapeString(doc.OutputXMLWithOptions(WithOutputSelf(), WithoutPreserveSpace()))
if strings.Contains(output, "\n") {
t.Errorf("the outputted xml contains newlines")
}
Expand All @@ -606,7 +630,7 @@ func TestOutputXMLWithSpaceOverwrittenToPreserve(t *testing.T) {
t.Errorf(`expected "%s", obtained "%s"`, expected, g)
}

output := html.UnescapeString(doc.OutputXML(true))
output := html.UnescapeString(doc.OutputXMLWithOptions(WithOutputSelf(), WithoutPreserveSpace()))
if strings.Contains(output, "\n") {
t.Errorf("the outputted xml contains newlines")
}
Expand Down Expand Up @@ -680,8 +704,8 @@ func TestOutputXMLWithPreserveSpaceOption(t *testing.T) {
</student>
</class_list>`
doc, _ := Parse(strings.NewReader(s))
resultWithSpace := doc.OutputXMLWithOptions(WithPreserveSpace())
resultWithoutSpace := doc.OutputXMLWithOptions()
resultWithSpace := doc.OutputXMLWithOptions()
resultWithoutSpace := doc.OutputXMLWithOptions(WithoutPreserveSpace())
if !strings.Contains(resultWithSpace, "> Robert <") {
t.Errorf("output was not expected. expected %v but got %v", " Robert ", resultWithSpace)
}
Expand Down
36 changes: 26 additions & 10 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package xmlquery

import (
"bufio"
"bytes"
"encoding/xml"
"fmt"
"io"
Expand Down Expand Up @@ -39,15 +40,31 @@ func Parse(r io.Reader) (*Node, error) {
func ParseWithOptions(r io.Reader, options ParserOptions) (*Node, error) {
p := createParser(r)
options.apply(p)
for {
_, err := p.parse()
if err == io.EOF {
return p.doc, nil
var err error
for err == nil {
_, err = p.parse()
}

if err == io.EOF {
// additional check for validity
// according to: https://www.w3.org/TR/xml
// the document MUST contain at least ONE element
valid := false
for doc := p.doc; doc != nil; doc = doc.NextSibling {
for node := doc.FirstChild; node != nil; node = node.NextSibling {
if node.Type == ElementNode {
valid = true
break
}
}
}
if err != nil {
return nil, err
if !valid {
return nil, fmt.Errorf("xmlquery: invalid XML document")
}
return p.doc, nil
}

return nil, err
}

type parser struct {
Expand Down Expand Up @@ -168,7 +185,7 @@ func (p *parser) parse() (*Node, error) {

if node.NamespaceURI != "" {
if v, ok := p.space2prefix[node.NamespaceURI]; ok {
cached := string(p.reader.Cache())
cached := string(p.reader.CacheWithLimit(len(v.name) + len(node.Data) + 2))
if strings.HasPrefix(cached, fmt.Sprintf("%s:%s", v.name, node.Data)) || strings.HasPrefix(cached, fmt.Sprintf("<%s:%s", v.name, node.Data)) {
node.Prefix = v.name
}
Expand Down Expand Up @@ -228,12 +245,11 @@ func (p *parser) parse() (*Node, error) {
}
case xml.CharData:
// First, normalize the cache...
cached := strings.ToUpper(string(p.reader.Cache()))
cached := bytes.ToUpper(p.reader.CacheWithLimit(9))
nodeType := TextNode
if strings.HasPrefix(cached, "<![CDATA[") || strings.HasPrefix(cached, "![CDATA[") {
if bytes.HasPrefix(cached, []byte("<![CDATA[")) || bytes.HasPrefix(cached, []byte("![CDATA[")) {
nodeType = CharDataNode
}

node := &Node{Type: nodeType, Data: string(tok), level: p.level}
if p.level == p.prev.level {
AddSibling(p.prev, node)
Expand Down
Loading

0 comments on commit 15733e6

Please sign in to comment.